summaryrefslogtreecommitdiff
path: root/candle-core/src/tensor.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-22 08:21:28 +0200
committerGitHub <noreply@github.com>2023-07-22 07:21:28 +0100
commit52c5d8c087f6a2ee91b807467860eb3e96bb6267 (patch)
tree8b5738ae1f5e7fb662f58b192238469d8d28f25f /candle-core/src/tensor.rs
parent6eeea1b04e5bf52a77b3f7e35a5c51e13b383848 (diff)
downloadcandle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.gz
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.tar.bz2
candle-52c5d8c087f6a2ee91b807467860eb3e96bb6267.zip
Add the gather op. (#219)
* Start adding gather. * Gather cpu implementation + use in simple training. * Add scatter_add for the gradient of gather. * Simple cpu implementation of scatter_add. * Use gather in the simple-training backprop.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r--candle-core/src/tensor.rs85
1 files changed, 85 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 1d6e4e3f..8ba0ba43 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -945,6 +945,57 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "scatter-add")?;
+ let source_dims = source.dims();
+ let self_dims = self.dims();
+ let mismatch = if source_dims.len() != self_dims.len() {
+ true
+ } else {
+ let mut mismatch = false;
+ for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
+ if i != dim && d1 != d2 {
+ mismatch = true;
+ break;
+ }
+ }
+ mismatch
+ };
+ if mismatch {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "scatter-add (self, src)",
+ lhs: self.shape().clone(),
+ rhs: source.shape().clone(),
+ })?
+ }
+ if indexes.dims() != source.dims() {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "scatter-add (indexes, src)",
+ lhs: indexes.shape().clone(),
+ rhs: source.shape().clone(),
+ })?
+ }
+ let storage = self.storage().scatter_add(
+ self.layout(),
+ &indexes.storage(),
+ indexes.layout(),
+ &source.storage(),
+ source.layout(),
+ dim,
+ )?;
+ let op = if indexes.track_op() || self.track_op() {
+ Some(Op::ScatterAdd(
+ self.clone(),
+ indexes.clone(),
+ source.clone(),
+ dim,
+ ))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, self.shape(), op, false))
+ }
+
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
let source_dims = source.dims();
@@ -992,6 +1043,40 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
+ pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "gather")?;
+ let self_dims = self.dims();
+ let indexes_dims = indexes.dims();
+ let mismatch = if indexes_dims.len() != self_dims.len() {
+ true
+ } else {
+ let mut mismatch = false;
+ for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
+ if i != dim && d1 != d2 {
+ mismatch = true;
+ break;
+ }
+ }
+ mismatch
+ };
+ if mismatch {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "gather",
+ lhs: self.shape().clone(),
+ rhs: indexes.shape().clone(),
+ })?
+ }
+ let storage =
+ self.storage()
+ .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
+ let op = if indexes.track_op() || self.track_op() {
+ Some(Op::Gather(self.clone(), indexes.clone(), dim))
+ } else {
+ None
+ };
+ Ok(from_storage(storage, indexes.shape(), op, false))
+ }
+
pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-select")?;
let indexes_len = match indexes.dims() {