diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-12-30 17:06:07 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-30 17:06:07 +0100 |
commit | a0facd0e67b546215ea62b53dc28a1cb2e6dcd47 (patch) | |
tree | 6b980a61b55276cf51c4eb0b40254021ba2f38de /candle-nn | |
parent | 4290b8124479fd0ac2c2eedf0cf8c65dcee4a702 (diff) | |
download | candle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.tar.gz candle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.tar.bz2 candle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.zip |
Small tweaks to batch-norm. (#1505)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/batch_norm.rs | 35 |
1 files changed, 16 insertions, 19 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 2b415e90..1782e47a 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,7 +7,6 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use crate::Init; use candle::{DType, Module, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] @@ -92,7 +91,6 @@ impl BatchNorm { ) } } - Ok(()) } @@ -217,34 +215,32 @@ impl BatchNorm { let x = x.to_dtype(internal_dtype)?; let x = x.transpose(0, 1)?; let x_dims_post_transpose = x.dims(); + // Flatten all the dimensions exception the channel one as this performs a Spatial Batch + // Normalization. let x = x.flatten_from(1)?.contiguous()?; let x = if self.remove_mean { + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let mean_x = x.mean_keepdim(1)?; - { - // Update running mean - let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? - + (mean_x.flatten_all()? * self.momentum)?)?; - - self.running_mean.set(&new_mean)?; - } + let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? + + (mean_x.flatten_all()? * self.momentum)?)?; + self.running_mean.set(&updated_running_mean)?; x.broadcast_sub(&mean_x)? } else { x }; + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let norm_x = x.sqr()?.mean_keepdim(1)?; - { - // Update running variance + let updated_running_var = { let batch_size = x.dim(1)? as f64; let running_var_weight = 1.0 - self.momentum; let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); - - let new_var = ((self.running_var.as_tensor() * running_var_weight)? - + (&norm_x.flatten_all()? * norm_x_weight)?)?; - - self.running_var.set(&new_var)?; - } - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed.to_dtype(x_dtype)?; + ((self.running_var.as_tensor() * running_var_weight)? + + (&norm_x.flatten_all()? * norm_x_weight)?)? + }; + self.running_var.set(&updated_running_var)?; + let x = x + .broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)?; let x = match &self.weight_and_bias { None => x, Some((weight, bias)) => { @@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>( config: C, vb: crate::VarBuilder, ) -> Result<BatchNorm> { + use crate::Init; let config = config.into(); if config.eps < 0. { candle::bail!("batch-norm eps cannot be negative {}", config.eps) |