diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-11-05 14:07:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-05 14:07:41 +0100 |
commit | 60fdab4e17d3e420f20610ec75df3deccd8e1f69 (patch) | |
tree | e479708736d4e1f630ab09c7e77ba3a9111d0d59 /candle-core | |
parent | 928a9d906e7de51f1d9a458b417f26865ebc9c41 (diff) | |
download | candle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.tar.gz candle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.tar.bz2 candle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.zip |
Detach all grads during backprop. (#1243)
* Detach all grads during backprop.
* Add an environment variable to select the backprop behavior.
* Update the comment.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 1448a6f4..fc0c79a2 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result } } +thread_local! { + static CANDLE_GRAD_DO_NOT_DETACH: bool = { + match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") { + Ok(s) => { + !s.is_empty() && s != "0" + }, + Err(_) => false, + } + } +} + impl Tensor { /// Return all the nodes that lead to this value in a topologically sorted vec, the first /// elements having dependencies on the latter ones, e.g. the first element if any is the @@ -155,10 +166,16 @@ impl Tensor { if node.is_variable() { continue; } - let grad = grads.remove(node).unwrap(); - // TODO: We should perform all these operations in place (or at least not track the - // whole graph). The only drawback would be if we wanted to support grad of grad but - // this is out of scope. + let grad = grads + .remove(node) + .expect("candle internal error - grad not populated"); + // https://github.com/huggingface/candle/issues/1241 + // Ideally, we would make these operations in place where possible to ensure that we + // do not have to allocate too often. Here we just call `.detach` to avoid computing + // the backprop graph of the backprop itself. This would be an issue for second order + // derivatives but these are out of scope at the moment. + let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); + let grad = if do_not_detach { grad } else { grad.detach()? }; if let Some(op) = node.op() { match op { Op::Binary(lhs, rhs, BinaryOp::Add) => { |