summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-18 07:48:48 +0100
committerGitHub <noreply@github.com>2023-07-18 07:48:48 +0100
commitd73df74cb28067f4d27e41ffc1426a42bd2bc2bc (patch)
tree23fd792286bab8871aaa595338a6ba651208b3e2
parentb8abe2bb4b23c616c60cf4888662da744dc3fa68 (diff)
downloadcandle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.gz
candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.tar.bz2
candle-d73df74cb28067f4d27e41ffc1426a42bd2bc2bc.zip
Preliminary support for mkl based gelu. (#187)
* Preliminary support for mkl based gelu. * Add the vectorized function for unary ops. * Get the mkl specialized gelu to work.
-rw-r--r--candle-core/src/cpu_backend.rs78
-rw-r--r--candle-core/src/mkl.rs40
-rw-r--r--candle-core/src/op.rs29
3 files changed, 135 insertions, 12 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index c336dfef..91ccd972 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -148,6 +148,48 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
}
}
+fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
+ vs: &[T],
+ layout: &Layout,
+ mut f: F,
+ mut f_vec: FV,
+) -> Vec<U> {
+ match layout.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => {
+ let mut ys: Vec<U> = Vec::with_capacity(len);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
+ f_vec(&vs[start_offset..start_offset + len], ys_to_set);
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(len) };
+ ys
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ let mut result = vec![];
+ result.reserve(layout.shape().elem_count());
+ // Specialize the case where block_len is one to avoid the second loop.
+ if block_len == 1 {
+ for index in block_start_index {
+ let v = unsafe { vs.get_unchecked(index) };
+ result.push(f(*v))
+ }
+ } else {
+ // TODO: Use f_vec here.
+ for index in block_start_index {
+ for offset in 0..block_len {
+ let v = unsafe { vs.get_unchecked(index + offset) };
+ result.push(f(*v))
+ }
+ }
+ }
+ result
+ }
+ }
+}
+
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
lhs_l: &Layout,
@@ -864,20 +906,40 @@ impl BackendStorage for CpuStorage {
fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {
- let data = unary_map(storage, layout, B::bf16);
- Ok(Self::BF16(data))
+ if B::BF16_VEC {
+ let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
+ Ok(Self::BF16(data))
+ } else {
+ let data = unary_map(storage, layout, B::bf16);
+ Ok(Self::BF16(data))
+ }
}
Self::F16(storage) => {
- let data = unary_map(storage, layout, B::f16);
- Ok(Self::F16(data))
+ if B::F16_VEC {
+ let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
+ Ok(Self::F16(data))
+ } else {
+ let data = unary_map(storage, layout, B::f16);
+ Ok(Self::F16(data))
+ }
}
Self::F32(storage) => {
- let data = unary_map(storage, layout, B::f32);
- Ok(Self::F32(data))
+ if B::F32_VEC {
+ let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
+ Ok(Self::F32(data))
+ } else {
+ let data = unary_map(storage, layout, B::f32);
+ Ok(Self::F32(data))
+ }
}
Self::F64(storage) => {
- let data = unary_map(storage, layout, B::f64);
- Ok(Self::F64(data))
+ if B::F64_VEC {
+ let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
+ Ok(Self::F64(data))
+ } else {
+ let data = unary_map(storage, layout, B::f64);
+ Ok(Self::F64(data))
+ }
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
diff --git a/candle-core/src/mkl.rs b/candle-core/src/mkl.rs
index aabe6edc..60bddcb4 100644
--- a/candle-core/src/mkl.rs
+++ b/candle-core/src/mkl.rs
@@ -1,3 +1,4 @@
+#![allow(dead_code)]
use libc::{c_char, c_double, c_float, c_int};
mod ffi {
@@ -156,9 +157,8 @@ pub unsafe fn hgemm(
)
}
-#[allow(dead_code)]
#[inline]
-pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
+fn vs_tanh(a: &[f32], y: &mut [f32]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@@ -167,9 +167,8 @@ pub fn vs_tanh(a: &[f32], y: &mut [f32]) {
unsafe { ffi::vsTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
-#[allow(dead_code)]
#[inline]
-pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
+fn vd_tanh(a: &[f64], y: &mut [f64]) {
let a_len = a.len();
let y_len = y.len();
if a_len != y_len {
@@ -177,3 +176,36 @@ pub fn vd_tanh(a: &[f64], y: &mut [f64]) {
}
unsafe { ffi::vdTanh(a_len as i32, a.as_ptr(), y.as_mut_ptr()) }
}
+
+// The vector functions from mkl can be performed in place by using the same array for input and
+// output.
+// https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-2/vector-mathematical-functions.html
+#[inline]
+pub fn vs_tanh_inplace(y: &mut [f32]) {
+ unsafe { ffi::vsTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
+}
+
+#[inline]
+pub fn vd_tanh_inplace(y: &mut [f64]) {
+ unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
+}
+
+pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f32 / std::f32::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vs_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
+
+pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = (2.0f64 / std::f64::consts::PI).sqrt() * v * (1.0 + 0.044715 * v * v)
+ }
+ vd_tanh_inplace(ys);
+ for (&v, y) in vs.iter().zip(ys.iter_mut()) {
+ *y = 0.5 * v * (1.0 + *y)
+ }
+}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 79473d2a..ec91a3fc 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -60,6 +60,17 @@ pub(crate) trait UnaryOp {
fn f64(v1: f64) -> f64;
fn u8(v1: u8) -> u8;
fn u32(v1: u32) -> u32;
+
+ // There is no very good way to represent optional function in traits so we go for an explicit
+ // boolean flag to mark the function as existing.
+ const BF16_VEC: bool = false;
+ fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {}
+ const F16_VEC: bool = false;
+ fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {}
+ const F32_VEC: bool = false;
+ fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {}
+ const F64_VEC: bool = false;
+ fn f64_vec(_xs: &[f64], _ys: &mut [f64]) {}
}
pub(crate) trait BinaryOp {
@@ -219,6 +230,24 @@ impl UnaryOp for Gelu {
0
}
const KERNEL: &'static str = "ugelu";
+
+ #[cfg(feature = "mkl")]
+ const F32_VEC: bool = true;
+
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f32_vec(xs: &[f32], ys: &mut [f32]) {
+ crate::mkl::vs_gelu(xs, ys)
+ }
+
+ #[cfg(feature = "mkl")]
+ const F64_VEC: bool = true;
+
+ #[cfg(feature = "mkl")]
+ #[inline(always)]
+ fn f64_vec(xs: &[f64], ys: &mut [f64]) {
+ crate::mkl::vd_gelu(xs, ys)
+ }
}
impl UnaryOp for Relu {