summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 15:52:03 +0100
committerGitHub <noreply@github.com>2023-07-10 15:52:03 +0100
commit548b1df7ea8835f171f04ce28f778a30e8b31f9c (patch)
tree3957f895cc2da7324af5b9c637187d3162e38e11 /candle-core
parent221b1aff6594acd6d030c5131dba388590d1917f (diff)
downloadcandle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.tar.gz
candle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.tar.bz2
candle-548b1df7ea8835f171f04ce28f778a30e8b31f9c.zip
Remove the dependency to blas and use mkl directly. (#125)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/Cargo.toml4
-rw-r--r--candle-core/src/cpu_backend.rs34
-rw-r--r--candle-core/src/lib.rs2
-rw-r--r--candle-core/src/mkl.rs154
4 files changed, 190 insertions, 4 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index 27e85eee..91ca0cff 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -11,7 +11,6 @@ license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
-blas = { version = "0.22.0", optional = true }
byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true }
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
@@ -22,6 +21,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
half = { version = "2.3.1", features = ["num-traits"] }
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
+libc = { version = "0.2.147", optional = true }
memmap2 = "0.7.1"
num-traits = "0.2.15"
num_cpus = "1.15.0"
@@ -35,4 +35,4 @@ anyhow = { version = "1", features = ["backtrace"] }
[features]
default = ["cuda"]
cuda = ["dep:cudarc", "dep:candle-kernels"]
-mkl = ["dep:blas", "dep:intel-mkl-src"]
+mkl = ["dep:libc", "dep:intel-mkl-src"]
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index dd9dabc1..6663021d 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -416,6 +416,36 @@ impl Map2 for MatMul {
let mut dst = vec![T::zero(); b * m * n];
match T::DTYPE {
+ DType::F16 => {
+ for step in 0..b {
+ let lhs_p = &lhs[step * a_skip..];
+ let rhs_p = &rhs[step * b_skip..];
+ let dst_p = &mut dst[step * c_skip..];
+ unsafe {
+ let a = rhs_p.as_ptr() as *const f16;
+ let b = lhs_p.as_ptr() as *const f16;
+ let c = dst_p.as_mut_ptr() as *mut f16;
+ let a = std::slice::from_raw_parts(a, a_skip);
+ let b = std::slice::from_raw_parts(b, b_skip);
+ let c = std::slice::from_raw_parts_mut(c, c_skip);
+ crate::mkl::hgemm(
+ transa,
+ transb,
+ /* m= */ n as i32,
+ /* n= */ m as i32,
+ /* k= */ k as i32,
+ /* alpha= */ f16::ONE,
+ /* a= */ a,
+ /* lda= */ lda,
+ /* b= */ b,
+ /* ldb= */ ldb,
+ /* beta= */ f16::ZERO,
+ /* c= */ c,
+ /* ldc= */ n as i32,
+ )
+ }
+ }
+ }
DType::F32 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
@@ -428,7 +458,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
- blas::sgemm(
+ crate::mkl::sgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
@@ -449,7 +479,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
- blas::dgemm(
+ crate::mkl::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs
index 8df44e37..81cd5c30 100644
--- a/candle-core/src/lib.rs
+++ b/candle-core/src/lib.rs
@@ -44,6 +44,8 @@ mod dtype;
mod dummy_cuda_backend;
mod error;
mod layout;
+#[cfg(feature = "mkl")]
+mod mkl;
mod npy;
mod op;
pub mod safetensors;
diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs
new file mode 100644
index 00000000..8d1aea4e
--- /dev/null
+++ b/candle-core/src/mkl.rs
@@ -0,0 +1,154 @@
+use libc::{c_char, c_double, c_float, c_int};
+
+mod ffi {
+ use super::*;
+ extern "C" {
+ pub fn sgemm_(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_float,
+ a: *const c_float,
+ lda: *const c_int,
+ b: *const c_float,
+ ldb: *const c_int,
+ beta: *const c_float,
+ c: *mut c_float,
+ ldc: *const c_int,
+ );
+ pub fn dgemm_(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const c_double,
+ a: *const c_double,
+ lda: *const c_int,
+ b: *const c_double,
+ ldb: *const c_int,
+ beta: *const c_double,
+ c: *mut c_double,
+ ldc: *const c_int,
+ );
+ pub fn hgemm_(
+ transa: *const c_char,
+ transb: *const c_char,
+ m: *const c_int,
+ n: *const c_int,
+ k: *const c_int,
+ alpha: *const half::f16,
+ a: *const half::f16,
+ lda: *const c_int,
+ b: *const half::f16,
+ ldb: *const c_int,
+ beta: *const half::f16,
+ c: *mut half::f16,
+ ldc: *const c_int,
+ );
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn sgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f32,
+ a: &[f32],
+ lda: i32,
+ b: &[f32],
+ ldb: i32,
+ beta: f32,
+ c: &mut [f32],
+ ldc: i32,
+) {
+ ffi::sgemm_(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn dgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: f64,
+ a: &[f64],
+ lda: i32,
+ b: &[f64],
+ ldb: i32,
+ beta: f64,
+ c: &mut [f64],
+ ldc: i32,
+) {
+ ffi::dgemm_(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}
+
+#[allow(clippy::too_many_arguments)]
+#[inline]
+pub unsafe fn hgemm(
+ transa: u8,
+ transb: u8,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: half::f16,
+ a: &[half::f16],
+ lda: i32,
+ b: &[half::f16],
+ ldb: i32,
+ beta: half::f16,
+ c: &mut [half::f16],
+ ldc: i32,
+) {
+ ffi::hgemm_(
+ &(transa as c_char),
+ &(transb as c_char),
+ &m,
+ &n,
+ &k,
+ &alpha,
+ a.as_ptr(),
+ &lda,
+ b.as_ptr(),
+ &ldb,
+ &beta,
+ c.as_mut_ptr(),
+ &ldc,
+ )
+}