summaryrefslogtreecommitdiff
path: root/candle-core/src/cpu_backend.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-22 08:21:28 +0200
committerGitHub <noreply@github.com>2023-07-22 07:21:28 +0100
commit52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch)
tree8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/cpu_backend.rs
parent6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff)
downloadcandle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/cpu_backend.rs')
-rw-r--r--candle-core/src/cpu_backend.rs128
1 files changed, 128 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 9e2d8699..b8d52c95 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -628,6 +628,59 @@ impl Map1 for Affine {
}
}
+struct Gather<'a> {
+ ids: &'a [u32],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a> Map1 for Gather<'a> {
+ fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ let src = match src_l.contiguous_offsets() {
+ Some((a, b)) => &src[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let src_dims = src_l.dims();
+ let dst_len: usize = ids_dims.iter().product();
+ let dst_left_len: usize = ids_dims[..dim].iter().product();
+ let dst_dim_len = ids_dims[dim];
+ let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let src_dim_len = src_dims[dim];
+ let src_right_len: usize = src_dims[dim + 1..].iter().product();
+
+ let mut dst = vec![T::zero(); dst_len];
+ for left_i in 0..dst_left_len {
+ let start_src_idx = left_i * src_right_len * src_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..dst_dim_len {
+ let start_dst_idx = start_dst_idx + i * dst_right_len;
+ for right_i in 0..dst_right_len {
+ let dst_idx = start_dst_idx + right_i;
+ let index = ids[dst_idx] as usize;
+ if index >= src_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let src_idx = start_src_idx + index * src_right_len + right_i;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct IndexSelect<'a> {
ids: &'a [u32],
ids_l: &'a Layout,
@@ -680,6 +733,63 @@ impl<'a> Map1 for IndexSelect<'a> {
}
}
+struct ScatterAdd<'a> {
+ ids: &'a [u32],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a> Map2 for ScatterAdd<'a> {
+ const OP: &'static str = "scatter-add";
+ fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
+ let dst_len = l1.shape().elem_count();
+ let mut dst = vec![T::zero(); dst_len];
+ copy_strided_src_(v1, &mut dst, 0, l1);
+ let src = match src_l.contiguous_offsets() {
+ None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
+ Some((o1, o2)) => &src[o1..o2],
+ };
+
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let dst_dims = l1.dims();
+ let dst_dim_len = dst_dims[dim];
+ let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
+
+ let ids_left_len: usize = ids_dims[..dim].iter().product();
+ let ids_dim_len = ids_dims[dim];
+ let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ for left_i in 0..ids_left_len {
+ let start_ids_idx = left_i * ids_right_len * ids_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..ids_dim_len {
+ let start_ids_idx = start_ids_idx + i * ids_right_len;
+ for right_i in 0..dst_right_len {
+ let ids_idx = start_ids_idx + right_i;
+ let index = ids[ids_idx] as usize;
+ if index >= dst_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: dst_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let dst_idx = start_dst_idx + index * dst_right_len + right_i;
+ dst[dst_idx] += src[ids_idx]
+ }
+ }
+ }
+
+ Ok(dst)
+ }
+}
+
struct IndexAdd<'a> {
ids: &'a [u32],
dim: usize,
@@ -1593,6 +1703,24 @@ impl BackendStorage for CpuStorage {
IndexSelect { ids, ids_l, dim }.map(self, l)
}
+ fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
+ let ids = ids.as_slice::<u32>()?;
+ Gather { ids, ids_l, dim }.map(self, l)
+ }
+
+ fn scatter_add(
+ &self,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
+ ) -> Result<Self> {
+ let ids = ids.as_slice::<u32>()?;
+ ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
+ }
+
fn index_add(
&self,
l: &Layout,