diff options
Diffstat (limited to 'candle-nn/tests/layer_norm.rs')
-rw-r--r-- | candle-nn/tests/layer_norm.rs | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs index 0f43d804..f81c29bd 100644 --- a/candle-nn/tests/layer_norm.rs +++ b/candle-nn/tests/layer_norm.rs @@ -5,11 +5,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle::{Device, Tensor}; +use candle::{test_utils, Device, Tensor}; use candle_nn::{LayerNorm, Module}; -mod test_utils; - #[test] fn layer_norm() -> Result<()> { let device = &Device::Cpu; @@ -28,7 +26,7 @@ fn layer_norm() -> Result<()> { let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?; let res = ln.forward(&inp)?; assert_eq!( - test_utils::to_vec3_round(res.clone(), 4)?, + test_utils::to_vec3_round(&res, 4)?, [[ [-3.1742, 0.5, 4.1742], [-3.1742, 0.5, 4.1742], @@ -41,7 +39,7 @@ fn layer_norm() -> Result<()> { let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?; // The standard deviation should be sqrt(`w`). assert_eq!( - test_utils::to_vec3_round(std, 4)?, + test_utils::to_vec3_round(&std, 4)?, [[[1.7321], [1.7321], [1.7321]]] ); Ok(()) |