diff options
Diffstat (limited to 'candle-core/src/op.rs')
-rw-r--r-- | candle-core/src/op.rs | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 79473d2a..ec91a3fc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -60,6 +60,17 @@ pub(crate) trait UnaryOp { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + + // There is no very good way to represent optional function in traits so we go for an explicit + // boolean flag to mark the function as existing. + const BF16_VEC: bool = false; + fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } pub(crate) trait BinaryOp { @@ -219,6 +230,24 @@ impl UnaryOp for Gelu { 0 } const KERNEL: &'static str = "ugelu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_gelu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_gelu(xs, ys) + } } impl UnaryOp for Relu { |