//! Layer Normalization. //! //! This layer applies Layer Normalization over a mini-batch of inputs as described in [`Layer //! Normalization`]. The input is expected to have three dimensions: a batch dimension, a length, //! and a hidden size, the normalization is applied over the last dimension. //! //! # Example //! //! ```rust //! use candle::{Tensor, Device::Cpu, test_utils::to_vec3_round}; //! use candle_nn::{LayerNorm, Module}; //! # fn main() -> candle::Result<()> { //! //! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?; //! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?; //! let layer = LayerNorm::new(w, b, 1e-5); //! //! let xs = Tensor::new( //! &[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], //! &Cpu)?; //! let ys = layer.forward(&xs)?; //! assert_eq!( //! to_vec3_round(&ys, 4)?, //! &[[[-1.2247, 0.0, 1.2247], //! [-1.2247, 0.0, 1.2247], //! [ 1.2247, 0.0, -1.2247]]]); //! # Ok(()) } //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 use candle::{DType, Module, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { pub eps: f64, /// Whether to remove the mean or not, the default is true and when set to false, this turns /// this layer into RmsNorm. pub remove_mean: bool, pub affine: bool, } impl Default for LayerNormConfig { fn default() -> Self { Self { eps: 1e-5, remove_mean: true, affine: true, } } } impl From for LayerNormConfig { fn from(eps: f64) -> Self { Self { eps, remove_mean: true, affine: true, } } } // This layer norm version handles both weight and bias so removes the mean. #[derive(Clone, Debug)] pub struct LayerNorm { weight: Tensor, bias: Option, remove_mean: bool, eps: f64, } impl LayerNorm { pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { Self { weight, bias: Some(bias), remove_mean: true, eps, } } pub fn new_no_bias(weight: Tensor, eps: f64) -> Self { Self { weight, bias: None, remove_mean: true, eps, } } pub fn rms_norm(weight: Tensor, eps: f64) -> Self { Self { weight, bias: None, remove_mean: false, eps, } } pub fn weight(&self) -> &Tensor { &self.weight } pub fn bias(&self) -> Option<&Tensor> { self.bias.as_ref() } } impl Module for LayerNorm { fn forward(&self, x: &Tensor) -> Result { if x.is_contiguous() && self.remove_mean { if let Some(bias) = self.bias.as_ref() { return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32); } } let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, d => d, }; let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let x = if self.remove_mean { let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; x.broadcast_sub(&mean_x)? } else { x }; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias), } } } pub fn layer_norm>( size: usize, config: C, vb: crate::VarBuilder, ) -> Result { let config = config.into(); let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?; let bias = if config.affine { Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?) } else { None }; Ok(LayerNorm { weight, bias, remove_mean: config.remove_mean, eps: config.eps, }) } /// RmsNorm is a specialized version of the LayerNorm module. #[derive(Clone, Debug)] pub struct RmsNorm(LayerNorm); impl RmsNorm { pub fn new(weight: Tensor, eps: f64) -> Self { Self(LayerNorm::rms_norm(weight, eps)) } pub fn into_inner(self) -> LayerNorm { self.0 } /// Faster variant of the forward kernel, this can only be used on contiguous tensors though. pub fn forward_diff(&self, xs: &Tensor) -> Result { self.0.forward(xs) } } impl Module for RmsNorm { fn forward(&self, xs: &Tensor) -> Result { if xs.is_contiguous() { crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32) } else { self.0.forward(xs) } } } pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { let config = LayerNormConfig { eps, remove_mean: false, affine: false, }; Ok(RmsNorm(layer_norm(size, config, vb)?)) }