diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/backprop.rs | 12 | ||||
-rw-r--r-- | candle-core/src/op.rs | 1 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 71 |
3 files changed, 83 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a2548198..67207dce 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -69,7 +69,8 @@ impl Tensor { | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) - | Op::Matmul(lhs, rhs) => { + | Op::Matmul(lhs, rhs) + | Op::SliceScatter0(lhs, rhs, _) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -270,6 +271,15 @@ impl Tensor { Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, + Op::SliceScatter0(lhs, rhs, start_rhs) => { + let rhs_sum_grad = grads.or_insert(rhs)?; + let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + + let lhs_sum_grad = grads.or_insert(lhs)?; + let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + } Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 4882a205..3083d2c8 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -133,6 +133,7 @@ pub enum Op { Copy(Tensor), Broadcast(Tensor), Narrow(Tensor, usize, usize, usize), + SliceScatter0(Tensor, Tensor, usize), Reshape(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 9dccf2b5..d3337e16 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1132,6 +1132,74 @@ impl Tensor { Ok(from_storage(storage, self.shape(), op, false)) } + /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension. + pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> { + let dim = dim.to_index(self.shape(), "slice-scatter")?; + if dim == 0 { + self.slice_scatter0(src, start) + } else { + // TODO: Maybe we want to add a more efficient implementation at some point. + self.transpose(0, dim)? + .slice_scatter0(&src.transpose(0, dim)?, start)? + .transpose(0, dim) + } + } + + /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension. + pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> { + if self.dtype() != src.dtype() { + Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: src.dtype(), + op: "slice-scatter", + } + .bt())? + } + if self.device().location() != src.device.location() { + Err(Error::DeviceMismatchBinaryOp { + lhs: self.device().location(), + rhs: src.device().location(), + op: "slice-scatter", + } + .bt())? + } + if self.rank() != src.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: self.rank(), + got: src.rank(), + shape: src.shape().clone(), + } + .bt())? + } + let shape_ok = + self.dims() + .iter() + .zip(src.dims().iter()) + .enumerate() + .all(|(dim_idx, (&d1, &d2))| { + if 0 == dim_idx { + d2 + start <= d1 + } else { + d1 == d2 + } + }); + if !shape_ok { + Err(Error::ShapeMismatchBinaryOp { + op: "slice-scatter (self, src)", + lhs: self.shape().clone(), + rhs: src.shape().clone(), + })? + } + let mut storage = self.device().zeros(self.shape(), self.dtype())?; + self.storage() + .copy_strided_src(&mut storage, 0, self.layout())?; + let offset = start * src.dims()[1..].iter().product::<usize>(); + src.storage() + .copy_strided_src(&mut storage, offset, src.layout())?; + let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start)); + Ok(from_storage(storage, self.shape(), op, false)) + } + /// Accumulate element from `source` at indexes `indexes` and add them to `self`. pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> { let dim = dim.to_index(self.shape(), "index-add")?; @@ -1548,6 +1616,9 @@ impl Tensor { pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> { let dim1 = dim1.to_index(self.shape(), "transpose")?; let dim2 = dim2.to_index(self.shape(), "transpose")?; + if dim1 == dim2 { + return Ok(self.clone()); + } let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2)); let tensor_ = Tensor_ { id: TensorId::new(), |