summaryrefslogtreecommitdiff
path: root/candle-core/tests
diff options
context:
space:
mode:
authorKirpal Grewal <45569241+KGrewal1@users.noreply.github.com>2024-03-23 06:05:55 +0000
committerGitHub <noreply@github.com>2024-03-23 07:05:55 +0100
commitcc856db9ce2541e09731165f88cdd7aae37f558e (patch)
tree5ffa3b722f84504ba2208c439c2b9253f3ffa1eb /candle-core/tests
parentfc1fe5e45b046771589126c355fdfb4d3bb49fe4 (diff)
downloadcandle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.gz
candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.bz2
candle-cc856db9ce2541e09731165f88cdd7aae37f558e.zip
Backwards for ConvTranspose2D (#1910)
* add documentation for nackprop * add backwards for ConvTranspose2D * add test python code to test
Diffstat (limited to 'candle-core/tests')
-rw-r--r--candle-core/tests/conv_tests.rs161
1 files changed, 154 insertions, 7 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 6cc48ec7..3762e02f 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -135,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
- -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
+ -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
dev,
)?;
@@ -276,11 +276,10 @@ fn conv2d_small(dev: &Device) -> Result<()> {
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000,
- 0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855,
- -1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
- 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640,
+ -0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0,
+ 3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0
]
);
@@ -398,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712,
0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790,
-0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006,
- -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
+ -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085,
],
(1, 4, 5, 5),
dev,
@@ -583,6 +582,154 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
);
+ // Conv Transpose 2d Test
+ //tested against following python
+
+ // import torch
+ // torch.manual_seed(4242)
+ // padding = 4
+ // outpadding = 2
+ // dilation = 3
+ // stride = 3
+ // input = torch.randn((1, 4, 7, 5), requires_grad=True)
+ // kernel = torch.randn((4, 2, 3, 5), requires_grad=True)
+ // print("input", input.flatten())
+ // print("kernel", kernel.flatten())
+ // res = torch.nn.functional.conv_transpose2d(
+ // input,
+ // kernel,
+ // stride=stride,
+ // padding=padding,
+ // dilation=dilation,
+ // output_padding=outpadding,
+ // )
+ // res.retain_grad()
+ // print(res.shape)
+ // loss = (res**2).sum()
+ // print(loss)
+ // loss.backward()
+ // print(input.grad.shape)
+ // print("input grad", torch.round(input.grad, decimals=1))
+ // print(kernel.grad.shape)
+ // print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1))
+
+ let padding = 4;
+ let outpadding = 2;
+ let dilation = 3;
+ let stride = 3;
+
+ let t = Var::from_slice(
+ &[
+ 0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997,
+ 3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843,
+ 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013,
+ -0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130,
+ 1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071,
+ 1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090,
+ 0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323,
+ -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742,
+ 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912,
+ -0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465,
+ -0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264,
+ 1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451,
+ -0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258,
+ -2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186,
+ 1.6475, 0.2219,
+ ],
+ (1, 4, 7, 5),
+ dev,
+ )?;
+
+ #[rustfmt::skip]
+ let w = Var::from_slice(
+ &[
+ -1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234,
+ -0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762,
+ 0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204,
+ 0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555,
+ 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990,
+ 0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181,
+ 0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481,
+ 0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509,
+ 0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732,
+ -0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071,
+ -1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604,
+ 0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478,
+ ],
+ (4, 2, 3, 5),
+ dev,
+ )?;
+ let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?;
+ let loss = res.sqr()?.sum_all()?;
+ assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0);
+ let grads = loss.backward()?;
+
+ let grad_t = grads.get(&t).unwrap();
+ let grad_w = grads.get(&w).unwrap();
+ assert_eq!(grad_t.dims(), [1, 4, 7, 5]);
+ assert_eq!(grad_w.dims(), [4, 2, 3, 5]);
+
+ assert_eq!(
+ test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?,
+ [
+ // torch gets 89.1
+ -89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0,
+ -15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9,
+ 52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2,
+ 106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6,
+ -27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5,
+ -10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0,
+ -52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9,
+ -20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5,
+ 92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5,
+ -28.4, 85.0, -18.3, 107.0, 28.3, -71.8
+ ]
+ );
+
+ assert_eq!(
+ test_utils::to_vec3_round(&grad_t.i(0)?, 1)?,
+ [
+ [
+ [32.3, -41.6, -24.0, 14.1, 17.6],
+ [-11.8, 72.5, 87.6, 46.4, 61.5],
+ [115.0, 108.5, -48.6, -63.4, -50.0],
+ [51.3, 5.4, 31.3, 91.1, -30.9],
+ [52.7, 92.8, -68.0, -47.0, 83.0],
+ // pytorch gets -107.1
+ [-10.2, -107.0, -5.4, 213.1, -31.4],
+ [-2.4, 65.1, 9.2, -146.2, -24.2]
+ ],
+ [
+ [-72.6, -63.9, -61.9, 45.3, 33.0],
+ [79.3, -0.5, -26.2, 78.2, 42.7],
+ [90.9, 141.6, 40.1, -62.7, 37.0],
+ [32.8, 198.2, -0.8, -31.1, 27.3],
+ // torch gets 48.0
+ [34.5, 34.9, -47.9, 127.6, -12.3],
+ [-61.4, -3.2, -2.9, -10.9, -16.6],
+ [74.6, 60.1, -68.9, 34.5, -50.4]
+ ],
+ [
+ [37.5, -56.9, -43.6, -13.5, -9.9],
+ [40.0, 97.3, 28.6, 14.2, -30.1],
+ [-22.3, -126.3, -68.8, -8.2, 26.1],
+ [-32.9, 37.3, 108.5, -54.8, 29.6],
+ [34.9, -176.9, -125.0, -28.3, -13.9],
+ [-54.9, 142.6, 62.1, -80.4, -65.6],
+ [7.4, -91.1, -67.6, 35.0, 39.7]
+ ],
+ [
+ [-57.2, -40.9, -10.1, 32.6, 29.4],
+ [18.7, -18.0, 29.5, -1.2, 59.2],
+ [-14.0, -74.4, 19.8, -117.0, 58.2],
+ [-21.8, 163.5, -71.1, -99.0, 80.9],
+ [-58.9, -10.9, 93.8, -139.6, 98.0],
+ // torch gets 54.5
+ [-54.4, 135.3, 6.0, -79.1, 134.6],
+ [27.5, -76.0, 43.4, -2.8, -7.8]
+ ]
+ ]
+ );
Ok(())
}