summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs43
1 files changed, 17 insertions, 26 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index de32b549..a20d032d 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -261,6 +261,17 @@ impl<'a> Map2 for Conv1D<'a> {
struct MatMul((usize, usize, usize, usize));
+impl MatMul {
+ fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
+ Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
+ lhs_l: lhs_l.clone(),
+ rhs_l: rhs_l.clone(),
+ bmnk: self.0,
+ msg,
+ }))
+ }
+}
+
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
@@ -290,19 +301,13 @@ impl Map2 for MatMul {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
@@ -369,19 +374,13 @@ impl Map2 for MatMul {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
- _ => Err(Error::UnexpectedStriding {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- })?,
+ _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
@@ -395,11 +394,7 @@ impl Map2 for MatMul {
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
} else {
- Err(Error::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
@@ -407,11 +402,7 @@ impl Map2 for MatMul {
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
} else {
- Err(Error::MatMulNonContiguous {
- lhs_stride: lhs_stride.to_vec(),
- rhs_stride: rhs_stride.to_vec(),
- mnk: (m, n, k),
- })?
+ Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
};
let mut dst = vec![T::zero(); b * m * n];