diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-05-24 16:48:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-24 16:48:21 +0200 |
commit | 3ceca9901a5ebc4ded3ac2cd793d0125f7a12562 (patch) | |
tree | 364793408840c261956f04fed2b0caf430655c41 /candle-nn | |
parent | 1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (diff) | |
download | candle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.tar.gz candle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.tar.bz2 candle-3ceca9901a5ebc4ded3ac2cd793d0125f7a12562.zip |
Enable the new layer-norm. (#2213)
* Enable the new layer-norm.
* Shape fixes.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/src/layer_norm.rs | 9 | ||||
-rw-r--r-- | candle-nn/tests/layer_norm.rs | 15 |
2 files changed, 19 insertions, 5 deletions
diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 23d0c01b..b7dd61cb 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -11,8 +11,8 @@ //! use candle_nn::{LayerNorm, Module}; //! # fn main() -> candle::Result<()> { //! -//! let w = Tensor::new(1f32, &Cpu)?; -//! let b = Tensor::new(0f32, &Cpu)?; +//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?; +//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?; //! let layer = LayerNorm::new(w, b, 1e-5); //! //! let xs = Tensor::new( @@ -107,6 +107,11 @@ impl LayerNorm { impl Module for LayerNorm { fn forward(&self, x: &Tensor) -> Result<Tensor> { + if x.is_contiguous() && self.remove_mean { + if let Some(bias) = self.bias.as_ref() { + return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32); + } + } let x_dtype = x.dtype(); let internal_dtype = match x_dtype { DType::F16 | DType::BF16 => DType::F32, diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index f81c29bd..30f598b3 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> { let device = &Device::Cpu; let w = Tensor::new(&[3f32], device)?; let b = Tensor::new(&[0.5f32], device)?; + let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8); + let ln3 = LayerNorm::new( + Tensor::cat(&[&w, &w, &w], 0)?, + Tensor::cat(&[&b, &b, &b], 0)?, + 1e-8, + ); let ln = LayerNorm::new(w, b, 1e-8); let two = Tensor::new(&[[[2f32]]], device)?; @@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> { assert_eq!(res.to_vec1::<f32>()?, [0.5f32]); let inp = Tensor::new(&[[[4f32, 0f32]]], device)?; - let res = ln.forward(&inp)?; + let res = ln2.forward(&inp)?; assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]); let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?; - let res = ln.forward(&inp)?; + let res = ln3.forward(&inp)?; assert_eq!( test_utils::to_vec3_round(&res, 4)?, [[ @@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> { ); let mean = (res.sum_keepdim(2)? / 3.0)?; // The average value should be `b`. - assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]); + assert_eq!( + test_utils::to_vec3_round(&mean, 4)?, + [[[0.5], [0.5], [0.5]]] + ); let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?; // The standard deviation should be sqrt(`w`). assert_eq!( |