From 8f310cc66625dc19a6f3f269edda0ba6b09a6f52 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 14 Oct 2023 22:03:41 +0100 Subject: Avoid trying to backprop through non-differentiable layers. (#1094) --- candle-core/src/backprop.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'candle-core') diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 16b9cfd9..dfad5f62 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -36,6 +36,8 @@ impl Tensor { // Do not call recursively on the "leaf" nodes. track_grad = true; nodes + } else if node.dtype().is_int() { + nodes } else if let Some(op) = node.op() { match op { Op::IndexAdd(t1, t2, t3, _) @@ -103,7 +105,6 @@ impl Tensor { | Op::Broadcast(node) | Op::Cmp(node, _) | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) - | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Permute(node, _) @@ -116,6 +117,15 @@ impl Tensor { track_grad |= tg; nodes } + Op::ToDType(node) => { + if node.dtype().is_float() { + let (tg, nodes) = walk(node, nodes, already_seen); + track_grad |= tg; + nodes + } else { + nodes + } + } Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { @@ -374,7 +384,7 @@ impl Tensor { } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; - *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? + *sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)? } Op::Copy(arg) => { let sum_grad = grads.or_insert(arg)?; -- cgit v1.2.3