summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-11-05 14:07:41 +0100
committerGitHub <noreply@github.com>2023-11-05 14:07:41 +0100
commit60fdab4e17d3e420f20610ec75df3deccd8e1f69 (patch)
treee479708736d4e1f630ab09c7e77ba3a9111d0d59 /candle-core
parent928a9d906e7de51f1d9a458b417f26865ebc9c41 (diff)
downloadcandle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.tar.gz
candle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.tar.bz2
candle-60fdab4e17d3e420f20610ec75df3deccd8e1f69.zip
Detach all grads during backprop. (#1243)
* Detach all grads during backprop. * Add an environment variable to select the backprop behavior. * Update the comment.
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backprop.rs25
1 files changed, 21 insertions, 4 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 1448a6f4..fc0c79a2 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -15,6 +15,17 @@ fn broadcast_back(arg: &Tensor, node: &Tensor, reduced_dims: &[usize]) -> Result
}
}
+thread_local! {
+ static CANDLE_GRAD_DO_NOT_DETACH: bool = {
+ match std::env::var("CANDLE_GRAD_DO_NOT_DETACH") {
+ Ok(s) => {
+ !s.is_empty() && s != "0"
+ },
+ Err(_) => false,
+ }
+ }
+}
+
impl Tensor {
/// Return all the nodes that lead to this value in a topologically sorted vec, the first
/// elements having dependencies on the latter ones, e.g. the first element if any is the
@@ -155,10 +166,16 @@ impl Tensor {
if node.is_variable() {
continue;
}
- let grad = grads.remove(node).unwrap();
- // TODO: We should perform all these operations in place (or at least not track the
- // whole graph). The only drawback would be if we wanted to support grad of grad but
- // this is out of scope.
+ let grad = grads
+ .remove(node)
+ .expect("candle internal error - grad not populated");
+ // https://github.com/huggingface/candle/issues/1241
+ // Ideally, we would make these operations in place where possible to ensure that we
+ // do not have to allocate too often. Here we just call `.detach` to avoid computing
+ // the backprop graph of the backprop itself. This would be an issue for second order
+ // derivatives but these are out of scope at the moment.
+ let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
+ let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {