summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 15:02:37 +0100
committerGitHub <noreply@github.com>2023-07-10 15:02:37 +0100
commit221b1aff6594acd6d030c5131dba388590d1917f (patch)
treefc800a0158c721240c7bf224504e202948401bce /candle-core/src/cpu_backend.rs
parent71cd3745a90e277c8d5911b7ddc98d70aebcd8ed (diff)
downloadcandle-221b1aff6594acd6d030c5131dba388590d1917f.tar.gz
candle-221b1aff6594acd6d030c5131dba388590d1917f.tar.bz2
candle-221b1aff6594acd6d030c5131dba388590d1917f.zip
Support dgemm in mkl matmul. (#122)
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs64
1 files changed, 45 insertions, 19 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 15982040..dd9dabc1 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -265,7 +265,7 @@ impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))]
- fn f<T: 'static + num_traits::Num + Copy>(
+ fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
@@ -350,7 +350,7 @@ impl Map2 for MatMul {
}
#[cfg(feature = "mkl")]
- fn f<T: 'static + num_traits::Num + Copy>(
+ fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
@@ -415,24 +415,50 @@ impl Map2 for MatMul {
};
let mut dst = vec![T::zero(); b * m * n];
- for step in 0..b {
- let lhs_p = &lhs[step * a_skip..];
- let rhs_p = &rhs[step * b_skip..];
- let dst_p = &mut dst[step * c_skip..];
- unsafe {
- let a = rhs_p.as_ptr() as *const f32;
- let b = lhs_p.as_ptr() as *const f32;
- let c = dst_p.as_mut_ptr() as *mut f32;
- let a = std::slice::from_raw_parts(a, a_skip);
- let b = std::slice::from_raw_parts(b, b_skip);
- let c = std::slice::from_raw_parts_mut(c, c_skip);
- blas::sgemm(
- transa, transb, /* m= */ n as i32, /* n= */ m as i32,
- /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
- /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0.,
- /* c= */ c, /* ldc= */ n as i32,
- )
+ match T::DTYPE {
+ DType::F32 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f32;
+ let b = lhs_p.as_ptr() as *const f32;
+ let c = dst_p.as_mut_ptr() as *mut f32;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ blas::sgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
+ DType::F64 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f64;
+ let b = lhs_p.as_ptr() as *const f64;
+ let c = dst_p.as_mut_ptr() as *mut f64;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ blas::dgemm(
+ transa, transb, /* m= */ n as i32, /* n= */ m as i32,
+ /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
+ /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
+ /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
+ )
+ }
+ }
}
+ dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?,
}
Ok(dst)
}