1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{test_utils, DType, Device, Tensor};
use candle_nn::{batch_norm, BatchNorm, BatchNormConfig, VarBuilder, VarMap};
/* The test below has been generated using the following PyTorch code:
import torch
torch.manual_seed(19551105)
m = torch.nn.BatchNorm2d(5, affine=False)
input = torch.randn(2, 5, 3, 4)
output = m(input)
print(input.flatten())
print(output.flatten())
print(m.running_mean)
print(m.running_var)
*/
#[test]
fn batch_norm_test() -> Result<()> {
let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?;
let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?;
let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?;
let input: [f32; 120] = [
-0.7493, -1.0410, 1.6977, -0.6579, 1.7982, -0.0087, 0.2812, -0.1190, 0.2908, -0.5975,
-0.0278, -0.2138, -1.3130, -1.6048, -2.2028, 0.9452, 0.4002, 0.0831, 1.0004, 0.1860,
0.5004, 0.5539, 0.9991, -0.2540, -0.0703, -0.3752, -0.1096, -0.2374, 1.0258, -2.2208,
-0.0257, 0.6073, -1.1627, -0.0964, -1.9718, 1.6577, 0.1931, -0.3692, -0.8011, 0.9059,
0.4797, 0.6521, -0.0165, -0.6683, -0.4148, 2.0649, -0.8276, 1.7947, -0.2061, 0.5812,
-1.3598, 1.6192, 1.0466, -0.4423, 0.4202, 0.1749, 0.6969, 0.2616, -0.0369, -1.4951,
-0.0814, -0.1877, 0.0267, 0.6150, 0.2402, -1.1440, -2.0068, 0.6032, -2.6639, 0.8260,
0.1085, -0.1693, 1.2805, 0.7654, -0.4930, 0.3770, 1.1309, 0.2303, 0.2949, -0.2634, -0.5225,
0.4269, 0.6341, 1.5736, 0.9827, -1.2499, 0.3509, -1.6243, -0.8123, 0.7634, -0.3047, 0.0143,
-0.4032, 0.0537, 0.7022, 0.8405, -1.2221, -1.6847, -0.0714, -0.1608, 0.5579, -1.5858,
0.4617, -0.6480, 0.1332, 0.0419, -0.9784, 0.4173, 1.2313, -1.9046, -0.1656, 0.1259, 0.0763,
1.4252, -0.9115, -0.1093, -0.3100, -0.6734, -1.4357, 0.9205,
];
let input = Tensor::new(&input, &Device::Cpu)?.reshape((2, 5, 3, 4))?;
let output = bn.forward_train(&input)?;
assert_eq!(output.dims(), &[2, 5, 3, 4]);
let output = output.flatten_all()?;
assert_eq!(
test_utils::to_vec1_round(&output, 4)?,
&[
-0.6391, -0.9414, 1.8965, -0.5444, 2.0007, 0.1283, 0.4287, 0.014, 0.4387, -0.4818,
0.1085, -0.0842, -1.6809, -2.0057, -2.6714, 0.8328, 0.2262, -0.1268, 0.8943, -0.0123,
0.3377, 0.3973, 0.8928, -0.5021, 0.0861, -0.2324, 0.0451, -0.0884, 1.2311, -2.1603,
0.1327, 0.7939, -1.055, 0.0589, -1.9002, 1.8912, 0.2918, -0.3253, -0.7993, 1.0741,
0.6063, 0.7955, 0.0617, -0.6536, -0.3754, 2.3461, -0.8284, 2.0495, -0.201, 0.6476,
-1.4446, 1.7665, 1.1493, -0.4556, 0.4741, 0.2097, 0.7723, 0.3031, -0.0186, -1.5905,
0.053, -0.0572, 0.165, 0.7746, 0.3862, -1.0481, -1.9422, 0.7624, -2.6231, 0.9933,
0.2498, -0.0381, 1.2061, 0.6327, -0.7681, 0.2004, 1.0396, 0.037, 0.109, -0.5125,
-0.8009, 0.2559, 0.4865, 1.5324, 1.1861, -1.1461, 0.5261, -1.5372, -0.689, 0.957,
-0.1587, 0.1745, -0.2616, 0.2156, 0.8931, 1.0375, -1.2614, -1.7691, 0.0015, -0.0966,
0.6921, -1.6605, 0.5866, -0.6313, 0.226, 0.1258, -0.9939, 0.5378, 1.3484, -2.0319,
-0.1574, 0.1568, 0.1034, 1.5574, -0.9614, -0.0967, -0.313, -0.7047, -1.5264, 1.0134
]
);
let bn2 = BatchNorm::new(
5,
running_mean,
running_var,
Tensor::new(&[0.5f32], &Device::Cpu)?.broadcast_as(5)?,
Tensor::new(&[-1.5f32], &Device::Cpu)?.broadcast_as(5)?,
1e-8,
)?;
let output2 = bn2.forward_train(&input)?;
assert_eq!(output2.dims(), &[2, 5, 3, 4]);
let output2 = output2.flatten_all()?;
let diff2 = ((output2 - (output * 0.5)?)? + 1.5)?.sqr()?;
let sum_diff2 = diff2.sum_keepdim(0)?;
assert_eq!(test_utils::to_vec1_round(&sum_diff2, 4)?, &[0f32]);
assert_eq!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
&[-0.0133, 0.0197, -0.0153, -0.0073, -0.0020]
);
assert_eq!(
test_utils::to_vec1_round(bn.running_var(), 4)?,
&[0.9972, 0.9842, 0.9956, 0.9866, 0.9898]
);
Ok(())
}
// This test makes sure that we can train a batch norm layer using a VarMap.
#[test]
fn train_batch_norm() -> Result<()> {
let vm = VarMap::new();
let vb = VarBuilder::from_varmap(&vm, DType::F32, &Device::Cpu);
let bn = batch_norm(1, BatchNormConfig::default(), vb)?;
// Get a copy of the original mean to ensure it is being updated.
let original_mean = bn.running_mean().detach().copy()?;
let var_map_mean = {
vm.data()
.lock()
.unwrap()
.get("running_mean")
.unwrap()
.clone()
};
// Ensure the var map mean is the same as the running mean.
assert_eq!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
);
// Train with a something guaranteed to be different from the running mean.
let mean_plus_one = {
let one = original_mean.ones_like()?;
original_mean.add(&one)?.reshape((1, 1))?
};
bn.forward_train(&mean_plus_one)?;
// Assert that the running mean has been updated.
assert_ne!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
test_utils::to_vec1_round(&original_mean, 4)?,
);
// Assert that the var map mean has been updated.
assert_eq!(
test_utils::to_vec1_round(bn.running_mean(), 4)?,
test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
);
Ok(())
}
|