summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs7
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 678dbabd..38898b7b 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -39,6 +39,7 @@ impl Tensor {
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
+ | Op::ScatterAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
@@ -56,6 +57,7 @@ impl Tensor {
}
| Op::CustomOp2(lhs, rhs, _)
| Op::Binary(lhs, rhs, _)
+ | Op::Gather(lhs, rhs, _)
| Op::IndexSelect(lhs, rhs, _)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
@@ -162,6 +164,11 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
+ Op::Gather(arg, indexes, dim) => {
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
+ }
+ Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;