summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 21:31:54 +0100
committerGitHub <noreply@github.com>2023-07-26 21:31:54 +0100
commit89ba005962495f2bfbda286e185e9c3c7f5300a3 (patch)
tree70aedcde0bb8c89d6930a977bb2e67a32c9e7738 /candle-core/src/backprop.rs
parent4f92420132d831e5d344f974c263c9f341e50906 (diff)
downloadcandle-89ba005962495f2bfbda286e185e9c3c7f5300a3.tar.gz
candle-89ba005962495f2bfbda286e185e9c3c7f5300a3.tar.bz2
candle-89ba005962495f2bfbda286e185e9c3c7f5300a3.zip
Support backprop for a few more ops. (#254)
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs41
1 files changed, 30 insertions, 11 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 24da23a2..d6beb70e 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -169,8 +169,22 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
}
- Op::ScatterAdd(..) => Err(Error::BackwardNotSupported { op: "scatter-add" })?,
- Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
+ Op::ScatterAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.gather(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
+ Op::IndexAdd(init, indexes, src, dim) => {
+ let init_sum_grad = grads.or_insert(init)?;
+ *init_sum_grad = init_sum_grad.add(&grad)?;
+
+ let src_grad = grad.index_select(indexes, *dim)?;
+ let src_sum_grad = grads.or_insert(src)?;
+ *src_sum_grad = src_sum_grad.add(&src_grad)?;
+ }
Op::IndexSelect(arg, indexes, dim) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.index_add(indexes, &grad, *dim)?;
@@ -228,7 +242,7 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
- Op::Cmp(_args, _) => return Err(Error::BackwardNotSupported { op: "cmp" }),
+ Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@@ -268,7 +282,12 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
- Op::Unary(_, UnaryOp::Abs) => Err(Error::BackwardNotSupported { op: "abs" })?,
+ Op::Unary(arg, UnaryOp::Abs) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let ones = arg.ones_like()?;
+ let abs_grad = arg.ge(&arg.zeros_like()?)?.where_cond(&ones, &ones.neg()?);
+ *sum_grad = sum_grad.add(&(&grad * abs_grad)?)?
+ }
Op::Unary(arg, UnaryOp::Exp) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
@@ -303,12 +322,8 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Reduce(_, ReduceOp::ArgMin, _) => {
- Err(Error::BackwardNotSupported { op: "argmin" })?
- }
- Op::Reduce(_, ReduceOp::ArgMax, _) => {
- Err(Error::BackwardNotSupported { op: "argmax" })?
- }
+ Op::Reduce(_, ReduceOp::ArgMin, _) => {}
+ Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
@@ -316,7 +331,11 @@ impl Tensor {
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
- Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
+ Op::Unary(arg, UnaryOp::Relu) => {
+ let sum_grad = grads.or_insert(arg)?;
+ let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
+ *sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
+ }
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
Op::CustomOp1(arg, c) => {
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {