summaryrefslogtreecommitdiff
path: root/candle-core/src/backprop.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-15 23:15:40 +0200
committerGitHub <noreply@github.com>2023-09-15 22:15:40 +0100
commit635012d770a75033081008a22044804d277fafa8 (patch)
tree8c08cf9c9bbddb2d4904477c817d2db4be5877ac /candle-core/src/backprop.rs
parent3e49f8fce52c6b8f361bfd37d541a99b5e1f8c63 (diff)
downloadcandle-635012d770a75033081008a22044804d277fafa8.tar.gz
candle-635012d770a75033081008a22044804d277fafa8.tar.bz2
candle-635012d770a75033081008a22044804d277fafa8.zip
Do not backprop through argmin/argmax. (#865)
Diffstat (limited to 'candle-core/src/backprop.rs')
-rw-r--r--candle-core/src/backprop.rs4
1 files changed, 3 insertions, 1 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index b930a9f4..9c8f685f 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -98,7 +98,7 @@ impl Tensor {
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
- | Op::Reduce(node, _, _)
+ | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@@ -112,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@@ -521,6 +522,7 @@ impl Tensor {
}
}
+#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {