summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs12
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/tensor.rs71
3 files changed, 83 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index a2548198..67207dce 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -69,7 +69,8 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
- | Op::Matmul(lhs, rhs) => {
+ | Op::Matmul(lhs, rhs)
+ | Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -270,6 +271,15 @@ impl Tensor {
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
+ Op::SliceScatter0(lhs, rhs, start_rhs) => {
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
+ }
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 4882a205..3083d2c8 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -133,6 +133,7 @@ pub enum Op {
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
+ SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 9dccf2b5..d3337e16 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1132,6 +1132,74 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
+ /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
+ pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "slice-scatter")?;
+ if dim == 0 {
+ self.slice_scatter0(src, start)
+ } else {
+ // TODO: Maybe we want to add a more efficient implementation at some point.
+ self.transpose(0, dim)?
+ .slice_scatter0(&src.transpose(0, dim)?, start)?
+ .transpose(0, dim)
+ }
+ }
+
+ /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
+ pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
+ if self.dtype() != src.dtype() {
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: self.dtype(),
+ rhs: src.dtype(),
+ op: "slice-scatter",
+ }
+ .bt())?
+ }
+ if self.device().location() != src.device.location() {
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: self.device().location(),
+ rhs: src.device().location(),
+ op: "slice-scatter",
+ }
+ .bt())?
+ }
+ if self.rank() != src.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: self.rank(),
+ got: src.rank(),
+ shape: src.shape().clone(),
+ }
+ .bt())?
+ }
+ let shape_ok =
+ self.dims()
+ .iter()
+ .zip(src.dims().iter())
+ .enumerate()
+ .all(|(dim_idx, (&d1, &d2))| {
+ if 0 == dim_idx {
+ d2 + start <= d1
+ } else {
+ d1 == d2
+ }
+ });
+ if !shape_ok {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "slice-scatter (self, src)",
+ lhs: self.shape().clone(),
+ rhs: src.shape().clone(),
+ })?
+ }
+ let mut storage = self.device().zeros(self.shape(), self.dtype())?;
+ self.storage()
+ .copy_strided_src(&mut storage, 0, self.layout())?;
+ let offset = start * src.dims()[1..].iter().product::<usize>();
+ src.storage()
+ .copy_strided_src(&mut storage, offset, src.layout())?;
+ let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
+ Ok(from_storage(storage, self.shape(), op, false))
+ }
+
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
@@ -1548,6 +1616,9 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
+ if dim1 == dim2 {
+ return Ok(self.clone());
+ }
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ {
id: TensorId::new(),