diff options
Diffstat (limited to 'candle-nn/tests/group_norm.rs')
-rw-r--r-- | candle-nn/tests/group_norm.rs | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs index eff66d17..8145a220 100644 --- a/candle-nn/tests/group_norm.rs +++ b/candle-nn/tests/group_norm.rs @@ -25,10 +25,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; +use candle::test_utils::to_vec3_round; use candle::{Device, Tensor}; use candle_nn::{GroupNorm, Module}; -mod test_utils; -use test_utils::to_vec3_round; #[test] fn group_norm() -> Result<()> { @@ -60,7 +59,7 @@ fn group_norm() -> Result<()> { device, )?; assert_eq!( - to_vec3_round(gn2.forward(&input)?, 4)?, + to_vec3_round(&gn2.forward(&input)?, 4)?, &[ [ [-0.1653, 0.3748, -0.7866], @@ -81,7 +80,7 @@ fn group_norm() -> Result<()> { ] ); assert_eq!( - to_vec3_round(gn3.forward(&input)?, 4)?, + to_vec3_round(&gn3.forward(&input)?, 4)?, &[ [ [0.4560, 1.4014, -0.6313], |