summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-10 14:43:04 +0100
committerGitHub <noreply@github.com>2023-07-10 14:43:04 +0100
commit71cd3745a90e277c8d5911b7ddc98d70aebcd8ed (patch)
tree9dfbf7305b34102bad215d725b5dd7ec0ce62a22
parentdc5825967957e28e6ac4f57da18c7963f2be542c (diff)
downloadcandle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.tar.gz
candle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.tar.bz2
candle-71cd3745a90e277c8d5911b7ddc98d70aebcd8ed.zip
Add some layer-norm tests. (#121)
-rw-r--r--candle-nn/tests/layer_norm.rs43
1 files changed, 43 insertions, 0 deletions
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs
new file mode 100644
index 00000000..e4b962b4
--- /dev/null
+++ b/candle-nn/tests/layer_norm.rs
@@ -0,0 +1,43 @@
+#[cfg(feature = "mkl")]
+extern crate intel_mkl_src;
+
+use anyhow::Result;
+use candle::{Device, Tensor};
+use candle_nn::LayerNorm;
+
+#[test]
+fn layer_norm() -> Result<()> {
+ let device = &Device::Cpu;
+ let w = Tensor::new(&[3f32], device)?;
+ let b = Tensor::new(&[0.5f32], device)?;
+ let ln = LayerNorm::new(w, b, 1e-8);
+
+ let two = Tensor::new(&[[[2f32]]], device)?;
+ let res = ln.forward(&two)?.flatten_all()?;
+ assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
+
+ let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
+ let res = ln.forward(&inp)?;
+ assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
+
+ let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
+ let res = ln.forward(&inp)?;
+ assert_eq!(
+ res.to_vec3::<f32>()?,
+ [[
+ [-3.1742344, 0.5, 4.1742344],
+ [-3.1742344, 0.5, 4.1742344],
+ [4.1742344, 0.5, -3.1742344]
+ ]]
+ );
+ let mean = (res.sum(&[2])? / 3.0)?;
+ // The average value should be `b`.
+ assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
+ let std = (res.broadcast_sub(&mean)?.sqr()?.sum(&[2])?.sqrt()? / 3.0)?;
+ // The standard deviation should be sqrt(`w`).
+ assert_eq!(
+ std.to_vec3::<f32>()?,
+ [[[1.7320508], [1.7320508], [1.7320508]]]
+ );
+ Ok(())
+}