diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-15 23:15:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 22:15:40 +0100 |
commit | 635012d770a75033081008a22044804d277fafa8 (patch) | |
tree | 8c08cf9c9bbddb2d4904477c817d2db4be5877ac /candle-core/src/backprop.rs | |
parent | 3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63 (diff) | |
download | candle-635012d770a75033081008a22044804d277fafa8.tar.gz candle-635012d770a75033081008a22044804d277fafa8.tar.bz2 candle-635012d770a75033081008a22044804d277fafa8.zip |
Do not backprop through argmin/argmax. (#865)
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r-- | candle-core/src/backprop.rs | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index b930a9f4..9c8f685f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -98,7 +98,7 @@ impl Tensor { | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) - | Op::Reduce(node, _, _) + | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -112,6 +112,7 @@ impl Tensor { track_grad |= tg; nodes } + Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { nodes @@ -521,6 +522,7 @@ impl Tensor { } } +#[derive(Debug)] pub struct GradStore(HashMap<TensorId, Tensor>); impl GradStore { |