diff options
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 3f452331..140927e3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1454,9 +1454,9 @@ pub fn call_gemm( // lhs has shape b, m, k // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. - let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { + let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { false - } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1466,9 +1466,9 @@ pub fn call_gemm( })?; }; // rhs has shape b, k, n - let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { + let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { false - } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous { |