summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-18 12:04:39 +0100
committerGitHub <noreply@github.com>2023-07-18 12:04:39 +0100
commitff61a42ad73fbd2b958096f9ef7594d1f7411099 (patch)
treeac8761d7f5d39eafcd1852213c4e30781dceb174 /candle-core/src/op.rs
parentb706f32839c13aa68686f800dd2102d07770fcbd (diff)
downloadcandle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.tar.gz
candle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.tar.bz2
candle-ff61a42ad73fbd2b958096f9ef7594d1f7411099.zip
Use mkl to accelerate binary ops. (#190)
* Vectorized binary ops with mkl. * Improve the binary op mkl support. * Push the support for mkl binary ops. * Proper vectorization of binary ops. * Proper mkl'isation when broadcasting binary ops.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs38
1 files changed, 33 insertions, 5 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index ec91a3fc..1344cf50 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -83,6 +83,19 @@ pub(crate) trait BinaryOp {
fn f64(v1: f64, v2: f64) -> f64;
fn u8(v1: u8, v2: u8) -> u8;
fn u32(v1: u32, v2: u32) -> u32;
+
+ const BF16_VEC: bool = false;
+ fn bf16_vec(_xs1: &[bf16], _xs2: &[bf16], _ys: &mut [bf16]) {}
+ const F16_VEC: bool = false;
+ fn f16_vec(_xs1: &[f16], _xs2: &[f16], _ys: &mut [f16]) {}
+ const F32_VEC: bool = false;
+ fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {}
+ const F64_VEC: bool = false;
+ fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {}
+ const U8_VEC: bool = false;
+ fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {}
+ const U32_VEC: bool = false;
+ fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {}
}
pub(crate) struct Add;
@@ -101,7 +114,7 @@ pub(crate) struct Gelu;
pub(crate) struct Relu;
macro_rules! bin_op {
- ($op:ident, $name: literal, $e: expr) => {
+ ($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
impl BinaryOp for $op {
const NAME: &'static str = $name;
const KERNEL: &'static str = concat!("b", $name);
@@ -130,14 +143,29 @@ macro_rules! bin_op {
fn u32(v1: u32, v2: u32) -> u32 {
$e(v1, v2)
}
+
+ #[cfg(feature = "mkl")]
+ const F32_VEC: bool = true;
+ #[cfg(feature = "mkl")]
+ const F64_VEC: bool = true;
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
+ crate::mkl::$f32_vec(xs1, xs2, ys)
+ }
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
+ crate::mkl::$f64_vec(xs1, xs2, ys)
+ }
}
};
}
-bin_op!(Add, "add", |v1, v2| v1 + v2);
-bin_op!(Sub, "sub", |v1, v2| v1 - v2);
-bin_op!(Mul, "mul", |v1, v2| v1 * v2);
-bin_op!(Div, "div", |v1, v2| v1 / v2);
+bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
+bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
+bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
+bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {