summaryrefslogtreecommitdiff
path: root/candle-nn/tests/batch_norm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/tests/batch_norm.rs')
-rw-r--r--candle-nn/tests/batch_norm.rs15
1 files changed, 13 insertions, 2 deletions
diff --git a/candle-nn/tests/batch_norm.rs b/candle-nn/tests/batch_norm.rs
index 5bbaf238..6fd7361a 100644
--- a/candle-nn/tests/batch_norm.rs
+++ b/candle-nn/tests/batch_norm.rs
@@ -16,6 +16,8 @@ input = torch.randn(2, 5, 3, 4)
output = m(input)
print(input.flatten())
print(output.flatten())
+print(m.running_mean)
+print(m.running_var)
*/
#[test]
fn batch_norm() -> Result<()> {
@@ -37,7 +39,7 @@ fn batch_norm() -> Result<()> {
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
- let output = bn.forward_learning(&input)?;
+ let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
@@ -65,11 +67,20 @@ fn batch_norm() -> Result<()> {
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
- let output2 = bn2.forward_learning(&input)?;
+ let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
let sum_diff2 = diff2.sum_keepdim(0)?;
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
+
+ assert_eq!(
+ test_utils::to_vec1_round(bn.running_mean(), 4)?,
+ &[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
+ );
+ assert_eq!(
+ test_utils::to_vec1_round(bn.running_var(), 4)?,
+ &[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
+ );
Ok(())
}