diff options
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index f39eedbb..65d91849 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -112,7 +112,8 @@ impl Tensor { } Op::Unary(_node, UnaryOp::Ceil) | Op::Unary(_node, UnaryOp::Floor) - | Op::Unary(_node, UnaryOp::Round) => nodes, + | Op::Unary(_node, UnaryOp::Round) + | Op::Unary(_node, UnaryOp::Sign) => nodes, Op::Reshape(node) | Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. } @@ -488,7 +489,6 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad)?; } - Op::Cmp(_args, _) => {} Op::Reduce(arg, ReduceOp::Max, reduced_dims) => { let node = broadcast_back(arg, node, reduced_dims)?; let grad = broadcast_back(arg, &grad, reduced_dims)?; @@ -578,20 +578,18 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Reduce(_, ReduceOp::ArgMin, _) => {} - Op::Reduce(_, ReduceOp::ArgMax, _) => {} + Op::Unary(_, UnaryOp::Floor) + | Op::Unary(_, UnaryOp::Round) + | Op::Reduce(_, ReduceOp::ArgMin, _) + | Op::Reduce(_, ReduceOp::ArgMax, _) + | Op::Unary(_, UnaryOp::Sign) + | Op::Cmp(_, _) => {} Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?, - Op::Unary(_, UnaryOp::Floor) => { - Err(Error::BackwardNotSupported { op: "floor" })? - } - Op::Unary(_, UnaryOp::Round) => { - Err(Error::BackwardNotSupported { op: "round" })? - } Op::Unary(arg, UnaryOp::Gelu) => { let sum_grad = grads.or_insert(arg)?; let cube = arg.powf(3.)?; |