/// Helper functions to write CPU kernels. use crate::backend::BackendStorage; use crate::{Error, Layout, Result, WithDType}; type C = super::CpuStorage; pub trait Map1 { fn f(&self, vs: &[T], layout: &Layout) -> Result>; fn map(&self, vs: &C, layout: &Layout) -> Result { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), } } } pub trait Map1Any { fn f) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result; fn map(&self, vs: &C, layout: &Layout) -> Result { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), } } } pub trait Map2 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), op: Self::OP, } .bt()), } } } pub trait Map2U8 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { match (v1, v2) { (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), op: Self::OP, } .bt()), } } } pub fn binary_map U>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] .iter() .zip(rhs[o_r1..o_r2].iter()) .map(|(&l, &r)| f(l, r)) .collect(), (Some((o_l1, o_l2)), None) => { // TODO: Maybe we want to avoid going through the layout twice. match rhs_l.offsets_b() { Some(ob) => { let mut i_in_block = 0; let mut i_right_broadcast = 0; lhs[o_l1..o_l2] .iter() .map(|&l| { let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; i_right_broadcast += 1; if i_right_broadcast >= ob.right_broadcast { i_in_block += 1; i_right_broadcast = 0; } if i_in_block >= ob.len { i_in_block = 0 } f(l, *r) }) .collect() } None => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), } } (None, Some((o_r1, o_r2))) => { // TODO: Maybe we want to avoid going through the layout twice. match lhs_l.offsets_b() { Some(ob) => { let mut i_in_block = 0; let mut i_right_broadcast = 0; rhs[o_r1..o_r2] .iter() .map(|&r| { let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; i_right_broadcast += 1; if i_right_broadcast >= ob.right_broadcast { i_in_block += 1; i_right_broadcast = 0; } if i_in_block >= ob.len { i_in_block = 0 } f(*l, r) }) .collect() } None => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), } } _ => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), } } // Similar to binary_map but with vectorized variants. pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( lhs_l: &Layout, rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, mut f_vec: FV, ) -> Vec { let el_count = lhs_l.shape().elem_count(); match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) }; f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; ys } (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) }; let mut dst_i = 0; for src_i in (o_l1..o_l2).step_by(ob.len) { f_vec( &lhs[src_i..src_i + ob.len], rhs, &mut ys_to_set[dst_i..dst_i + ob.len], ); dst_i += ob.len; } // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; ys } Some(ob) => { let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys = lhs[o_l1..o_l2].to_vec(); for idx_l in 0..ob.left_broadcast { let start = idx_l * ob.len * ob.right_broadcast; for (i, &r) in rhs.iter().enumerate() { let start = start + i * ob.right_broadcast; for v in ys[start..start + ob.right_broadcast].iter_mut() { *v = f(*v, r) } } } ys } None => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), }, (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) }; let mut dst_i = 0; for src_i in (o_r1..o_r2).step_by(ob.len) { f_vec( lhs, &rhs[src_i..src_i + ob.len], &mut ys_to_set[dst_i..dst_i + ob.len], ); dst_i += ob.len; } // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; ys } Some(ob) => { let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys = rhs[o_r1..o_r2].to_vec(); for idx_l in 0..ob.left_broadcast { let start = idx_l * ob.len * ob.right_broadcast; for (i, &l) in lhs.iter().enumerate() { let start = start + i * ob.right_broadcast; for v in ys[start..start + ob.right_broadcast].iter_mut() { *v = f(l, *v) } } } ys } None => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), }, _ => lhs_l .strided_index() .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .collect(), } } pub fn unary_map U>( vs: &[T], layout: &Layout, mut f: F, ) -> Vec { match layout.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => vs [start_offset..start_offset + len] .iter() .map(|&v| f(v)) .collect(), crate::StridedBlocks::MultipleBlocks { block_start_index, block_len, } => { let mut result = Vec::with_capacity(layout.shape().elem_count()); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { for index in block_start_index { let v = unsafe { vs.get_unchecked(index) }; result.push(f(*v)) } } else { for index in block_start_index { for offset in 0..block_len { let v = unsafe { vs.get_unchecked(index + offset) }; result.push(f(*v)) } } } result } } } pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( vs: &[T], layout: &Layout, mut f: F, mut f_vec: FV, ) -> Vec { match layout.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => { let mut ys: Vec = Vec::with_capacity(len); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) }; f_vec(&vs[start_offset..start_offset + len], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(len) }; ys } crate::StridedBlocks::MultipleBlocks { block_start_index, block_len, } => { let el_count = layout.shape().elem_count(); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { let mut result = Vec::with_capacity(el_count); for index in block_start_index { let v = unsafe { vs.get_unchecked(index) }; result.push(f(*v)) } result } else { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) }; let mut dst_index = 0; for src_index in block_start_index { let vs = &vs[src_index..src_index + block_len]; let ys = &mut ys_to_set[dst_index..dst_index + block_len]; f_vec(vs, ys); dst_index += block_len; } // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; ys } } } }