From ad73e93da2cf7311cb5c5bc39250aa335c5f9b76 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 13 Feb 2024 14:26:32 +0100 Subject: Detach the tensors on batch-norm eval. (#1702) * Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black. --- candle-nn/src/batch_norm.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'candle-nn') diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 856c2c7a..4c67961d 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -262,9 +262,19 @@ impl BatchNorm { let target_shape = target_shape.as_slice(); let x = x - .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? + .broadcast_sub( + &self + .running_mean + .as_detached_tensor() + .reshape(target_shape)?, + )? .broadcast_div( - &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, + &(self + .running_var + .as_detached_tensor() + .reshape(target_shape)? + + self.eps)? + .sqrt()?, )?; match &self.weight_and_bias { -- cgit v1.2.3