diff options
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/batch_norm.rs | 215 | ||||
-rw-r--r-- | candle-nn/src/encoding.rs | 150 | ||||
-rw-r--r-- | candle-nn/src/lib.rs | 1 |
3 files changed, 311 insertions, 55 deletions
diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index 8cfc6740..856c2c7a 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -7,15 +7,21 @@ //! running stats. //! //! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167 -use candle::{DType, Result, Tensor}; +use candle::{DType, Result, Tensor, Var}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct BatchNormConfig { pub eps: f64, pub remove_mean: bool, + /// The meaning of affine here is different from LayerNorm: when false there is no learnable /// parameter at all, 1 used for gamma and 0 for beta. pub affine: bool, + + /// Controls exponential moving average of running stats. Defaults to 0.1 + /// + /// `running_stat * (1.0 - momentum) + stat * momentum`. + pub momentum: f64, } impl Default for BatchNormConfig { @@ -24,6 +30,7 @@ impl Default for BatchNormConfig { eps: 1e-5, remove_mean: true, affine: true, + momentum: 0.1, } } } @@ -32,23 +39,61 @@ impl From<f64> for BatchNormConfig { fn from(eps: f64) -> Self { Self { eps, - remove_mean: true, - affine: true, + ..Default::default() } } } #[derive(Clone, Debug)] pub struct BatchNorm { - running_mean: Tensor, - running_var: Tensor, + running_mean: Var, + running_var: Var, weight_and_bias: Option<(Tensor, Tensor)>, remove_mean: bool, eps: f64, - num_features: usize, + momentum: f64, } impl BatchNorm { + fn check_validity(&self, num_features: usize) -> Result<()> { + if self.eps < 0. { + candle::bail!("batch-norm eps cannot be negative {}", self.eps) + } + if !(0.0..=1.0).contains(&self.momentum) { + candle::bail!( + "batch-norm momentum must be between 0 and 1, is {}", + self.momentum + ) + } + if self.running_mean.dims() != [num_features] { + candle::bail!( + "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]", + self.running_mean.shape(), + ) + } + if self.running_var.dims() != [num_features] { + candle::bail!( + "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]", + self.running_var.shape(), + ) + } + if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() { + if weight.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + weight.shape(), + ) + } + if bias.dims() != [num_features] { + candle::bail!( + "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]", + bias.shape(), + ) + } + } + Ok(()) + } + pub fn new( num_features: usize, running_mean: Tensor, @@ -57,29 +102,16 @@ impl BatchNorm { bias: Tensor, eps: f64, ) -> Result<Self> { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - if weight.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected weight shape {:?} {num_features}", - weight.shape() - ) - } - if bias.dims() != [num_features] { - candle::bail!( - "batch-norm unexpected bias shape {:?} {num_features}", - bias.shape() - ) - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: Some((weight, bias)), remove_mean: true, eps, - num_features, - }) + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn new_no_bias( @@ -88,25 +120,64 @@ impl BatchNorm { running_var: Tensor, eps: f64, ) -> Result<Self> { - if eps < 0. { - candle::bail!("batch-norm eps cannot be negative {eps}") - } - Ok(Self { - running_mean, - running_var, + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: None, + remove_mean: true, + eps, + momentum: 0.1, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + weight: Tensor, + bias: Tensor, + eps: f64, + momentum: f64, + ) -> Result<Self> { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, + weight_and_bias: Some((weight, bias)), + remove_mean: true, + eps, + momentum, + }; + out.check_validity(num_features)?; + Ok(out) + } + + pub fn new_no_bias_with_momentum( + num_features: usize, + running_mean: Tensor, + running_var: Tensor, + eps: f64, + momentum: f64, + ) -> Result<Self> { + let out = Self { + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias: None, remove_mean: true, eps, - num_features, - }) + momentum, + }; + out.check_validity(num_features)?; + Ok(out) } pub fn running_mean(&self) -> &Tensor { - &self.running_mean + self.running_mean.as_tensor() } pub fn running_var(&self) -> &Tensor { - &self.running_var + self.running_var.as_tensor() } pub fn eps(&self) -> f64 { @@ -117,7 +188,12 @@ impl BatchNorm { self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1)) } - pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> { + pub fn momentum(&self) -> f64 { + self.momentum + } + + pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> { + let num_features = self.running_mean.as_tensor().dim(0)?; let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, @@ -129,40 +205,54 @@ impl BatchNorm { x.shape() ) } - if x.dim(1)? != self.num_features { + if x.dim(1)? != num_features { candle::bail!( "batch-norm input doesn't have the expected number of features ({:?} <> {})", x.shape(), - self.num_features + num_features ) } let x = x.to_dtype(internal_dtype)?; let x = x.transpose(0, 1)?; let x_dims_post_transpose = x.dims(); + // Flatten all the dimensions exception the channel one as this performs a Spatial Batch + // Normalization. let x = x.flatten_from(1)?.contiguous()?; let x = if self.remove_mean { + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let mean_x = x.mean_keepdim(1)?; + let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))? + + (mean_x.flatten_all()? * self.momentum)?)?; + self.running_mean.set(&updated_running_mean)?; x.broadcast_sub(&mean_x)? } else { x }; + // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above. let norm_x = x.sqr()?.mean_keepdim(1)?; - let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; - let x = x_normed.to_dtype(x_dtype)?; + let updated_running_var = { + let batch_size = x.dim(1)? as f64; + let running_var_weight = 1.0 - self.momentum; + let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0); + ((self.running_var.as_tensor() * running_var_weight)? + + (&norm_x.flatten_all()? * norm_x_weight)?)? + }; + self.running_var.set(&updated_running_var)?; + let x = x + .broadcast_div(&(norm_x + self.eps)?.sqrt()?)? + .to_dtype(x_dtype)?; let x = match &self.weight_and_bias { None => x, Some((weight, bias)) => { - let weight = weight.reshape((self.num_features, 1))?; - let bias = bias.reshape((self.num_features, 1))?; + let weight = weight.reshape(((), 1))?; + let bias = bias.reshape(((), 1))?; x.broadcast_mul(&weight)?.broadcast_add(&bias)? } }; x.reshape(x_dims_post_transpose)?.transpose(0, 1) } -} -impl crate::Module for BatchNorm { - fn forward(&self, x: &Tensor) -> Result<Tensor> { + fn forward_eval(&self, x: &Tensor) -> Result<Tensor> { let target_shape: Vec<usize> = x .dims() .iter() @@ -170,9 +260,13 @@ impl crate::Module for BatchNorm { .map(|(idx, v)| if idx == 1 { *v } else { 1 }) .collect(); let target_shape = target_shape.as_slice(); + let x = x - .broadcast_sub(&self.running_mean.reshape(target_shape)?)? - .broadcast_div(&(self.running_var.reshape(target_shape)? + self.eps)?.sqrt()?)?; + .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)? + .broadcast_div( + &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?, + )?; + match &self.weight_and_bias { None => Ok(x), Some((weight, bias)) => { @@ -184,30 +278,41 @@ impl crate::Module for BatchNorm { } } +impl crate::ModuleT for BatchNorm { + fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> { + if train { + self.forward_train(x) + } else { + self.forward_eval(x) + } + } +} + pub fn batch_norm<C: Into<BatchNormConfig>>( num_features: usize, config: C, vb: crate::VarBuilder, ) -> Result<BatchNorm> { + use crate::Init; let config = config.into(); if config.eps < 0. { candle::bail!("batch-norm eps cannot be negative {}", config.eps) } - let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?; - let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?; + let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?; + let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?; let weight_and_bias = if config.affine { - let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?; - let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?; + let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?; + let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?; Some((weight, bias)) } else { None }; Ok(BatchNorm { - running_mean, - running_var, + running_mean: Var::from_tensor(&running_mean)?, + running_var: Var::from_tensor(&running_var)?, weight_and_bias, remove_mean: config.remove_mean, eps: config.eps, - num_features, + momentum: config.momentum, }) } diff --git a/candle-nn/src/encoding.rs b/candle-nn/src/encoding.rs new file mode 100644 index 00000000..38e2cc3b --- /dev/null +++ b/candle-nn/src/encoding.rs @@ -0,0 +1,150 @@ +//! Encoding Utilities. (e.g., one-hot/cold encoding) + +use candle::{bail, DType, Result, Tensor, WithDType}; + +/// One-hot/cold encoding. +/// +/// Given an input tensor of indices, this function returns a tensor of the same shape as the input +/// tensor with an additional dimension of the given depth size. The values in the returned tensor are +/// all set to the `off_value` except for the positions represented by the indices, which are set to the `on_value`. +/// +/// This method returns a tensor with a rank that is one rank larger than the input tensor. +/// +/// As an example, the following tensor will be encoded to a one-hot matrix: +/// +/// `[[0i64, 2], [1, -1]]` +/// +/// with a depth of 4 will be encoded to: +/// +/// `[[[1, 0, 0, 0], [0, 0, 1, 0]], [[0, 1, 0, 0], [0, 0, 0, 0]]]` +/// +/// When the input tensor index has a value of -1, the corresponding one-hot vector will be ignored, +/// resulting in a vector of values set to the `off_value`. +/// +/// +/// This method supports one-cold encoding by setting `on_value` to `0` and `off_value` to `1`. +/// By default `on_value` is `1` and `off_value` is `0`. +/// +/// Other encoding values can be used by setting `on_value` and `off_value` to the desired values. +/// +/// # Examples +/// +/// ## One-hot encoding +/// +/// ```rust +/// use candle::{Shape, Tensor, Device}; +/// use candle_nn::encoding::one_hot; +/// +/// let device = candle::Device::Cpu; +/// +/// let indices = Tensor::new(vec![vec![0i64, 2], vec![1, -1]], &device).unwrap(); +/// let depth = 4; +/// let one_hot = one_hot(indices, depth, 1f32, 0f32).unwrap(); +/// +/// let expected_matrix = [ +/// [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], +/// [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], +/// ]; +/// +/// assert_eq!(one_hot.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_hot.to_vec3::<f32>().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +///``` +/// ## One-cold Encoding +/// +/// ```rust +/// use candle::{Shape, Tensor, Device}; +/// use candle_nn::encoding::one_hot; +/// +/// +/// let device = candle::Device::Cpu; +/// let depth = 4; +/// let indices = Tensor::new(vec![vec![0u8, 2], vec![1, 3]], &device).unwrap(); +/// let one_cold = one_hot(indices, depth, 0u8, 1u8).unwrap(); +/// +/// let expected_matrix = [[[0, 1, 1, 1], [1, 1, 0, 1]], [[1, 0, 1, 1], [1, 1, 1, 0]]]; +/// +/// assert_eq!(one_cold.shape(), &Shape::from((2, 2, depth))); +/// +/// let matrix = one_cold.to_vec3::<u8>().unwrap(); +/// +/// assert_eq!(matrix, expected_matrix); +/// ``` +/// +/// +/// # Bails +/// +/// This method bails if: +/// - One of the index value is less than -1. +/// - One of the index value is greater than or equal to the depth value. +/// - The input data type is not `U8`, `U32`, or `I64`. +/// +/// # API Design +/// +/// The api design for this method is loosely based on the [TensorFlow One-Hot](https://www.tensorflow.org/api_docs/python/tf/one_hot) method. +pub fn one_hot<D: WithDType>( + indices: Tensor, + depth: usize, + on_value: D, + off_value: D, +) -> Result<Tensor> { + let mut target_shape = indices.dims().to_vec(); + target_shape.push(depth); + let indices = indices.flatten_all()?; + let mut out = vec![off_value; depth * indices.elem_count()]; + match indices.dtype() { + DType::U8 => { + let indices = indices.to_vec1::<u8>()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::U32 => { + let indices = indices.to_vec1::<u32>()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + DType::I64 => { + let indices = indices.to_vec1::<i64>()?; + for (i, &index) in indices.iter().enumerate() { + set_at_index(index, i * depth, depth, &mut out, on_value)?; + } + } + dtype => { + bail!("one_hot: unsupported data type {dtype:?}, expected U8, U32, or I64") + } + }; + Tensor::from_vec(out, target_shape, indices.device()) +} + +fn set_at_index<D: WithDType, I: Into<i64>>( + value: I, + offset: usize, + depth: usize, + v: &mut Vec<D>, + on_value: D, +) -> Result<()> { + let value = value.into(); + // Skip for an entire row of off_values + if value == -1 { + return Ok(()); + } + if value < -1 { + bail!( + "one_hot: invalid negative index value {value}, expected a positive index value or -1" + ); + } + let value = value as usize; + if value >= depth { + bail!("one_hot: index value {value} exceeds depth {depth}") + } + let idx = offset + value; + if idx >= v.len() { + bail!("one_hot: index out of bounds {idx}, len {}", v.len()); + } + v[idx] = on_value; + Ok(()) +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 8f00e54c..6306c55a 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -2,6 +2,7 @@ pub mod activation; pub mod batch_norm; pub mod conv; pub mod embedding; +pub mod encoding; pub mod func; pub mod group_norm; pub mod init; |