diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-01-01 10:13:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-01 10:13:13 +0100 |
commit | b0fe5e4453bacc1aecf0049eaa424c39eb1771d4 (patch) | |
tree | 2f233da5fcf9f64cf3395f5cd6b8081e801eb7f7 /candle-nn | |
parent | 1fb2dd905cb49ce99b7a7c31f5d0809382bc12f3 (diff) | |
download | candle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.tar.gz candle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.tar.bz2 candle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.zip |
Do not implement Module for BatchNorm. (#1513)
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/batch_norm.rs | 26 | ||||
-rw-r--r-- | candle-nn/tests/batch_norm.rs | 4 |
2 files changed, 15 insertions, 15 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 1782e47a..856c2c7a 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,7 +7,7 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use candle::{DType, Module, Result, Tensor, Var}; +use candle::{DType, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct BatchNormConfig { @@ -192,7 +192,7 @@ impl BatchNorm { self.momentum } - pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> { + pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> { let num_features = self.running_mean.as_tensor().dim(0)?; let x_dtype = x.dtype(); let internal_dtype = match x_dtype { @@ -252,17 +252,7 @@ impl BatchNorm { x.reshape(x_dims_post_transpose)?.transpose(0, 1) } - pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> { - if train { - self.forward_learning(x) - } else { - self.forward(x) - } - } -} - -impl Module for BatchNorm { - fn forward(&self, x: &Tensor) -> Result<Tensor> { + fn forward_eval(&self, x: &Tensor) -> Result<Tensor> { let target_shape: Vec<usize> = x .dims() .iter() @@ -288,6 +278,16 @@ impl Module for BatchNorm { } } +impl crate::ModuleT for BatchNorm { + fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> { + if train { + self.forward_train(x) + } else { + self.forward_eval(x) + } + } +} + pub fn batch_norm<C: Into<BatchNormConfig>>( num_features: usize, config: C, diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 73a38545..6fd7361a 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> { 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, ]; let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; - let output = bn.forward_learning(&input)?; + let output = bn.forward_train(&input)?; assert_eq!(output.dims(), &[2, 5, 3, 4]); let output = output.flatten_all()?; assert_eq!( @@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> { Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, 1e-8, )?; - let output2 = bn2.forward_learning(&input)?; + let output2 = bn2.forward_train(&input)?; assert_eq!(output2.dims(), &[2, 5, 3, 4]); let output2 = output2.flatten_all()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; |