summaryrefslogtreecommitdiff
path: root/candle-core/src/op.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-18 07:48:48 +0100
committerGitHub <noreply@github.com>2023-07-18 07:48:48 +0100
commitd73df74cb28067f4d27e41ffc1426a42bd2bc2bc (patch)
tree23fd792286bab8871aaa595338a6ba651208b3e2 /candle-core/src/op.rs
parentb8abe2bb4b23c616c60cf4888662da744dc3fa68 (diff)
downloadcandle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.gz
candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.bz2
candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.zip
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r--candle-core/src/op.rs29
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 79473d2a..ec91a3fc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
+
+ // There is no very good way to represent optional function in traits so we go for an explicit
+ // boolean flag to mark the function as existing.
+ const BF16_VEC: bool = false;
+ fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
+ const F16_VEC: bool = false;
+ fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
+ const F32_VEC: bool = false;
+ fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
+ const F64_VEC: bool = false;
+ fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
@@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
0
}
const KERNEL: &'static str = "ugelu";
+
+ #[cfg(feature = "mkl")]
+ const F32_VEC: bool = true;
+
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f32_vec(xs: &[f32], ys: &mut [f32]) {
+ crate::mkl::vs_gelu(xs, ys)
+ }
+
+ #[cfg(feature = "mkl")]
+ const F64_VEC: bool = true;
+
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f64_vec(xs: &[f64], ys: &mut [f64]) {
+ crate::mkl::vd_gelu(xs, ys)
+ }
}
impl UnaryOp for Relu {