summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backend.rs10
-rw-r--r--candle-core/src/backprop.rs7
-rw-r--r--candle-core/src/cpu_backend.rs128
-rw-r--r--candle-core/src/cuda_backend.rs14
-rw-r--r--candle-core/src/dummy_cuda_backend.rs16
-rw-r--r--candle-core/src/op.rs2
-rw-r--r--candle-core/src/storage.rs45
-rw-r--r--candle-core/src/tensor.rs85
8 files changed, 307 insertions, 0 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 2b873e6e..8815c08d 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -40,6 +40,16 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;
fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result<Self>;
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self>;
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
fn index_add(
&self,
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 678dbabd..38898b7b 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -39,6 +39,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
+ | Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
@@ -56,6 +57,7 @@ impl Tensor {
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
+ | Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
@@ -162,6 +164,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::Gather(arg, indexes, dim) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
+ }
+ Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 9e2d8699..b8d52c95 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -628,6 +628,59 @@ impl Map1 for Affine {
}
}
+struct Gather<'a> {
+ ids: &'a [u32],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a> Map1 for Gather<'a> {
+ fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ let src = match src_l.contiguous_offsets() {
+ Some((a, b)) => &src[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let src_dims = src_l.dims();
+ let dst_len: usize = ids_dims.iter().product();
+ let dst_left_len: usize = ids_dims[..dim].iter().product();
+ let dst_dim_len = ids_dims[dim];
+ let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let src_dim_len = src_dims[dim];
+ let src_right_len: usize = src_dims[dim + 1..].iter().product();
+
+ let mut dst = vec![T::zero(); dst_len];
+ for left_i in 0..dst_left_len {
+ let start_src_idx = left_i * src_right_len * src_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..dst_dim_len {
+ let start_dst_idx = start_dst_idx + i * dst_right_len;
+ for right_i in 0..dst_right_len {
+ let dst_idx = start_dst_idx + right_i;
+ let index = ids[dst_idx] as usize;
+ if index >= src_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: src_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let src_idx = start_src_idx + index * src_right_len + right_i;
+ dst[dst_idx] = src[src_idx]
+ }
+ }
+ }
+ Ok(dst)
+ }
+}
+
struct IndexSelect<'a> {
ids: &'a [u32],
ids_l: &'a Layout,
@@ -680,6 +733,63 @@ impl<'a> Map1 for IndexSelect<'a> {
}
}
+struct ScatterAdd<'a> {
+ ids: &'a [u32],
+ ids_l: &'a Layout,
+ dim: usize,
+}
+
+impl<'a> Map2 for ScatterAdd<'a> {
+ const OP: &'static str = "scatter-add";
+ fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
+ let dst_len = l1.shape().elem_count();
+ let mut dst = vec![T::zero(); dst_len];
+ copy_strided_src_(v1, &mut dst, 0, l1);
+ let src = match src_l.contiguous_offsets() {
+ None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
+ Some((o1, o2)) => &src[o1..o2],
+ };
+
+ let dim = self.dim;
+ let ids_dims = self.ids_l.dims();
+ let dst_dims = l1.dims();
+ let dst_dim_len = dst_dims[dim];
+ let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
+
+ let ids_left_len: usize = ids_dims[..dim].iter().product();
+ let ids_dim_len = ids_dims[dim];
+ let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
+
+ let ids = match self.ids_l.contiguous_offsets() {
+ Some((a, b)) => &self.ids[a..b],
+ None => Err(Error::RequiresContiguous { op: "gather" })?,
+ };
+ for left_i in 0..ids_left_len {
+ let start_ids_idx = left_i * ids_right_len * ids_dim_len;
+ let start_dst_idx = left_i * dst_right_len * dst_dim_len;
+ for i in 0..ids_dim_len {
+ let start_ids_idx = start_ids_idx + i * ids_right_len;
+ for right_i in 0..dst_right_len {
+ let ids_idx = start_ids_idx + right_i;
+ let index = ids[ids_idx] as usize;
+ if index >= dst_dim_len {
+ Err(Error::InvalidIndex {
+ index,
+ size: dst_dim_len,
+ op: "gather",
+ }
+ .bt())?
+ }
+ let dst_idx = start_dst_idx + index * dst_right_len + right_i;
+ dst[dst_idx] += src[ids_idx]
+ }
+ }
+ }
+
+ Ok(dst)
+ }
+}
+
struct IndexAdd<'a> {
ids: &'a [u32],
dim: usize,
@@ -1593,6 +1703,24 @@ impl BackendStorage for CpuStorage {
IndexSelect { ids, ids_l, dim }.map(self, l)
}
+ fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
+ let ids = ids.as_slice::<u32>()?;
+ Gather { ids, ids_l, dim }.map(self, l)
+ }
+
+ fn scatter_add(
+ &self,
+ l: &Layout,
+ ids: &Self,
+ ids_l: &Layout,
+ src: &Self,
+ src_l: &Layout,
+ dim: usize,
+ ) -> Result<Self> {
+ let ids = ids.as_slice::<u32>()?;
+ ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l)
+ }
+
fn index_add(
&self,
l: &Layout,
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index a5633836..43bfef2d 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1064,6 +1064,20 @@ impl BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement index-select").into())
}
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
+ Err(CudaError::InternalError("TODO: implement gather").into())
+ }
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(CudaError::InternalError("TODO: implement scatter-add").into())
+ }
fn index_add(
&self,
_: &Layout,
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index 633f146e..c195cade 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -85,6 +85,22 @@ impl crate::backend::BackendStorage for CudaStorage {
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
+ fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
+ fn scatter_add(
+ &self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: &Self,
+ _: &Layout,
+ _: usize,
+ ) -> Result<Self> {
+ Err(Error::NotCompiledWithCudaSupport)
+ }
+
fn index_add(
&self,
_: &Layout,
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index d36aa301..de5094bd 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -66,6 +66,8 @@ pub(crate) enum Op {
Reduce(Tensor, ReduceOp, Vec<usize>),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
+ Gather(Tensor, Tensor, usize),
+ ScatterAdd(Tensor, Tensor, Tensor, usize),
IndexSelect(Tensor, Tensor, usize),
IndexAdd(Tensor, Tensor, Tensor, usize),
WhereCond(Tensor, Tensor, Tensor),
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 62e2d5e7..5e6cfdf2 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -325,6 +325,51 @@ impl Storage {
}
}
+ pub(crate) fn gather(
+ &self,
+ l: &Layout,
+ indexes: &Self,
+ indexes_l: &Layout,
+ d: usize,
+ ) -> Result<Self> {
+ self.same_device(indexes, "index-add")?;
+ match (self, indexes) {
+ (Self::Cpu(s), Self::Cpu(indexes)) => {
+ let storage = s.gather(l, indexes, indexes_l, d)?;
+ Ok(Self::Cpu(storage))
+ }
+ (Self::Cuda(s), Self::Cuda(indexes)) => {
+ let storage = s.gather(l, indexes, indexes_l, d)?;
+ Ok(Self::Cuda(storage))
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ pub(crate) fn scatter_add(
+ &self,
+ l: &Layout,
+ indexes: &Self,
+ indexes_l: &Layout,
+ source: &Self,
+ source_l: &Layout,
+ d: usize,
+ ) -> Result<Self> {
+ self.same_device(indexes, "scatter-add")?;
+ self.same_device(source, "scatter-add")?;
+ match (self, indexes, source) {
+ (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
+ let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Cpu(storage))
+ }
+ (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
+ let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
+ Ok(Self::Cuda(storage))
+ }
+ _ => unreachable!(),
+ }
+ }
+
pub(crate) fn index_add(
&self,
l: &Layout,
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 1d6e4e3f..8ba0ba43 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -945,6 +945,57 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "scatter-add")?;
+ let source_dims = source.dims();
+ let self_dims = self.dims();
+ let mismatch = if source_dims.len() != self_dims.len() {
+ true
+ } else {
+ let mut mismatch = false;
+ for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
+ if i != dim && d1 != d2 {
+ mismatch = true;
+ break;
+ }
+ }
+ mismatch
+ };
+ if mismatch {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "scatter-add (self, src)",
+ lhs: self.shape().clone(),
+ rhs: source.shape().clone(),
+ })?
+ }
+ if indexes.dims() != source.dims() {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "scatter-add (indexes, src)",
+ lhs: indexes.shape().clone(),
+ rhs: source.shape().clone(),
+ })?
+ }
+ let storage = self.storage().scatter_add(
+ self.layout(),
+ &indexes.storage(),
+ indexes.layout(),
+ &source.storage(),
+ source.layout(),
+ dim,
+ )?;
+ let op = if indexes.track_op() || self.track_op() {
+ Some(Op::ScatterAdd(
+ self.clone(),
+ indexes.clone(),
+ source.clone(),
+ dim,
+ ))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, self.shape(), op, false))
+ }
+
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
@@ -992,6 +1043,40 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
+ pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "gather")?;
+ let self_dims = self.dims();
+ let indexes_dims = indexes.dims();
+ let mismatch = if indexes_dims.len() != self_dims.len() {
+ true
+ } else {
+ let mut mismatch = false;
+ for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
+ if i != dim && d1 != d2 {
+ mismatch = true;
+ break;
+ }
+ }
+ mismatch
+ };
+ if mismatch {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "gather",
+ lhs: self.shape().clone(),
+ rhs: indexes.shape().clone(),
+ })?
+ }
+ let storage =
+ self.storage()
+ .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
+ let op = if indexes.track_op() || self.track_op() {
+ Some(Op::Gather(self.clone(), indexes.clone(), dim))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, indexes.shape(), op, false))
+ }
+
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {