diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 678dbabd..38898b7b 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -39,6 +39,7 @@ impl Tensor { } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) + | Op::ScatterAdd(t1, t2, t3, _) | Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => { let (tg, nodes) = walk(t1, nodes, already_seen); @@ -56,6 +57,7 @@ impl Tensor { } | Op::CustomOp2(lhs, rhs, _) | Op::Binary(lhs, rhs, _) + | Op::Gather(lhs, rhs, _) | Op::IndexSelect(lhs, rhs, _) | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { @@ -162,6 +164,11 @@ impl Tensor { *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?, + Op::Gather(arg, indexes, dim) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?; + } + Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?, Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?, Op::IndexSelect(arg, indexes, dim) => { let sum_grad = grads.or_insert(arg)?; |