summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-01-01 10:13:13 +0100
committerGitHub <noreply@github.com>2024-01-01 10:13:13 +0100
commitb0fe5e4453bacc1aecf0049eaa424c39eb1771d4 (patch)
tree2f233da5fcf9f64cf3395f5cd6b8081e801eb7f7 /candle-nn
parent1fb2dd905cb49ce99b7a7c31f5d0809382bc12f3 (diff)
downloadcandle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.tar.gz
candle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.tar.bz2
candle-b0fe5e4453bacc1aecf0049eaa424c39eb1771d4.zip
Do not implement Module for BatchNorm. (#1513)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/batch_norm.rs26
-rw-r--r--candle-nn/tests/batch_norm.rs4
2 files changed, 15 insertions, 15 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 1782e47a..856c2c7a 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -7,7 +7,7 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
-use candle::{DType, Module, Result, Tensor, Var};
+use candle::{DType, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BatchNormConfig {
@@ -192,7 +192,7 @@ impl BatchNorm {
self.momentum
}
- pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
+ pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
let num_features = self.running_mean.as_tensor().dim(0)?;
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@@ -252,17 +252,7 @@ impl BatchNorm {
x.reshape(x_dims_post_transpose)?.transpose(0, 1)
}
- pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
- if train {
- self.forward_learning(x)
- } else {
- self.forward(x)
- }
- }
-}
-
-impl Module for BatchNorm {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
let target_shape: Vec<usize> = x
.dims()
.iter()
@@ -288,6 +278,16 @@ impl Module for BatchNorm {
}
}
+impl crate::ModuleT for BatchNorm {
+ fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
+ if train {
+ self.forward_train(x)
+ } else {
+ self.forward_eval(x)
+ }
+ }
+}
+
pub fn batch_norm<C: Into<BatchNormConfig>>(
num_features: usize,
config: C,
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs
index 73a38545..6fd7361a 100644
--- a/candle-nn/tests/batch_norm.rs
+++ b/candle-nn/tests/batch_norm.rs
@@ -39,7 +39,7 @@ fn batch_norm() -> Result<()> {
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
- let output = bn.forward_learning(&input)?;
+ let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
@@ -67,7 +67,7 @@ fn batch_norm() -> Result<()> {
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
- let output2 = bn2.forward_learning(&input)?;
+ let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;