diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-27 10:59:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-27 10:59:05 +0100 |
commit | a9abde5f930914ef7ef2d504728f742f80468961 (patch) | |
tree | 3172571ca2f02af81d81a043d8ee09d86e3fb03f /candle-metal-kernels | |
parent | 75b6d4b0da4e7fef82d9f61e274b49af55777acf (diff) | |
download | candle-a9abde5f930914ef7ef2d504728f742f80468961.tar.gz candle-a9abde5f930914ef7ef2d504728f742f80468961.tar.bz2 candle-a9abde5f930914ef7ef2d504728f742f80468961.zip |
More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks.
* Also relax the checks on the metal side.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 449bef8f..3f452331 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1451,9 +1451,12 @@ pub fn call_gemm( let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + // lhs has shape b, m, k + // We also allow for the case where the stride on the minor dimension is not as expected but + // there is a single element. + let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { false - } else if lhs_m1 == m && lhs_m2 == 1 { + } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1462,9 +1465,10 @@ pub fn call_gemm( mnk: (m, n, k), })?; }; - let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + // rhs has shape b, k, n + let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { false - } else if rhs_m1 == k && rhs_m2 == 1 { + } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { true } else { return Err(MetalKernelError::MatMulNonContiguous { |