diff options
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 182 |
1 files changed, 120 insertions, 62 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 82e1f3e2..9a6320ec 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,6 +1,6 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; -use crate::{DType, Error, Layout, Result, Shape, WithDType}; +use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + @@ -133,9 +133,9 @@ impl Map2U8 for Cmp { } } -struct WCond<'a>(&'a [u32], &'a Layout); +struct WCond<'a, T: IntDType>(&'a [T], &'a Layout); -impl<'a> Map2 for WCond<'a> { +impl<'a, I: IntDType> Map2 for WCond<'a, I> { const OP: &'static str = "where"; #[inline(always)] fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> { @@ -150,14 +150,20 @@ impl<'a> Map2 for WCond<'a> { let f = &f[o_f1..o_f2]; pred.iter() .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .map(|(p, (&t, &f))| if p.is_true() { t } else { f }) .collect::<Vec<_>>() } _ => self .1 .strided_index() .zip(t_l.strided_index().zip(f_l.strided_index())) - .map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] }) + .map(|(i_p, (i_t, i_f))| { + if self.0[i_p].is_true() { + t[i_t] + } else { + f[i_f] + } + }) .collect::<Vec<_>>(), }; Ok(vs) @@ -628,13 +634,13 @@ impl Map1 for Affine { } } -struct Gather<'a> { - ids: &'a [u32], +struct Gather<'a, I: IntDType> { + ids: &'a [I], ids_l: &'a Layout, dim: usize, } -impl<'a> Map1 for Gather<'a> { +impl<'a, I: IntDType> Map1 for Gather<'a, I> { fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], @@ -663,7 +669,7 @@ impl<'a> Map1 for Gather<'a> { let start_dst_idx = start_dst_idx + i * dst_right_len; for right_i in 0..dst_right_len { let dst_idx = start_dst_idx + right_i; - let index = ids[dst_idx] as usize; + let index = ids[dst_idx].as_usize(); if index >= src_dim_len { Err(Error::InvalidIndex { index, @@ -681,13 +687,13 @@ impl<'a> Map1 for Gather<'a> { } } -struct IndexSelect<'a> { - ids: &'a [u32], +struct IndexSelect<'a, T: IntDType> { + ids: &'a [T], ids_l: &'a Layout, dim: usize, } -impl<'a> Map1 for IndexSelect<'a> { +impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> { let src = match layout.contiguous_offsets() { Some((a, b)) => &src[a..b], @@ -714,7 +720,7 @@ impl<'a> Map1 for IndexSelect<'a> { let start_src_idx = left_i * right_len * src_dim; let start_dst_idx = left_i * right_len * n_ids; for i in 0..n_ids { - let index = self.ids[self.ids_l.start_offset() + stride_ids * i] as usize; + let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize(); if index >= src_dim { Err(Error::InvalidIndex { index, @@ -733,13 +739,13 @@ impl<'a> Map1 for IndexSelect<'a> { } } -struct ScatterAdd<'a> { - ids: &'a [u32], +struct ScatterAdd<'a, I: IntDType> { + ids: &'a [I], ids_l: &'a Layout, dim: usize, } -impl<'a> Map2 for ScatterAdd<'a> { +impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { const OP: &'static str = "scatter-add"; fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> { let dst_len = l1.shape().elem_count(); @@ -771,7 +777,7 @@ impl<'a> Map2 for ScatterAdd<'a> { let start_ids_idx = start_ids_idx + i * ids_right_len; for right_i in 0..dst_right_len { let ids_idx = start_ids_idx + right_i; - let index = ids[ids_idx] as usize; + let index = ids[ids_idx].as_usize(); if index >= dst_dim_len { Err(Error::InvalidIndex { index, @@ -790,12 +796,12 @@ impl<'a> Map2 for ScatterAdd<'a> { } } -struct IndexAdd<'a> { - ids: &'a [u32], +struct IndexAdd<'a, I: IntDType> { + ids: &'a [I], dim: usize, } -impl<'a> Map2 for IndexAdd<'a> { +impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { const OP: &'static str = "index-add"; // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_ // v1, l1 -> self @@ -811,8 +817,8 @@ impl<'a> Map2 for IndexAdd<'a> { let max_idx = l1.dims()[dim]; let stride = src_l.stride()[dim]; if dim == 0 { - for (src_idx, &dst_idx) in self.ids.iter().enumerate() { - let dst_idx = dst_idx as usize; + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { index: dst_idx, @@ -831,8 +837,8 @@ impl<'a> Map2 for IndexAdd<'a> { } else { let pre_dim = src_l.dims()[..dim].iter().product::<usize>(); let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>(); - for (src_idx, &dst_idx) in self.ids.iter().enumerate() { - let dst_idx = dst_idx as usize; + for (src_idx, dst_idx) in self.ids.iter().enumerate() { + let dst_idx = dst_idx.as_usize(); if dst_idx >= max_idx { Err(Error::InvalidIndex { index: dst_idx, @@ -856,31 +862,52 @@ impl<'a> Map2 for IndexAdd<'a> { } } -struct Embedding<'a> { +struct Embedding<'a, I: IntDType> { vocab_size: usize, hidden_size: usize, - ids: &'a [u32], + ids: &'a [I], ids_l: &'a Layout, } -impl<'a> Map1 for Embedding<'a> { +impl<'a, I: IntDType> Map1 for Embedding<'a, I> { fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> { - // TODO: We assume that vs is contiguous here. + if !layout.is_contiguous() { + Err(Error::RequiresContiguous { op: "embedding" })? + } let vs = &vs[layout.start_offset()..]; let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); - // TODO: Optimize for the case where ids are contiguous. - for index in self.ids_l.strided_index() { - let index = self.ids[index].try_into()?; - if index >= self.vocab_size { - Err(Error::InvalidIndex { - index, - size: self.vocab_size, - op: "take", + match self.ids_l.contiguous_offsets() { + Some((o1, o2)) => { + for index in self.ids[o1..o2].iter() { + let index = index.as_usize(); + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + } + None => { + for index in self.ids_l.strided_index() { + let index = self.ids[index].as_usize(); + if index >= self.vocab_size { + Err(Error::InvalidIndex { + index, + size: self.vocab_size, + op: "take", + } + .bt())? + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } } - .bt())? - } else { - let hidden_size = self.hidden_size; - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); } } Ok(values) @@ -1671,9 +1698,11 @@ impl BackendStorage for CpuStorage { f: &Self, f_l: &Layout, ) -> Result<Self> { - // TODO: Support types that could be casted to a boolean. - let pred = self.as_slice::<u32>()?; - WCond(pred, layout).map(t, t_l, f, f_l) + match self { + Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), + } } fn conv1d( @@ -1687,25 +1716,40 @@ impl BackendStorage for CpuStorage { } fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> { - let ids = self.as_slice::<u32>()?; let (vocab_size, hidden_size) = rhs_l.shape().dims2()?; - Embedding { - vocab_size, - hidden_size, - ids, - ids_l, + match self { + Self::U8(ids) => Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l), + Self::U32(ids) => Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")), } - .map(rhs, rhs_l) } fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - IndexSelect { ids, ids_l, dim }.map(self, l) + match ids { + Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")), + } } fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - Gather { ids, ids_l, dim }.map(self, l) + match ids { + Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")), + } } fn scatter_add( @@ -1717,8 +1761,11 @@ impl BackendStorage for CpuStorage { src_l: &Layout, dim: usize, ) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l) + match ids { + Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")), + } } fn index_add( @@ -1730,12 +1777,23 @@ impl BackendStorage for CpuStorage { src_l: &Layout, dim: usize, ) -> Result<Self> { - let ids = ids.as_slice::<u32>()?; - let ids = match ids_l.contiguous_offsets() { - Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, - }; - IndexAdd { ids, dim }.map(self, l, src, src_l) + match ids { + Self::U8(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + Self::U32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" })?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")), + } } fn matmul( |