summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-14 22:03:41 +0100
committerGitHub <noreply@github.com>2023-10-14 22:03:41 +0100
commit8f310cc66625dc19a6f3f269edda0ba6b09a6f52 (patch)
treebfdd781880118d212ad2a8f4542081a5b813d7aa /candle-core
parent8921d5027cee21b8d438af5c1b3ae120b21a3acc (diff)
downloadcandle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.tar.gz
candle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.tar.bz2
candle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.zip
Avoid trying to backprop through non-differentiable layers. (#1094)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backprop.rs14
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 16b9cfd9..dfad5f62 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -36,6 +36,8 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
+ } else if node.dtype().is_int() {
+ nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
@@ -103,7 +105,6 @@ impl Tensor {
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
- | Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
@@ -116,6 +117,15 @@ impl Tensor {
track_grad |= tg;
nodes
}
+ Op::ToDType(node) => {
+ if node.dtype().is_float() {
+ let (tg, nodes) = walk(node, nodes, already_seen);
+ track_grad |= tg;
+ nodes
+ } else {
+ nodes
+ }
+ }
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
@@ -374,7 +384,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
- *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
+ *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;