diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-18 12:04:39 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-18 12:04:39 +0100 |
commit | ff61a42ad73fbd2b958096f9ef7594d1f7411099 (patch) | |
tree | ac8761d7f5d39eafcd1852213c4e30781dceb174 /candle-core/src/op.rs | |
parent | b706f32839c13aa68686f800dd2102d07770fcbd (diff) | |
download | candle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.tar.gz candle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.tar.bz2 candle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.zip |
Use mkl to accelerate binary ops. (#190)
* Vectorized binary ops with mkl.
* Improve the binary op mkl support.
* Push the support for mkl binary ops.
* Proper vectorization of binary ops.
* Proper mkl'isation when broadcasting binary ops.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 38 |
1 files changed, 33 insertions, 5 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index ec91a3fc..1344cf50 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -83,6 +83,19 @@ pub(crate) trait BinaryOp { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + + const BF16_VEC: bool = false; + fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const U8_VEC: bool = false; + fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} + const U32_VEC: bool = false; + fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} } pub(crate) struct Add; @@ -101,7 +114,7 @@ pub(crate) struct Gelu; pub(crate) struct Relu; macro_rules! bin_op { - ($op:ident, $name: literal, $e: expr) => { + ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => { impl BinaryOp for $op { const NAME: &'static str = $name; const KERNEL: &'static str = concat!("b", $name); @@ -130,14 +143,29 @@ macro_rules! bin_op { fn u32(v1: u32, v2: u32) -> u32 { $e(v1, v2) } + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) { + crate::mkl::$f32_vec(xs1, xs2, ys) + } + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) { + crate::mkl::$f64_vec(xs1, xs2, ys) + } } }; } -bin_op!(Add, "add", |v1, v2| v1 + v2); -bin_op!(Sub, "sub", |v1, v2| v1 - v2); -bin_op!(Mul, "mul", |v1, v2| v1 * v2); -bin_op!(Div, "div", |v1, v2| v1 / v2); +bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add); +bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub); +bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul); +bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div); macro_rules! unary_op { ($op: ident, $name: literal, $a: ident, $e: expr) => { |