diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 43 |
1 files changed, 17 insertions, 26 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index de32b549..a20d032d 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -261,6 +261,17 @@ impl<'a> Map2 for Conv1D<'a> { struct MatMul((usize, usize, usize, usize)); +impl MatMul { + fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error { + Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding { + lhs_l: lhs_l.clone(), + rhs_l: rhs_l.clone(), + bmnk: self.0, + msg, + })) + } +} + impl Map2 for MatMul { const OP: &'static str = "mat_mul"; @@ -290,19 +301,13 @@ impl Map2 for MatMul { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, }; let b_skip: usize = match rhs_stride[..rank - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, }; let c_skip: usize = m * n; @@ -369,19 +374,13 @@ impl Map2 for MatMul { [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?, }; let b_skip: usize = match rhs_stride[..rank - 2] { [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, - _ => Err(Error::UnexpectedStriding { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - })?, + _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?, }; let c_skip: usize = m * n; @@ -395,11 +394,7 @@ impl Map2 for MatMul { } else if rhs_m1 == k && rhs_m2 == 1 { (k as i32, b'T') } else { - Err(Error::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))? }; // The b tensor has dims batching, m, k (lhs) let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k { @@ -407,11 +402,7 @@ impl Map2 for MatMul { } else if lhs_m1 == m && lhs_m2 == 1 { (m as i32, b'T') } else { - Err(Error::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? + Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))? }; let mut dst = vec![T::zero(); b * m * n]; |