summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-02-13 14:26:32 +0100
committerGitHub <noreply@github.com>2024-02-13 14:26:32 +0100
commitad73e93da2cf7311cb5c5bc39250aa335c5f9b76 (patch)
tree5b5ea591d0fda870f4499869e3a8feb9718cfebf /candle-nn
parent13c67226e68de216b731707067f7e68af0438821 (diff)
downloadcandle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.gz
candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.tar.bz2
candle-ad73e93da2cf7311cb5c5bc39250aa335c5f9b76.zip
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/batch_norm.rs14
1 files changed, 12 insertions, 2 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 856c2c7a..4c67961d 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -262,9 +262,19 @@ impl BatchNorm {
let target_shape = target_shape.as_slice();
let x = x
- .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
+ .broadcast_sub(
+ &self
+ .running_mean
+ .as_detached_tensor()
+ .reshape(target_shape)?,
+ )?
.broadcast_div(
- &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
+ &(self
+ .running_var
+ .as_detached_tensor()
+ .reshape(target_shape)?
+ + self.eps)?
+ .sqrt()?,
)?;
match &self.weight_and_bias {