diff options
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 1d6e4e3f..8ba0ba43 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -945,6 +945,57 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "scatter-add")?; + let source_dims = source.dims(); + let self_dims = self.dims(); + let mismatch = if source_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (self, src)", + lhs: self.shape().clone(), + rhs: source.shape().clone(), + })? + } + if indexes.dims() != source.dims() { + Err(Error::ShapeMismatchBinaryOp { + op: "scatter-add (indexes, src)", + lhs: indexes.shape().clone(), + rhs: source.shape().clone(), + })? + } + let storage = self.storage().scatter_add( + self.layout(), + &indexes.storage(), + indexes.layout(), + &source.storage(), + source.layout(), + dim, + )?; + let op = if indexes.track_op() || self.track_op() { + Some(Op::ScatterAdd( + self.clone(), + indexes.clone(), + source.clone(), + dim, + )) + } else { + None + }; + Ok(from_storage(storage, self.shape(), op, false)) + } + pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { let dim = dim.to_index(self.shape(), "index-add")?; let source_dims = source.dims(); @@ -992,6 +1043,40 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + pub fn gather<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> { + let dim = dim.to_index(self.shape(), "gather")?; + let self_dims = self.dims(); + let indexes_dims = indexes.dims(); + let mismatch = if indexes_dims.len() != self_dims.len() { + true + } else { + let mut mismatch = false; + for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() { + if i != dim && d1 != d2 { + mismatch = true; + break; + } + } + mismatch + }; + if mismatch { + Err(Error::ShapeMismatchBinaryOp { + op: "gather", + lhs: self.shape().clone(), + rhs: indexes.shape().clone(), + })? + } + let storage = + self.storage() + .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; + let op = if indexes.track_op() || self.track_op() { + Some(Op::Gather(self.clone(), indexes.clone(), dim)) + } else { + None + }; + Ok(from_storage(storage, indexes.shape(), op, false)) + } + pub fn index_select<D: Dim>(&self, indexes: &Self, dim: D) -> Result<Self> { let dim = dim.to_index(self.shape(), "index-select")?; let indexes_len = match indexes.dims() { |