summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs12
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index a2548198..67207dce 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -69,7 +69,8 @@ impl Tensor {
| Op::Binary(lhs, rhs, _)
| Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
- | Op::Matmul(lhs, rhs) => {
+ | Op::Matmul(lhs, rhs)
+ | Op::SliceScatter0(lhs, rhs, _) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@@ -270,6 +271,15 @@ impl Tensor {
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
+ Op::SliceScatter0(lhs, rhs, start_rhs) => {
+ let rhs_sum_grad = grads.or_insert(rhs)?;
+ let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
+ *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
+
+ let lhs_sum_grad = grads.or_insert(lhs)?;
+ let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
+ *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
+ }
Op::Gather(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;