summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-28 10:17:38 +0100
committerGitHub <noreply@github.com>2024-03-28 10:17:38 +0100
commitb3484e7a5e8d8c613e2a444c6f056142fc1e758d (patch)
tree01d31980402024d62ffc5088fa7d5e452a2aa72b /candle-metal-kernels
parentada5d7c096b530fd29b071d798660f3843945e2b (diff)
downloadcandle-b3484e7a5e8d8c613e2a444c6f056142fc1e758d.tar.gz
candle-b3484e7a5e8d8c613e2a444c6f056142fc1e758d.tar.bz2
candle-b3484e7a5e8d8c613e2a444c6f056142fc1e758d.zip
Fix for the RWKV models. (#1955)
* Fix for the RWKV models. * More general fix + revert the rwkv hack. * Remove the old hack.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 3f452331..140927e3 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1454,9 +1454,9 @@ pub fn call_gemm(
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
- let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
+ let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
false
- } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
+ } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@@ -1466,9 +1466,9 @@ pub fn call_gemm(
})?;
};
// rhs has shape b, k, n
- let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
+ let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
false
- } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
+ } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {