summaryrefslogtreecommitdiff
path: root/candle-core/src/quantized/cuda.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-10 20:23:43 +0100
committerGitHub <noreply@github.com>2024-03-10 20:23:43 +0100
commitdf5f69444e438a7cd8d8ab4971579bf309b72114 (patch)
treed0c8d2c3ddf6e2163c55dc31fe55af958ea0abba /candle-core/src/quantized/cuda.rs
parent0c5eecbc0faa7e642210800c735ad8137d5a9e08 (diff)
downloadcandle-df5f69444e438a7cd8d8ab4971579bf309b72114.tar.gz
candle-df5f69444e438a7cd8d8ab4971579bf309b72114.tar.bz2
candle-df5f69444e438a7cd8d8ab4971579bf309b72114.zip
Properly handle the batch dimension in cuda quantized matmul. (#1832)
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r--candle-core/src/quantized/cuda.rs2
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs
index 5b684573..c90cf576 100644
--- a/candle-core/src/quantized/cuda.rs
+++ b/candle-core/src/quantized/cuda.rs
@@ -313,7 +313,7 @@ impl QCudaStorage {
}
let data_f32 = self.dequantize(n * k)?;
- let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
+ let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();