summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-12-30 17:06:07 +0100
committerGitHub <noreply@github.com>2023-12-30 17:06:07 +0100
commita0facd0e67b546215ea62b53dc28a1cb2e6dcd47 (patch)
tree6b980a61b55276cf51c4eb0b40254021ba2f38de /candle-nn
parent4290b8124479fd0ac2c2eedf0cf8c65dcee4a702 (diff)
downloadcandle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.tar.gz
candle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.tar.bz2
candle-a0facd0e67b546215ea62b53dc28a1cb2e6dcd47.zip
Small tweaks to batch-norm. (#1505)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/src/batch_norm.rs35
1 files changed, 16 insertions, 19 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs
index 2b415e90..1782e47a 100644
--- a/candle-nn/src/batch_norm.rs
+++ b/candle-nn/src/batch_norm.rs
@@ -7,7 +7,6 @@
//! running stats.
//!
//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
-use crate::Init;
use candle::{DType, Module, Result, Tensor, Var};
#[derive(Debug, Clone, Copy, PartialEq)]
@@ -92,7 +91,6 @@ impl BatchNorm {
)
}
}
-
Ok(())
}
@@ -217,34 +215,32 @@ impl BatchNorm {
let x = x.to_dtype(internal_dtype)?;
let x = x.transpose(0, 1)?;
let x_dims_post_transpose = x.dims();
+ // Flatten all the dimensions exception the channel one as this performs a Spatial Batch
+ // Normalization.
let x = x.flatten_from(1)?.contiguous()?;
let x = if self.remove_mean {
+ // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let mean_x = x.mean_keepdim(1)?;
- {
- // Update running mean
- let new_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
- + (mean_x.flatten_all()? * self.momentum)?)?;
-
- self.running_mean.set(&new_mean)?;
- }
+ let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
+ + (mean_x.flatten_all()? * self.momentum)?)?;
+ self.running_mean.set(&updated_running_mean)?;
x.broadcast_sub(&mean_x)?
} else {
x
};
+ // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
let norm_x = x.sqr()?.mean_keepdim(1)?;
- {
- // Update running variance
+ let updated_running_var = {
let batch_size = x.dim(1)? as f64;
let running_var_weight = 1.0 - self.momentum;
let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
-
- let new_var = ((self.running_var.as_tensor() * running_var_weight)?
- + (&norm_x.flatten_all()? * norm_x_weight)?)?;
-
- self.running_var.set(&new_var)?;
- }
- let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
- let x = x_normed.to_dtype(x_dtype)?;
+ ((self.running_var.as_tensor() * running_var_weight)?
+ + (&norm_x.flatten_all()? * norm_x_weight)?)?
+ };
+ self.running_var.set(&updated_running_var)?;
+ let x = x
+ .broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
+ .to_dtype(x_dtype)?;
let x = match &self.weight_and_bias {
None => x,
Some((weight, bias)) => {
@@ -297,6 +293,7 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
config: C,
vb: crate::VarBuilder,
) -> Result<BatchNorm> {
+ use crate::Init;
let config = config.into();
if config.eps < 0. {
candle::bail!("batch-norm eps cannot be negative {}", config.eps)