summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs18
1 files changed, 8 insertions, 10 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index f39eedbb..65d91849 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -112,7 +112,8 @@ impl Tensor {
}
Op::Unary(_node, UnaryOp::Ceil)
| Op::Unary(_node, UnaryOp::Floor)
- | Op::Unary(_node, UnaryOp::Round) => nodes,
+ | Op::Unary(_node, UnaryOp::Round)
+ | Op::Unary(_node, UnaryOp::Sign) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest2D { arg: node, .. }
@@ -488,7 +489,6 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad)?;
}
- Op::Cmp(_args, _) => {}
Op::Reduce(arg, ReduceOp::Max, reduced_dims) => {
let node = broadcast_back(arg, node, reduced_dims)?;
let grad = broadcast_back(arg, &grad, reduced_dims)?;
@@ -578,20 +578,18 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
- Op::Reduce(_, ReduceOp::ArgMin, _) => {}
- Op::Reduce(_, ReduceOp::ArgMax, _) => {}
+ Op::Unary(_, UnaryOp::Floor)
+ | Op::Unary(_, UnaryOp::Round)
+ | Op::Reduce(_, ReduceOp::ArgMin, _)
+ | Op::Reduce(_, ReduceOp::ArgMax, _)
+ | Op::Unary(_, UnaryOp::Sign)
+ | Op::Cmp(_, _) => {}
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
- Op::Unary(_, UnaryOp::Floor) => {
- Err(Error::BackwardNotSupported { op: "floor" })?
- }
- Op::Unary(_, UnaryOp::Round) => {
- Err(Error::BackwardNotSupported { op: "round" })?
- }
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;