diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-10 15:02:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 15:02:37 +0100 |
commit | 221b1aff6594acd6d030c5131dba388590d1917f (patch) | |
tree | fc800a0158c721240c7bf224504e202948401bce /candle-core/src/cpu_backend.rs | |
parent | 71cd3745a90e277c8d5911b7ddc98d70aebcd8ed (diff) | |
download | candle-221b1aff6594acd6d030c5131dba388590d1917f.tar.gz candle-221b1aff6594acd6d030c5131dba388590d1917f.tar.bz2 candle-221b1aff6594acd6d030c5131dba388590d1917f.zip |
Support dgemm in mkl matmul. (#122)
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 64 |
1 files changed, 45 insertions, 19 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 15982040..dd9dabc1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -265,7 +265,7 @@ impl Map2 for MatMul { const OP: &'static str = "mat_mul"; #[cfg(not(feature = "mkl"))] - fn f<T: 'static + num_traits::Num + Copy>( + fn f<T: 'static + WithDType + num_traits::Num + Copy>( &self, lhs: &[T], lhs_l: &Layout, @@ -350,7 +350,7 @@ impl Map2 for MatMul { } #[cfg(feature = "mkl")] - fn f<T: 'static + num_traits::Num + Copy>( + fn f<T: 'static + WithDType + num_traits::Num + Copy>( &self, lhs: &[T], lhs_l: &Layout, @@ -415,24 +415,50 @@ impl Map2 for MatMul { }; let mut dst = vec![T::zero(); b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - let a = rhs_p.as_ptr() as *const f32; - let b = lhs_p.as_ptr() as *const f32; - let c = dst_p.as_mut_ptr() as *mut f32; - let a = std::slice::from_raw_parts(a, a_skip); - let b = std::slice::from_raw_parts(b, b_skip); - let c = std::slice::from_raw_parts_mut(c, c_skip); - blas::sgemm( - transa, transb, /* m= */ n as i32, /* n= */ m as i32, - /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, - /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0., - /* c= */ c, /* ldc= */ n as i32, - ) + match T::DTYPE { + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + blas::sgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + blas::dgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?, } Ok(dst) } |