diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-22 08:21:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-22 07:21:28 +0100 |
commit | 52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch) | |
tree | 8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/cpu_backend.rs | |
parent | 6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff) | |
download | candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2 candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip |
Add the gather op. (#219)
* Start adding gather.
* Gather cpu implementation + use in simple training.
* Add scatter_add for the gradient of gather.
* Simple cpu implementation of scatter_add.
* Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r-- | candle-core/src/cpu_backend.rs | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9e2d8699..b8d52c95 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -628,6 +628,59 @@ impl Map1 for Affine { } } +struct Gather<'a> { + ids: &'a [u32], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a> Map1 for Gather<'a> { + fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> { + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + let src = match src_l.contiguous_offsets() { + Some((a, b)) => &src[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let src_dims = src_l.dims(); + let dst_len: usize = ids_dims.iter().product(); + let dst_left_len: usize = ids_dims[..dim].iter().product(); + let dst_dim_len = ids_dims[dim]; + let dst_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let src_dim_len = src_dims[dim]; + let src_right_len: usize = src_dims[dim + 1..].iter().product(); + + let mut dst = vec![T::zero(); dst_len]; + for left_i in 0..dst_left_len { + let start_src_idx = left_i * src_right_len * src_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..dst_dim_len { + let start_dst_idx = start_dst_idx + i * dst_right_len; + for right_i in 0..dst_right_len { + let dst_idx = start_dst_idx + right_i; + let index = ids[dst_idx] as usize; + if index >= src_dim_len { + Err(Error::InvalidIndex { + index, + size: src_dim_len, + op: "gather", + } + .bt())? + } + let src_idx = start_src_idx + index * src_right_len + right_i; + dst[dst_idx] = src[src_idx] + } + } + } + Ok(dst) + } +} + struct IndexSelect<'a> { ids: &'a [u32], ids_l: &'a Layout, @@ -680,6 +733,63 @@ impl<'a> Map1 for IndexSelect<'a> { } } +struct ScatterAdd<'a> { + ids: &'a [u32], + ids_l: &'a Layout, + dim: usize, +} + +impl<'a> Map2 for ScatterAdd<'a> { + const OP: &'static str = "scatter-add"; + fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> { + let dst_len = l1.shape().elem_count(); + let mut dst = vec![T::zero(); dst_len]; + copy_strided_src_(v1, &mut dst, 0, l1); + let src = match src_l.contiguous_offsets() { + None => Err(Error::RequiresContiguous { op: "scatter-add" })?, + Some((o1, o2)) => &src[o1..o2], + }; + + let dim = self.dim; + let ids_dims = self.ids_l.dims(); + let dst_dims = l1.dims(); + let dst_dim_len = dst_dims[dim]; + let dst_right_len: usize = dst_dims[dim + 1..].iter().product(); + + let ids_left_len: usize = ids_dims[..dim].iter().product(); + let ids_dim_len = ids_dims[dim]; + let ids_right_len: usize = ids_dims[dim + 1..].iter().product(); + + let ids = match self.ids_l.contiguous_offsets() { + Some((a, b)) => &self.ids[a..b], + None => Err(Error::RequiresContiguous { op: "gather" })?, + }; + for left_i in 0..ids_left_len { + let start_ids_idx = left_i * ids_right_len * ids_dim_len; + let start_dst_idx = left_i * dst_right_len * dst_dim_len; + for i in 0..ids_dim_len { + let start_ids_idx = start_ids_idx + i * ids_right_len; + for right_i in 0..dst_right_len { + let ids_idx = start_ids_idx + right_i; + let index = ids[ids_idx] as usize; + if index >= dst_dim_len { + Err(Error::InvalidIndex { + index, + size: dst_dim_len, + op: "gather", + } + .bt())? + } + let dst_idx = start_dst_idx + index * dst_right_len + right_i; + dst[dst_idx] += src[ids_idx] + } + } + } + + Ok(dst) + } +} + struct IndexAdd<'a> { ids: &'a [u32], dim: usize, @@ -1593,6 +1703,24 @@ impl BackendStorage for CpuStorage { IndexSelect { ids, ids_l, dim }.map(self, l) } + fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> { + let ids = ids.as_slice::<u32>()?; + Gather { ids, ids_l, dim }.map(self, l) + } + + fn scatter_add( + &self, + l: &Layout, + ids: &Self, + ids_l: &Layout, + src: &Self, + src_l: &Layout, + dim: usize, + ) -> Result<Self> { + let ids = ids.as_slice::<u32>()?; + ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l) + } + fn index_add( &self, l: &Layout, |