diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-18 07:48:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-18 07:48:48 +0100 |
commit | d73df74cb28067f4d27e41ffc1426a42bd2bc2bc (patch) | |
tree | 23fd792286bab8871aaa595338a6ba651208b3e2 | |
parent | b8abe2bb4b23c616c60cf4888662da744dc3fa68 (diff) | |
download | candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.gz candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.bz2 candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.zip |
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu.
* Add the vectorized function for unary ops.
* Get the mkl specialized gelu to work.
-rw-r--r-- | candle-core/src/cpu_backend.rs | 78 | ||||
-rw-r--r-- | candle-core/src/mkl.rs | 40 | ||||
-rw-r--r-- | candle-core/src/op.rs | 29 |
3 files changed, 135 insertions, 12 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index c336dfef..91ccd972 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut } } +fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>( + vs: &[T], + layout: &Layout, + mut f: F, + mut f_vec: FV, +) -> Vec<U> { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let mut ys: Vec<U> = Vec::with_capacity(len); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + f_vec(&vs[start_offset..start_offset + len], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(len) }; + ys + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut result = vec![]; + result.reserve(layout.shape().elem_count()); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + // TODO: Use f_vec here. + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + result + } + } +} + // This function maps over two strided index sequences. fn binary_map<T: Copy, F: FnMut(T, T) -> T>( lhs_l: &Layout, @@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage { fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> { match self { Self::BF16(storage) => { - let data = unary_map(storage, layout, B::bf16); - Ok(Self::BF16(data)) + if B::BF16_VEC { + let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec); + Ok(Self::BF16(data)) + } else { + let data = unary_map(storage, layout, B::bf16); + Ok(Self::BF16(data)) + } } Self::F16(storage) => { - let data = unary_map(storage, layout, B::f16); - Ok(Self::F16(data)) + if B::F16_VEC { + let data = unary_map_vec(storage, layout, B::f16, B::f16_vec); + Ok(Self::F16(data)) + } else { + let data = unary_map(storage, layout, B::f16); + Ok(Self::F16(data)) + } } Self::F32(storage) => { - let data = unary_map(storage, layout, B::f32); - Ok(Self::F32(data)) + if B::F32_VEC { + let data = unary_map_vec(storage, layout, B::f32, B::f32_vec); + Ok(Self::F32(data)) + } else { + let data = unary_map(storage, layout, B::f32); + Ok(Self::F32(data)) + } } Self::F64(storage) => { - let data = unary_map(storage, layout, B::f64); - Ok(Self::F64(data)) + if B::F64_VEC { + let data = unary_map_vec(storage, layout, B::f64, B::f64_vec); + Ok(Self::F64(data)) + } else { + let data = unary_map(storage, layout, B::f64); + Ok(Self::F64(data)) + } } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs index aabe6edc..60bddcb4 100644 --- a/candle-core/src/mkl.rs +++ b/candle-core/src/mkl.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] use libc::{c_char, c_double, c_float, c_int}; mod ffi { @@ -156,9 +157,8 @@ pub unsafe fn hgemm( ) } -#[allow(dead_code)] #[inline] -pub fn vs_tanh(a: &[f32], y: &mut [f32]) { +fn vs_tanh(a: &[f32], y: &mut [f32]) { let a_len = a.len(); let y_len = y.len(); if a_len != y_len { @@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) { unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } } -#[allow(dead_code)] #[inline] -pub fn vd_tanh(a: &[f64], y: &mut [f64]) { +fn vd_tanh(a: &[f64], y: &mut [f64]) { let a_len = a.len(); let y_len = y.len(); if a_len != y_len { @@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) { } unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) } } + +// The vector functions from mkl can be performed in place by using the same array for input and +// output. +// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html +#[inline] +pub fn vs_tanh_inplace(y: &mut [f32]) { + unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +#[inline] +pub fn vd_tanh_inplace(y: &mut [f64]) { + unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) } +} + +pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vs_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} + +pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) { + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v) + } + vd_tanh_inplace(ys); + for (&v, y) in vs.iter().zip(ys.iter_mut()) { + *y = 0.5 * v * (1.0 + *y) + } +} diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 79473d2a..ec91a3fc 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -60,6 +60,17 @@ pub(crate) trait UnaryOp { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + + // There is no very good way to represent optional function in traits so we go for an explicit + // boolean flag to mark the function as existing. + const BF16_VEC: bool = false; + fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} + const F16_VEC: bool = false; + fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F32_VEC: bool = false; + fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} + const F64_VEC: bool = false; + fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {} } pub(crate) trait BinaryOp { @@ -219,6 +230,24 @@ impl UnaryOp for Gelu { 0 } const KERNEL: &'static str = "ugelu"; + + #[cfg(feature = "mkl")] + const F32_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f32_vec(xs: &[f32], ys: &mut [f32]) { + crate::mkl::vs_gelu(xs, ys) + } + + #[cfg(feature = "mkl")] + const F64_VEC: bool = true; + + #[cfg(feature = "mkl")] + #[inline(always)] + fn f64_vec(xs: &[f64], ys: &mut [f64]) { + crate::mkl::vd_gelu(xs, ys) + } } impl UnaryOp for Relu { |