summaryrefslogtreecommitdiff
path: root/candle-nn/tests/group_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/group_norm.rs')
-rw-r--r--candle-nn/tests/group_norm.rs7
1 files changed, 3 insertions, 4 deletions
diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs
index eff66d17..8145a220 100644
--- a/candle-nn/tests/group_norm.rs
+++ b/candle-nn/tests/group_norm.rs
@@ -25,10 +25,9 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::Result;
+use candle::test_utils::to_vec3_round;
use candle::{Device, Tensor};
use candle_nn::{GroupNorm, Module};
-mod test_utils;
-use test_utils::to_vec3_round;
#[test]
fn group_norm() -> Result<()> {
@@ -60,7 +59,7 @@ fn group_norm() -> Result<()> {
device,
)?;
assert_eq!(
- to_vec3_round(gn2.forward(&input)?, 4)?,
+ to_vec3_round(&gn2.forward(&input)?, 4)?,
&[
[
[-0.1653, 0.3748, -0.7866],
@@ -81,7 +80,7 @@ fn group_norm() -> Result<()> {
]
);
assert_eq!(
- to_vec3_round(gn3.forward(&input)?, 4)?,
+ to_vec3_round(&gn3.forward(&input)?, 4)?,
&[
[
[0.4560, 1.4014, -0.6313],