diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-14 22:03:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-14 22:03:41 +0100 |
commit | 8f310cc66625dc19a6f3f269edda0ba6b09a6f52 (patch) | |
tree | bfdd781880118d212ad2a8f4542081a5b813d7aa /candle-core | |
parent | 8921d5027cee21b8d438af5c1b3ae120b21a3acc (diff) | |
download | candle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.tar.gz candle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.tar.bz2 candle-8f310cc66625dc19a6f3f269edda0ba6b09a6f52.zip |
Avoid trying to backprop through non-differentiable layers. (#1094)
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 16b9cfd9..dfad5f62 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -36,6 +36,8 @@ impl Tensor { // Do not call recursively on the "leaf" nodes. track_grad = true; nodes + } else if node.dtype().is_int() { + nodes } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) @@ -103,7 +105,6 @@ impl Tensor { | Op::Broadcast(node) | Op::Cmp(node, _) | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) - | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Permute(node, _) @@ -116,6 +117,15 @@ impl Tensor { track_grad |= tg; nodes } + Op::ToDType(node) => { + if node.dtype().is_float() { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } else { + nodes + } + } Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { @@ -374,7 +384,7 @@ impl Tensor { } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? + *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)? } Op::Copy(arg) => { let sum_grad = grads.or_insert(arg)?; |