diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a2548198..67207dce 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -69,7 +69,8 @@ impl Tensor { | Op::Binary(lhs, rhs, _) | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) - | Op::Matmul(lhs, rhs) => { + | Op::Matmul(lhs, rhs) + | Op::SliceScatter0(lhs, rhs, _) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -270,6 +271,15 @@ impl Tensor { Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported { op: "upsample-nearest2d", })?, + Op::SliceScatter0(lhs, rhs, start_rhs) => { + let rhs_sum_grad = grads.or_insert(rhs)?; + let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + + let lhs_sum_grad = grads.or_insert(lhs)?; + let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)? + } Op::Gather(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; |