diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-18 07:48:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-18 07:48:48 +0100 |
commit | d73df74cb28067f4d27e41ffc1426a42bd2bc2bc (patch) | |
tree | 23fd792286bab8871aaa595338a6ba651208b3e2 /candle-core/src/op.rs | |
parent | b8abe2bb4b23c616c60cf4888662da744dc3fa68 (diff) | |
download | candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.gz candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.bz2 candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.zip |
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu.
* Add the vectorized function for unary ops.
* Get the mkl specialized gelu to work.
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 79473d2a..ec91a3fc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -60,6 +60,17 @@ pub(crate) trait UnaryOp { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + + // There is no very good way to represent optional function in traits so we go for an explicit + // boolean flag to mark the function as existing. + const BF16_VEC: bool = false; + fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } pub(crate) trait BinaryOp { @@ -219,6 +230,24 @@ impl UnaryOp for Gelu { 0 } const KERNEL: &'static str = "ugelu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_gelu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_gelu(xs, ys) + } } impl UnaryOp for Relu { |