summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backprop.rs12
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/tensor.rs71
-rw-r--r--candle-core/tests/grad_tests.rs16
-rw-r--r--candle-core/tests/tensor_tests.rs43
5 files changed, 142 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index a2548198..67207dce 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -69,7 +69,8 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
- | Op::Matmul(lhs, rhs) => {
+ | Op::Matmul(lhs, rhs)
+ | Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -270,6 +271,15 @@ impl Tensor {
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
+ Op::SliceScatter0(lhs, rhs, start_rhs) => {
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
+ }
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 4882a205..3083d2c8 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -133,6 +133,7 @@ pub enum Op {
Copy(Tensor),
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
+ SliceScatter0(Tensor, Tensor, usize),
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 9dccf2b5..d3337e16 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1132,6 +1132,74 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
+ /// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
+ pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> {
+ let dim = dim.to_index(self.shape(), "slice-scatter")?;
+ if dim == 0 {
+ self.slice_scatter0(src, start)
+ } else {
+ // TODO: Maybe we want to add a more efficient implementation at some point.
+ self.transpose(0, dim)?
+ .slice_scatter0(&src.transpose(0, dim)?, start)?
+ .transpose(0, dim)
+ }
+ }
+
+ /// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
+ pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
+ if self.dtype() != src.dtype() {
+ Err(Error::DTypeMismatchBinaryOp {
+ lhs: self.dtype(),
+ rhs: src.dtype(),
+ op: "slice-scatter",
+ }
+ .bt())?
+ }
+ if self.device().location() != src.device.location() {
+ Err(Error::DeviceMismatchBinaryOp {
+ lhs: self.device().location(),
+ rhs: src.device().location(),
+ op: "slice-scatter",
+ }
+ .bt())?
+ }
+ if self.rank() != src.rank() {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: self.rank(),
+ got: src.rank(),
+ shape: src.shape().clone(),
+ }
+ .bt())?
+ }
+ let shape_ok =
+ self.dims()
+ .iter()
+ .zip(src.dims().iter())
+ .enumerate()
+ .all(|(dim_idx, (&d1, &d2))| {
+ if 0 == dim_idx {
+ d2 + start <= d1
+ } else {
+ d1 == d2
+ }
+ });
+ if !shape_ok {
+ Err(Error::ShapeMismatchBinaryOp {
+ op: "slice-scatter (self, src)",
+ lhs: self.shape().clone(),
+ rhs: src.shape().clone(),
+ })?
+ }
+ let mut storage = self.device().zeros(self.shape(), self.dtype())?;
+ self.storage()
+ .copy_strided_src(&mut storage, 0, self.layout())?;
+ let offset = start * src.dims()[1..].iter().product::<usize>();
+ src.storage()
+ .copy_strided_src(&mut storage, offset, src.layout())?;
+ let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
+ Ok(from_storage(storage, self.shape(), op, false))
+ }
+
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
@@ -1548,6 +1616,9 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
+ if dim1 == dim2 {
+ return Ok(self.clone());
+ }
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ {
id: TensorId::new(),
diff --git a/candle-core/tests/grad_tests.rs b/candle-core/tests/grad_tests.rs
index ad09c90f..2a70cfc4 100644
--- a/candle-core/tests/grad_tests.rs
+++ b/candle-core/tests/grad_tests.rs
@@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
let grad_x = grads.get(x).context("no grad for x")?;
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
+
+ let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
+ let x = x_var.as_tensor();
+ let y_var = Var::new(&[2f32, 7., 1.], device)?;
+ let y = y_var.as_tensor();
+
+ let ss = x
+ .reshape((2, 3))?
+ .slice_scatter0(&y.reshape((1, 3))?, 1)?
+ .sqr()?;
+ let grads = ss.backward()?;
+ let grad_x = grads.get(x).context("no grad for x")?;
+ let grad_y = grads.get(y).context("no grad for y")?;
+ assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
+ assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
+ assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
Ok(())
}
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index edd0bd79..dbe0dd6a 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -674,6 +674,48 @@ fn index_add(device: &Device) -> Result<()> {
Ok(())
}
+fn slice_scatter(device: &Device) -> Result<()> {
+ let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
+ assert_eq!(
+ t.to_vec2::<f32>()?,
+ &[
+ [0.0, 1.0, 2.0],
+ [3.0, 4.0, 5.0],
+ [6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0]
+ ]
+ );
+ let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
+ assert_eq!(
+ t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
+ &[
+ [100.0, 101.0, 102.0],
+ [103.0, 104.0, 105.0],
+ [6.0, 7.0, 8.0],
+ [9.0, 10.0, 11.0]
+ ]
+ );
+ assert_eq!(
+ t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
+ &[
+ [0.0, 1.0, 2.0],
+ [100.0, 101.0, 102.0],
+ [103.0, 104.0, 105.0],
+ [9.0, 10.0, 11.0]
+ ]
+ );
+ assert_eq!(
+ t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
+ &[
+ [0.0, 1.0, 2.0],
+ [3.0, 4.0, 5.0],
+ [100.0, 101.0, 102.0],
+ [103.0, 104.0, 105.0],
+ ]
+ );
+ Ok(())
+}
+
fn scatter_add(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
@@ -946,6 +988,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);
test_device!(gather, gather_cpu, gather_gpu);
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
+test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
test_device!(randn, randn_cpu, randn_gpu);
test_device!(clamp, clamp_cpu, clamp_gpu);