diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-10 15:52:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 15:52:03 +0100 |
commit | 548b1df7ea8835f171f04ce28f778a30e8b31f9c (patch) | |
tree | 3957f895cc2da7324af5b9c637187d3162e38e11 | |
parent | 221b1aff6594acd6d030c5131dba388590d1917f (diff) | |
download | candle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.tar.gz candle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.tar.bz2 candle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.zip |
Remove the dependency to blas and use mkl directly. (#125)
-rw-r--r-- | candle-core/Cargo.toml | 4 | ||||
-rw-r--r-- | candle-core/src/cpu_backend.rs | 34 | ||||
-rw-r--r-- | candle-core/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-core/src/mkl.rs | 154 |
4 files changed, 190 insertions, 4 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 27e85eee..91ca0cff 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -11,7 +11,6 @@ license = "MIT/Apache-2.0" readme = "README.md" [dependencies] -blas = { version = "0.22.0", optional = true } byteorder = "1.4.3" candle-kernels = { path = "../candle-kernels", optional = true } # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes @@ -22,6 +21,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" } half = { version = "2.3.1", features = ["num-traits"] } intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]} +libc = { version = "0.2.147", optional = true } memmap2 = "0.7.1" num-traits = "0.2.15" num_cpus = "1.15.0" @@ -35,4 +35,4 @@ anyhow = { version = "1", features = ["backtrace"] } [features] default = ["cuda"] cuda = ["dep:cudarc", "dep:candle-kernels"] -mkl = ["dep:blas", "dep:intel-mkl-src"] +mkl = ["dep:libc", "dep:intel-mkl-src"] diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index dd9dabc1..6663021d 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -416,6 +416,36 @@ impl Map2 for MatMul { let mut dst = vec![T::zero(); b * m * n]; match T::DTYPE { + DType::F16 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f16; + let b = lhs_p.as_ptr() as *const f16; + let c = dst_p.as_mut_ptr() as *mut f16; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + crate::mkl::hgemm( + transa, + transb, + /* m= */ n as i32, + /* n= */ m as i32, + /* k= */ k as i32, + /* alpha= */ f16::ONE, + /* a= */ a, + /* lda= */ lda, + /* b= */ b, + /* ldb= */ ldb, + /* beta= */ f16::ZERO, + /* c= */ c, + /* ldc= */ n as i32, + ) + } + } + } DType::F32 => { for step in 0..b { let lhs_p = &lhs[step * a_skip..]; @@ -428,7 +458,7 @@ impl Map2 for MatMul { let a = std::slice::from_raw_parts(a, a_skip); let b = std::slice::from_raw_parts(b, b_skip); let c = std::slice::from_raw_parts_mut(c, c_skip); - blas::sgemm( + crate::mkl::sgemm( transa, transb, /* m= */ n as i32, /* n= */ m as i32, /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, @@ -449,7 +479,7 @@ impl Map2 for MatMul { let a = std::slice::from_raw_parts(a, a_skip); let b = std::slice::from_raw_parts(b, b_skip); let c = std::slice::from_raw_parts_mut(c, c_skip); - blas::dgemm( + crate::mkl::dgemm( transa, transb, /* m= */ n as i32, /* n= */ m as i32, /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 8df44e37..81cd5c30 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -44,6 +44,8 @@ mod dtype; mod dummy_cuda_backend; mod error; mod layout; +#[cfg(feature = "mkl")] +mod mkl; mod npy; mod op; pub mod safetensors; diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs new file mode 100644 index 00000000..8d1aea4e --- /dev/null +++ b/candle-core/src/mkl.rs @@ -0,0 +1,154 @@ +use libc::{c_char, c_double, c_float, c_int}; + +mod ffi { + use super::*; + extern "C" { + pub fn sgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_float, + a: *const c_float, + lda: *const c_int, + b: *const c_float, + ldb: *const c_int, + beta: *const c_float, + c: *mut c_float, + ldc: *const c_int, + ); + pub fn dgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const c_double, + a: *const c_double, + lda: *const c_int, + b: *const c_double, + ldb: *const c_int, + beta: *const c_double, + c: *mut c_double, + ldc: *const c_int, + ); + pub fn hgemm_( + transa: *const c_char, + transb: *const c_char, + m: *const c_int, + n: *const c_int, + k: *const c_int, + alpha: *const half::f16, + a: *const half::f16, + lda: *const c_int, + b: *const half::f16, + ldb: *const c_int, + beta: *const half::f16, + c: *mut half::f16, + ldc: *const c_int, + ); + } +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn sgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f32, + a: &[f32], + lda: i32, + b: &[f32], + ldb: i32, + beta: f32, + c: &mut [f32], + ldc: i32, +) { + ffi::sgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn dgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: f64, + a: &[f64], + lda: i32, + b: &[f64], + ldb: i32, + beta: f64, + c: &mut [f64], + ldc: i32, +) { + ffi::dgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} + +#[allow(clippy::too_many_arguments)] +#[inline] +pub unsafe fn hgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: half::f16, + a: &[half::f16], + lda: i32, + b: &[half::f16], + ldb: i32, + beta: half::f16, + c: &mut [half::f16], + ldc: i32, +) { + ffi::hgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +} |