summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs64
1 files changed, 45 insertions, 19 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 15982040..dd9dabc1 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -265,7 +265,7 @@ impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))]
- fn f<T: 'static + num_traits::Num + Copy>(
+ fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
@@ -350,7 +350,7 @@ impl Map2 for MatMul {
}
#[cfg(feature = "mkl")]
- fn f<T: 'static + num_traits::Num + Copy>(
+ fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
@@ -415,24 +415,50 @@ impl Map2 for MatMul {
};
let mut dst = vec![T::zero(); b * m * n];
- for step in 0..b {
- let lhs_p = &lhs[step * a_skip..];
- let rhs_p = &rhs[step * b_skip..];
- let dst_p = &mut dst[step * c_skip..];
- unsafe {
- let a = rhs_p.as_ptr() as *const f32;
- let b = lhs_p.as_ptr() as *const f32;
- let c = dst_p.as_mut_ptr() as *mut f32;
- let a = std::slice::from_raw_parts(a, a_skip);
- let b = std::slice::from_raw_parts(b, b_skip);
- let c = std::slice::from_raw_parts_mut(c, c_skip);
- blas::sgemm(
- transa, transb, /* m= */ n as i32, /* n= */ m as i32,
- /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
- /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0.,
- /* c= */ c, /* ldc= */ n as i32,
- )
+ match T::DTYPE {
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ blas::sgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ blas::dgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
}
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?,
}
Ok(dst)
}