summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/quantized/metal.rs5
1 files changed, 4 insertions, 1 deletions
diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs
index 7be0f74e..c310d766 100644
--- a/candle-core/src/quantized/metal.rs
+++ b/candle-core/src/quantized/metal.rs
@@ -149,8 +149,11 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
+ // We always use a single batch dimension and stack all the tensors in the batch on the
+ // second dimension as the implementation in candle-metal-kernels doesn't handle batch
+ // properly.
let (b, m) = match dst_shape.len() {
- 3 => (dst_shape[0], dst_shape[1]),
+ 3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};