summaryrefslogtreecommitdiff
path: root/candle-nn/tests/layer_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/layer_norm.rs')
-rw-r--r--candle-nn/tests/layer_norm.rs8
1 files changed, 3 insertions, 5 deletions
diff --git a/candle-nn/tests/layer_norm.rs b/candle-nn/tests/layer_norm.rs
index 0f43d804..f81c29bd 100644
--- a/candle-nn/tests/layer_norm.rs
+++ b/candle-nn/tests/layer_norm.rs
@@ -5,11 +5,9 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::Result;
-use candle::{Device, Tensor};
+use candle::{test_utils, Device, Tensor};
use candle_nn::{LayerNorm, Module};
-mod test_utils;
-
#[test]
fn layer_norm() -> Result<()> {
let device = &Device::Cpu;
@@ -28,7 +26,7 @@ fn layer_norm() -> Result<()> {
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
let res = ln.forward(&inp)?;
assert_eq!(
- test_utils::to_vec3_round(res.clone(), 4)?,
+ test_utils::to_vec3_round(&res, 4)?,
[[
[-3.1742, 0.5, 4.1742],
[-3.1742, 0.5, 4.1742],
@@ -41,7 +39,7 @@ fn layer_norm() -> Result<()> {
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
// The standard deviation should be sqrt(`w`).
assert_eq!(
- test_utils::to_vec3_round(std, 4)?,
+ test_utils::to_vec3_round(&std, 4)?,
[[[1.7321], [1.7321], [1.7321]]]
);
Ok(())