summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/backprop.rs41
1 files changed, 30 insertions, 11 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 24da23a2..d6beb70e 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -169,8 +169,22 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
- Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
- Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
+ Op::ScatterAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.gather(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
+ Op::IndexAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.index_select(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
@@ -228,7 +242,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
- Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
+ Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@@ -268,7 +282,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
- Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
+ Op::Unary(arg, UnaryOp::Abs) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let ones = arg.ones_like()?;
+ let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
+ *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
+ }
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
@@ -303,12 +322,8 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Reduce(_, ReduceOp::ArgMin, _) => {
- Err(Error::BackwardNotSupported { op: "argmin" })?
- }
- Op::Reduce(_, ReduceOp::ArgMax, _) => {
- Err(Error::BackwardNotSupported { op: "argmax" })?
- }
+ Op::Reduce(_, ReduceOp::ArgMin, _) => {}
+ Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
@@ -316,7 +331,11 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
- Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
+ Op::Unary(arg, UnaryOp::Relu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
+ *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
+ }
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {