diff options
author | nkoppel <nathankoppel0@gmail.com> | 2023-12-30 15:42:08 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-30 16:42:08 +0100 |
commit | 4290b8124479fd0ac2c2eedf0cf8c65dcee4a702 (patch) | |
tree | 1b7174466430a9b42f02e5720ae79716e4db75d7 /candle-nn | |
parent | 51e577a682ab9497d6022b4080f3b54bbbd75f1b (diff) | |
download | candle-4290b8124479fd0ac2c2eedf0cf8c65dcee4a702.tar.gz candle-4290b8124479fd0ac2c2eedf0cf8c65dcee4a702.tar.bz2 candle-4290b8124479fd0ac2c2eedf0cf8c65dcee4a702.zip |
[Breaking] Add training to batchnorm with exponential moving average (#1504)
* Add training to batchnorm with exponential moving average
* Add more checks to batch norm
* Resolve some review comments
* Add with_momentum varients of `new` methods
* Add check for range of momentum variable; update batch norm test
* Run cargo fmt
* Add back num_features parameter
* Format; tiny simplification
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/batch_norm.rs | 208 | ||||
-rw-r--r-- | candle-nn/tests/batch_norm.rs | 11 |
2 files changed, 169 insertions, 50 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 8cfc6740..2b415e90 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,15 +7,22 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use candle::{DType, Result, Tensor}; +use crate::Init; +use candle::{DType, Module, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct BatchNormConfig { pub eps: f64, pub remove_mean: bool, + /// The meaning of affine here is different from LayerNorm: when false there is no learnable /// parameter at all, 1 used for gamma and 0 for beta. pub affine: bool, + + /// Controls exponential moving average of running stats. Defaults to 0.1 + /// + /// `running_stat * (1.0 - momentum) + stat * momentum`. + pub momentum: f64, } impl Default for BatchNormConfig { @@ -24,6 +31,7 @@ impl Default for BatchNormConfig { eps: 1e-5, remove_mean: true, affine: true, + momentum: 0.1, } } } @@ -32,23 +40,62 @@ impl From<f64> for BatchNormConfig { fn from(eps: f64) -> Self { Self { eps, - remove_mean: true, - affine: true, + ..Default::default() } } } #[derive(Clone, Debug)] pub struct BatchNorm { - running_mean: Tensor, - running_var: Tensor, + running_mean: Var, + running_var: Var, weight_and_bias: Option<(Tensor, Tensor)>, remove_mean: bool, eps: f64, - num_features: usize, + momentum: f64, } impl BatchNorm { + fn check_validity(&self, num_features: usize) -> Result<()> { + if self.eps < 0. { + candle::bail!("batch-norm eps cannot be negative {}", self.eps) + } + if !(0.0..=1.0).contains(&self.momentum) { + candle::bail!( + "batch-norm momentum must be between 0 and 1, is {}", + self.momentum + ) + } + if self.running_mean.dims() != [num_features] { + candle::bail!( + "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]", + self.running_mean.shape(), + ) + } + if self.running_var.dims() != [num_features] { + candle::bail!( + "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]", + self.running_var.shape(), + ) + } + if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() { + if weight.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + weight.shape(), + ) + } + if bias.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + bias.shape(), + ) + } + } + + Ok(()) + } + pub fn new( num_features: usize, running_mean: Tensor, @@ -57,29 +104,16 @@ impl BatchNorm { bias: Tensor, eps: f64, ) -> Result<Self> { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - if weight.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected weight shape {:?} {num_features}", - weight.shape() - ) - } - if bias.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected bias shape {:?} {num_features}", - bias.shape() - ) - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: Some((weight, bias)), remove_mean: true, eps, - num_features, - }) + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn new_no_bias( @@ -88,25 +122,64 @@ impl BatchNorm { running_var: Tensor, eps: f64, ) -> Result<Self> { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: None, + remove_mean: true, + eps, + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + weight: Tensor, + bias: Tensor, + eps: f64, + momentum: f64, + ) -> Result<Self> { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_no_bias_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + eps: f64, + momentum: f64, + ) -> Result<Self> { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: None, remove_mean: true, eps, - num_features, - }) + momentum, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn running_mean(&self) -> &Tensor { - &self.running_mean + self.running_mean.as_tensor() } pub fn running_var(&self) -> &Tensor { - &self.running_var + self.running_var.as_tensor() } pub fn eps(&self) -> f64 { @@ -117,7 +190,12 @@ impl BatchNorm { self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) } + pub fn momentum(&self) -> f64 { + self.momentum + } + pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> { + let num_features = self.running_mean.as_tensor().dim(0)?; let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, @@ -129,11 +207,11 @@ impl BatchNorm { x.shape() ) } - if x.dim(1)? != self.num_features { + if x.dim(1)? != num_features { candle::bail!( "batch-norm input doesn't have the expected number of features ({:?} <> {})", x.shape(), - self.num_features + num_features ) } let x = x.to_dtype(internal_dtype)?; @@ -142,26 +220,52 @@ impl BatchNorm { let x = x.flatten_from(1)?.contiguous()?; let x = if self.remove_mean { let mean_x = x.mean_keepdim(1)?; + { + // Update running mean + let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? + + (mean_x.flatten_all()? * self.momentum)?)?; + + self.running_mean.set(&new_mean)?; + } x.broadcast_sub(&mean_x)? } else { x }; let norm_x = x.sqr()?.mean_keepdim(1)?; + { + // Update running variance + let batch_size = x.dim(1)? as f64; + let running_var_weight = 1.0 - self.momentum; + let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); + + let new_var = ((self.running_var.as_tensor() * running_var_weight)? + + (&norm_x.flatten_all()? * norm_x_weight)?)?; + + self.running_var.set(&new_var)?; + } let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?; let x = match &self.weight_and_bias { None => x, Some((weight, bias)) => { - let weight = weight.reshape((self.num_features, 1))?; - let bias = bias.reshape((self.num_features, 1))?; + let weight = weight.reshape(((), 1))?; + let bias = bias.reshape(((), 1))?; x.broadcast_mul(&weight)?.broadcast_add(&bias)? } }; x.reshape(x_dims_post_transpose)?.transpose(0, 1) } + + pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> { + if train { + self.forward_learning(x) + } else { + self.forward(x) + } + } } -impl crate::Module for BatchNorm { +impl Module for BatchNorm { fn forward(&self, x: &Tensor) -> Result<Tensor> { let target_shape: Vec<usize> = x .dims() @@ -170,9 +274,13 @@ impl crate::Module for BatchNorm { .map(|(idx, v)| if idx == 1 { *v } else { 1 }) .collect(); let target_shape = target_shape.as_slice(); + let x = x - .broadcast_sub(&self.running_mean.reshape(target_shape)?)? - .broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?; + .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? + .broadcast_div( + &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, + )?; + match &self.weight_and_bias { None => Ok(x), Some((weight, bias)) => { @@ -193,21 +301,21 @@ pub fn batch_norm<C: Into<BatchNormConfig>>( if config.eps < 0. { candle::bail!("batch-norm eps cannot be negative {}", config.eps) } - let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?; - let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?; + let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?; + let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?; let weight_and_bias = if config.affine { - let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?; - let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?; + let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?; + let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?; Some((weight, bias)) } else { None }; Ok(BatchNorm { - running_mean, - running_var, + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias, remove_mean: config.remove_mean, eps: config.eps, - num_features, + momentum: config.momentum, }) } diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 5bbaf238..73a38545 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4) output = m(input) print(input.flatten()) print(output.flatten()) +print(m.running_mean) +print(m.running_var) */ #[test] fn batch_norm() -> Result<()> { @@ -71,5 +73,14 @@ fn batch_norm() -> Result<()> { let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; let sum_diff2 = diff2.sum_keepdim(0)?; assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); + + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + &[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020] + ); + assert_eq!( + test_utils::to_vec1_round(bn.running_var(), 4)?, + &[0.9972, 0.9842, 0.9956, 0.9866, 0.9898] + ); Ok(()) } |