summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend/utils.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-29 23:02:11 +0100
committerGitHub <noreply@github.com>2024-03-29 23:02:11 +0100
commit665da304878326e267b178fa6e6d85424249126b (patch)
treeb1c4e16174c84ffadc56d2ac5ec26d2a5882b86a /candle-core/src/cpu_backend/utils.rs
parent356a170ae92ea85411e605de1be2685b4c923358 (diff)
downloadcandle-665da304878326e267b178fa6e6d85424249126b.tar.gz
candle-665da304878326e267b178fa6e6d85424249126b.tar.bz2
candle-665da304878326e267b178fa6e6d85424249126b.zip
Backend refactoring. (#1966)
* Backend refactoring. * Metal tweaks. * Move the cudnn module.
Diffstat (limited to 'candle-core/src/cpu_backend/utils.rs')
-rw-r--r--candle-core/src/cpu_backend/utils.rs350
1 files changed, 350 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs
new file mode 100644
index 00000000..af25a2af
--- /dev/null
+++ b/candle-core/src/cpu_backend/utils.rs
@@ -0,0 +1,350 @@
+/// Helper functions to write CPU kernels.
+use crate::backend::BackendStorage;
+use crate::{Error, Layout, Result, WithDType};
+
+type C = super::CpuStorage;
+pub trait Map1 {
+ fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
+
+ fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
+ match vs {
+ C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
+ C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
+ C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
+ C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
+ C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
+ C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
+ C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
+ }
+ }
+}
+
+pub trait Map1Any {
+ fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
+
+ fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
+ match vs {
+ C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
+ C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
+ C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
+ C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
+ C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
+ C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
+ C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
+ }
+ }
+}
+
+pub trait Map2 {
+ const OP: &'static str;
+ fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
+
+ fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
+ match (v1, v2) {
+ (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
+ (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
+ (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
+ (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
+ (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
+ (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
+ _ => Err(Error::DTypeMismatchBinaryOp {
+ lhs: v1.dtype(),
+ rhs: v2.dtype(),
+ op: Self::OP,
+ }
+ .bt()),
+ }
+ }
+}
+
+pub trait Map2U8 {
+ const OP: &'static str;
+ fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
+
+ fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
+ match (v1, v2) {
+ (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
+ _ => Err(Error::DTypeMismatchBinaryOp {
+ lhs: v1.dtype(),
+ rhs: v2.dtype(),
+ op: Self::OP,
+ }
+ .bt()),
+ }
+ }
+}
+
+pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ lhs: &[T],
+ rhs: &[T],
+ mut f: F,
+) -> Vec<U> {
+ match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
+ (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
+ .iter()
+ .zip(rhs[o_r1..o_r2].iter())
+ .map(|(&l, &r)| f(l, r))
+ .collect(),
+ (Some((o_l1, o_l2)), None) => {
+ // TODO: Maybe we want to avoid going through the layout twice.
+ match rhs_l.offsets_b() {
+ Some(ob) => {
+ let mut i_in_block = 0;
+ let mut i_right_broadcast = 0;
+ lhs[o_l1..o_l2]
+ .iter()
+ .map(|&l| {
+ let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
+ i_right_broadcast += 1;
+ if i_right_broadcast >= ob.right_broadcast {
+ i_in_block += 1;
+ i_right_broadcast = 0;
+ }
+ if i_in_block >= ob.len {
+ i_in_block = 0
+ }
+ f(l, *r)
+ })
+ .collect()
+ }
+ None => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ }
+ }
+ (None, Some((o_r1, o_r2))) => {
+ // TODO: Maybe we want to avoid going through the layout twice.
+ match lhs_l.offsets_b() {
+ Some(ob) => {
+ let mut i_in_block = 0;
+ let mut i_right_broadcast = 0;
+ rhs[o_r1..o_r2]
+ .iter()
+ .map(|&r| {
+ let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
+ i_right_broadcast += 1;
+ if i_right_broadcast >= ob.right_broadcast {
+ i_in_block += 1;
+ i_right_broadcast = 0;
+ }
+ if i_in_block >= ob.len {
+ i_in_block = 0
+ }
+ f(*l, r)
+ })
+ .collect()
+ }
+ None => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ }
+ }
+ _ => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ }
+}
+
+// Similar to binary_map but with vectorized variants.
+pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
+ lhs_l: &Layout,
+ rhs_l: &Layout,
+ lhs: &[T],
+ rhs: &[T],
+ mut f: F,
+ mut f_vec: FV,
+) -> Vec<T> {
+ let el_count = lhs_l.shape().elem_count();
+ match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
+ (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
+ let mut ys: Vec<T> = Vec::with_capacity(el_count);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
+ f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(el_count) };
+ ys
+ }
+ (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
+ Some(ob) if ob.right_broadcast == 1 => {
+ let rhs = &rhs[ob.start..ob.start + ob.len];
+ let mut ys: Vec<T> = Vec::with_capacity(el_count);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
+ let mut dst_i = 0;
+ for src_i in (o_l1..o_l2).step_by(ob.len) {
+ f_vec(
+ &lhs[src_i..src_i + ob.len],
+ rhs,
+ &mut ys_to_set[dst_i..dst_i + ob.len],
+ );
+ dst_i += ob.len;
+ }
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(el_count) };
+ ys
+ }
+ Some(ob) => {
+ let rhs = &rhs[ob.start..ob.start + ob.len];
+ let mut ys = lhs[o_l1..o_l2].to_vec();
+ for idx_l in 0..ob.left_broadcast {
+ let start = idx_l * ob.len * ob.right_broadcast;
+ for (i, &r) in rhs.iter().enumerate() {
+ let start = start + i * ob.right_broadcast;
+ for v in ys[start..start + ob.right_broadcast].iter_mut() {
+ *v = f(*v, r)
+ }
+ }
+ }
+ ys
+ }
+ None => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ },
+ (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
+ Some(ob) if ob.right_broadcast == 1 => {
+ let lhs = &lhs[ob.start..ob.start + ob.len];
+ let mut ys: Vec<T> = Vec::with_capacity(el_count);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
+ let mut dst_i = 0;
+ for src_i in (o_r1..o_r2).step_by(ob.len) {
+ f_vec(
+ lhs,
+ &rhs[src_i..src_i + ob.len],
+ &mut ys_to_set[dst_i..dst_i + ob.len],
+ );
+ dst_i += ob.len;
+ }
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(el_count) };
+ ys
+ }
+ Some(ob) => {
+ let lhs = &lhs[ob.start..ob.start + ob.len];
+ let mut ys = rhs[o_r1..o_r2].to_vec();
+ for idx_l in 0..ob.left_broadcast {
+ let start = idx_l * ob.len * ob.right_broadcast;
+ for (i, &l) in lhs.iter().enumerate() {
+ let start = start + i * ob.right_broadcast;
+ for v in ys[start..start + ob.right_broadcast].iter_mut() {
+ *v = f(l, *v)
+ }
+ }
+ }
+ ys
+ }
+ None => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ },
+ _ => lhs_l
+ .strided_index()
+ .zip(rhs_l.strided_index())
+ .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
+ .collect(),
+ }
+}
+
+pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
+ vs: &[T],
+ layout: &Layout,
+ mut f: F,
+) -> Vec<U> {
+ match layout.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => vs
+ [start_offset..start_offset + len]
+ .iter()
+ .map(|&v| f(v))
+ .collect(),
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ let mut result = Vec::with_capacity(layout.shape().elem_count());
+ // Specialize the case where block_len is one to avoid the second loop.
+ if block_len == 1 {
+ for index in block_start_index {
+ let v = unsafe { vs.get_unchecked(index) };
+ result.push(f(*v))
+ }
+ } else {
+ for index in block_start_index {
+ for offset in 0..block_len {
+ let v = unsafe { vs.get_unchecked(index + offset) };
+ result.push(f(*v))
+ }
+ }
+ }
+ result
+ }
+ }
+}
+
+pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
+ vs: &[T],
+ layout: &Layout,
+ mut f: F,
+ mut f_vec: FV,
+) -> Vec<U> {
+ match layout.strided_blocks() {
+ crate::StridedBlocks::SingleBlock { start_offset, len } => {
+ let mut ys: Vec<U> = Vec::with_capacity(len);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
+ f_vec(&vs[start_offset..start_offset + len], ys_to_set);
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(len) };
+ ys
+ }
+ crate::StridedBlocks::MultipleBlocks {
+ block_start_index,
+ block_len,
+ } => {
+ let el_count = layout.shape().elem_count();
+ // Specialize the case where block_len is one to avoid the second loop.
+ if block_len == 1 {
+ let mut result = Vec::with_capacity(el_count);
+ for index in block_start_index {
+ let v = unsafe { vs.get_unchecked(index) };
+ result.push(f(*v))
+ }
+ result
+ } else {
+ let mut ys: Vec<U> = Vec::with_capacity(el_count);
+ let ys_to_set = ys.spare_capacity_mut();
+ let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
+ let mut dst_index = 0;
+ for src_index in block_start_index {
+ let vs = &vs[src_index..src_index + block_len];
+ let ys = &mut ys_to_set[dst_index..dst_index + block_len];
+ f_vec(vs, ys);
+ dst_index += block_len;
+ }
+ // SAFETY: values are all set by f_vec.
+ unsafe { ys.set_len(el_count) };
+ ys
+ }
+ }
+ }
+}