diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-10 20:23:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-10 20:23:43 +0100 |
commit | df5f69444e438a7cd8d8ab4971579bf309b72114 (patch) | |
tree | d0c8d2c3ddf6e2163c55dc31fe55af958ea0abba /candle-core/src/quantized/cuda.rs | |
parent | 0c5eecbc0faa7e642210800c735ad8137d5a9e08 (diff) | |
download | candle-df5f69444e438a7cd8d8ab4971579bf309b72114.tar.gz candle-df5f69444e438a7cd8d8ab4971579bf309b72114.tar.bz2 candle-df5f69444e438a7cd8d8ab4971579bf309b72114.zip |
Properly handle the batch dimension in cuda quantized matmul. (#1832)
Diffstat (limited to 'candle-core/src/quantized/cuda.rs')
-rw-r--r-- | candle-core/src/quantized/cuda.rs | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 5b684573..c90cf576 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -313,7 +313,7 @@ impl QCudaStorage { } let data_f32 = self.dequantize(n * k)?; - let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?; let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; let mut out_shape = layout.shape().dims().to_vec(); out_shape.pop(); |