summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-27 10:59:05 +0100
committerGitHub <noreply@github.com>2024-03-27 10:59:05 +0100
commita9abde5f930914ef7ef2d504728f742f80468961 (patch)
tree3172571ca2f02af81d81a043d8ee09d86e3fb03f /candle-metal-kernels
parent75b6d4b0da4e7fef82d9f61e274b49af55777acf (diff)
downloadcandle-a9abde5f930914ef7ef2d504728f742f80468961.tar.gz
candle-a9abde5f930914ef7ef2d504728f742f80468961.tar.bz2
candle-a9abde5f930914ef7ef2d504728f742f80468961.zip
More flexible matmul contiguity checks. (#1949)
* More flexible matmul contiguity checks. * Also relax the checks on the metal side.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs12
1 files changed, 8 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 449bef8f..3f452331 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1451,9 +1451,12 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
- let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
+ // lhs has shape b, m, k
+ // We also allow for the case where the stride on the minor dimension is not as expected but
+ // there is a single element.
+ let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
false
- } else if lhs_m1 == m && lhs_m2 == 1 {
+ } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@@ -1462,9 +1465,10 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
- let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
+ // rhs has shape b, k, n
+ let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
false
- } else if rhs_m1 == k && rhs_m2 == 1 {
+ } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {