diff options
Diffstat (limited to 'candle-nn/tests/batch_norm.rs')
-rw-r--r-- | candle-nn/tests/batch_norm.rs | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs index 5bbaf238..6fd7361a 100644 --- a/candle-nn/tests/batch_norm.rs +++ b/candle-nn/tests/batch_norm.rs @@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4) output = m(input) print(input.flatten()) print(output.flatten()) +print(m.running_mean) +print(m.running_var) */ #[test] fn batch_norm() -> Result<()> { @@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> { 1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205, ]; let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?; - let output = bn.forward_learning(&input)?; + let output = bn.forward_train(&input)?; assert_eq!(output.dims(), &[2, 5, 3, 4]); let output = output.flatten_all()?; assert_eq!( @@ -65,11 +67,20 @@ fn batch_norm() -> Result<()> { Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?, 1e-8, )?; - let output2 = bn2.forward_learning(&input)?; + let output2 = bn2.forward_train(&input)?; assert_eq!(output2.dims(), &[2, 5, 3, 4]); let output2 = output2.flatten_all()?; let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?; let sum_diff2 = diff2.sum_keepdim(0)?; assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]); + + assert_eq!( + test_utils::to_vec1_round(bn.running_mean(), 4)?, + &[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020] + ); + assert_eq!( + test_utils::to_vec1_round(bn.running_var(), 4)?, + &[0.9972, 0.9842, 0.9956, 0.9866, 0.9898] + ); Ok(()) } |