diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-10 14:43:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 14:43:04 +0100 |
commit | 71cd3745a90e277c8d5911b7ddc98d70aebcd8ed (patch) | |
tree | 9dfbf7305b34102bad215d725b5dd7ec0ce62a22 | |
parent | dc5825967957e28e6ac4f57da18c7963f2be542c (diff) | |
download | candle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.tar.gz candle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.tar.bz2 candle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.zip |
Add some layer-norm tests. (#121)
-rw-r--r-- | candle-nn/tests/layer_norm.rs | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs new file mode 100644 index 00000000..e4b962b4 --- /dev/null +++ b/candle-nn/tests/layer_norm.rs @@ -0,0 +1,43 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::LayerNorm; + +#[test] +fn layer_norm() -> Result<()> { + let device = &Device::Cpu; + let w = Tensor::new(&[3f32], device)?; + let b = Tensor::new(&[0.5f32], device)?; + let ln = LayerNorm::new(w, b, 1e-8); + + let two = Tensor::new(&[[[2f32]]], device)?; + let res = ln.forward(&two)?.flatten_all()?; + assert_eq!(res.to_vec1::<f32>()?, [0.5f32]); + + let inp = Tensor::new(&[[[4f32, 0f32]]], device)?; + let res = ln.forward(&inp)?; + assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]); + + let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?; + let res = ln.forward(&inp)?; + assert_eq!( + res.to_vec3::<f32>()?, + [[ + [-3.1742344, 0.5, 4.1742344], + [-3.1742344, 0.5, 4.1742344], + [4.1742344, 0.5, -3.1742344] + ]] + ); + let mean = (res.sum(&[2])? / 3.0)?; + // The average value should be `b`. + assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]); + let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?; + // The standard deviation should be sqrt(`w`). + assert_eq!( + std.to_vec3::<f32>()?, + [[[1.7320508], [1.7320508], [1.7320508]]] + ); + Ok(()) +} |