diff options
author | Kirpal Grewal <45569241+KGrewal1@users.noreply.github.com> | 2024-03-23 06:05:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-23 07:05:55 +0100 |
commit | cc856db9ce2541e09731165f88cdd7aae37f558e (patch) | |
tree | 5ffa3b722f84504ba2208c439c2b9253f3ffa1eb /candle-core/tests | |
parent | fc1fe5e45b046771589126c355fdfb4d3bb49fe4 (diff) | |
download | candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.gz candle-cc856db9ce2541e09731165f88cdd7aae37f558e.tar.bz2 candle-cc856db9ce2541e09731165f88cdd7aae37f558e.zip |
Backwards for ConvTranspose2D (#1910)
* add documentation for nackprop
* add backwards for ConvTranspose2D
* add test python code to test
Diffstat (limited to 'candle-core/tests')
-rw-r--r-- | candle-core/tests/conv_tests.rs | 161 |
1 files changed, 154 insertions, 7 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 6cc48ec7..3762e02f 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -135,7 +135,7 @@ fn conv2d(dev: &Device) -> Result<()> { 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, - -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, ], dev, )?; @@ -276,11 +276,10 @@ fn conv2d_small(dev: &Device) -> Result<()> { assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [ - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1640, -0.0111, -0.1742, 0.0000, 0.0000, - 0.0000, 0.0000, 2.6437, -2.0268, 1.1823, 0.0000, 0.0000, 0.0000, 0.0000, 3.2855, - -1.0324, 0.2539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, - 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1640, + -0.0111, -0.1742, 0.0, 0.0, 0.0, 0.0, 2.6437, -2.0268, 1.1823, 0.0, 0.0, 0.0, 0.0, + 3.2855, -1.0324, 0.2539, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0 ] ); @@ -398,7 +397,7 @@ fn conv2d_grad(dev: &Device) -> Result<()> { 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.7875, 0.5377, -0.6779, -0.8088, 1.9098, 1.2006, - -0.8000, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, + -0.8, -0.4983, 1.5480, 0.8265, -0.1025, 0.5138, 0.5748, 0.3821, -0.4607, 0.0085, ], (1, 4, 5, 5), dev, @@ -583,6 +582,154 @@ fn conv2d_grad(dev: &Device) -> Result<()> { ] ); + // Conv Transpose 2d Test + //tested against following python + + // import torch + // torch.manual_seed(4242) + // padding = 4 + // outpadding = 2 + // dilation = 3 + // stride = 3 + // input = torch.randn((1, 4, 7, 5), requires_grad=True) + // kernel = torch.randn((4, 2, 3, 5), requires_grad=True) + // print("input", input.flatten()) + // print("kernel", kernel.flatten()) + // res = torch.nn.functional.conv_transpose2d( + // input, + // kernel, + // stride=stride, + // padding=padding, + // dilation=dilation, + // output_padding=outpadding, + // ) + // res.retain_grad() + // print(res.shape) + // loss = (res**2).sum() + // print(loss) + // loss.backward() + // print(input.grad.shape) + // print("input grad", torch.round(input.grad, decimals=1)) + // print(kernel.grad.shape) + // print("kernel grad", torch.round(kernel.grad.flatten(), decimals=1)) + + let padding = 4; + let outpadding = 2; + let dilation = 3; + let stride = 3; + + let t = Var::from_slice( + &[ + 0.4056_f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, + 3.0616, 1.6541, 0.0964, -0.8338, -1.6523, -0.8323, -0.1699, 0.0823, 0.3526, 0.6843, + 0.2395, 1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, + -0.6836, 0.2477, 1.3127, -0.2260, 0.2622, -1.2974, -0.8140, -0.8404, -0.3490, 0.0130, + 1.3123, 1.7569, -0.3956, -1.8255, 0.1727, -0.3538, 2.6941, 1.0529, 0.4219, -0.2071, + 1.1586, 0.4717, 0.3865, -0.5690, -0.5010, -0.1310, 0.7796, 0.6630, -0.2021, 2.6090, + 0.2049, 0.6466, -0.5042, -0.0603, -1.6538, -1.2429, 1.8357, 1.6052, -1.3844, 0.3323, + -1.3712, 0.9634, -0.4799, -0.6451, -0.0840, -1.4247, 0.5512, -0.1747, -0.5509, -0.3742, + 0.3790, -0.4431, -0.4720, -0.7890, 0.2620, 0.5411, -1.1715, -2.4997, 2.3249, -0.8912, + -0.4733, -0.5701, -2.8888, -1.4112, -0.5471, -0.9234, -1.1660, 0.4189, -0.7465, + -0.6473, 0.1402, 0.7875, 0.5377, -0.6779, -0.8088, -0.4864, -0.2312, 0.9279, 0.1264, + 1.5480, 0.8265, -0.1025, 0.5138, -0.2512, 0.1576, 1.2705, 0.3641, -0.9325, 0.6451, + -0.8537, 0.2378, 0.1794, 0.2752, -0.3687, -1.1149, -0.1410, -0.5829, -0.0892, 1.4258, + -2.2789, 0.5270, 0.1825, 1.7007, -0.5263, -0.2954, 0.4440, 0.5537, 0.3492, 0.6186, + 1.6475, 0.2219, + ], + (1, 4, 7, 5), + dev, + )?; + + #[rustfmt::skip] + let w = Var::from_slice( + &[ + -1.1744_f32, 0.3266, 2.5893, 1.0142, 0.1763, 0.7752, 0.6604, 0.2029, -0.2145, 0.7234, + -0.3441, -1.5400, -0.6333, 0.6613, 0.2083, 0.6230, -1.7002, 0.3393, 0.4049, 1.0762, + 0.2723, 1.4181, 0.0029, -0.2122, 1.7668, 1.4168, 0.3320, -0.2719, 0.7932, -0.7204, + 0.4447, 0.1211, 0.5908, 1.0089, -0.1646, 1.8033, -0.6286, 0.2016, -0.3370, 1.2555, + 0.8009, -0.6488, -0.4652, -1.5685, 1.5860, 0.5583, 0.4623, 0.6026, 0.8828, 2.4990, + 0.6811, -0.3369, 1.3320, 1.7669, -1.1067, 1.2958, -0.9415, -0.9655, -0.4462, 0.7181, + 0.5181, -1.1658, -1.8467, -0.7763, 1.2769, 0.8651, 0.9890, 1.5092, 0.7207, -0.8481, + 0.7417, 0.3375, -1.2685, 1.4572, 1.0915, 0.1093, -0.8550, -0.5831, -0.6309, -0.2509, + 0.5220, -0.0914, 0.7900, 0.1096, 0.3258, 0.2723, -1.0942, -0.3393, -0.1653, 0.5732, + -0.8014, 1.8194, -1.9023, 0.2127, 1.8636, -0.8979, 0.1927, -0.2778, 0.3105, 0.0071, + -1.1823, 0.2476, -0.7178, -1.3821, 1.0769, -0.4376, -0.9967, -0.1227, 1.6197, -1.0604, + 0.1372, 0.8141, -0.6163, 0.7304, -0.8285, 2.0636, -0.7176, 0.2495, -0.2581, -0.4478, + ], + (4, 2, 3, 5), + dev, + )?; + let res = t.conv_transpose2d(&w, padding, outpadding, stride, dilation)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 0)?, 2904.0); + let grads = loss.backward()?; + + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 7, 5]); + assert_eq!(grad_w.dims(), [4, 2, 3, 5]); + + assert_eq!( + test_utils::to_vec1_round(&grad_w.flatten_all()?, 1)?, + [ + // torch gets 89.1 + -89.0, -135.3, 136.7, 102.0, -53.4, 117.9, 118.6, -43.9, -218.0, -58.5, -114.3, -150.0, + -15.6, 172.1, 66.3, -64.3, -27.9, -19.8, 31.7, 62.1, 5.5, 92.6, 28.2, -29.6, 55.9, + 52.7, -72.7, -119.8, 53.8, -25.5, 128.8, 19.3, 68.0, 190.9, -64.1, -86.2, -111.2, + 106.6, -67.7, 37.8, 115.9, 50.4, -77.7, -54.9, 22.3, -4.6, 89.8, 61.7, 122.4, 192.6, + -27.8, -104.6, 57.0, 166.4, 27.1, 6.1, 18.7, -93.2, 31.5, 168.2, -3.7, -99.5, -55.5, + -10.8, 17.5, 20.8, 16.9, 43.8, 42.0, -89.2, 18.8, -9.6, -84.1, 212.6, 19.7, -50.0, + -52.0, -40.0, -166.6, -73.2, -10.8, -73.3, 31.5, -23.4, -79.3, -27.0, -84.4, -42.9, + -20.3, 51.8, -16.7, 76.3, -120.5, -65.8, 96.5, -10.7, -45.9, -88.1, 65.4, -7.0, -1.5, + 92.8, -25.1, -114.2, -5.8, -14.8, -51.2, -20.7, 54.2, -79.8, 47.7, -29.2, -8.8, 53.5, + -28.4, 85.0, -18.3, 107.0, 28.3, -71.8 + ] + ); + + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 1)?, + [ + [ + [32.3, -41.6, -24.0, 14.1, 17.6], + [-11.8, 72.5, 87.6, 46.4, 61.5], + [115.0, 108.5, -48.6, -63.4, -50.0], + [51.3, 5.4, 31.3, 91.1, -30.9], + [52.7, 92.8, -68.0, -47.0, 83.0], + // pytorch gets -107.1 + [-10.2, -107.0, -5.4, 213.1, -31.4], + [-2.4, 65.1, 9.2, -146.2, -24.2] + ], + [ + [-72.6, -63.9, -61.9, 45.3, 33.0], + [79.3, -0.5, -26.2, 78.2, 42.7], + [90.9, 141.6, 40.1, -62.7, 37.0], + [32.8, 198.2, -0.8, -31.1, 27.3], + // torch gets 48.0 + [34.5, 34.9, -47.9, 127.6, -12.3], + [-61.4, -3.2, -2.9, -10.9, -16.6], + [74.6, 60.1, -68.9, 34.5, -50.4] + ], + [ + [37.5, -56.9, -43.6, -13.5, -9.9], + [40.0, 97.3, 28.6, 14.2, -30.1], + [-22.3, -126.3, -68.8, -8.2, 26.1], + [-32.9, 37.3, 108.5, -54.8, 29.6], + [34.9, -176.9, -125.0, -28.3, -13.9], + [-54.9, 142.6, 62.1, -80.4, -65.6], + [7.4, -91.1, -67.6, 35.0, 39.7] + ], + [ + [-57.2, -40.9, -10.1, 32.6, 29.4], + [18.7, -18.0, 29.5, -1.2, 59.2], + [-14.0, -74.4, 19.8, -117.0, 58.2], + [-21.8, 163.5, -71.1, -99.0, 80.9], + [-58.9, -10.9, 93.8, -139.6, 98.0], + // torch gets 54.5 + [-54.4, 135.3, 6.0, -79.1, 134.6], + [27.5, -76.0, 43.4, -2.8, -7.8] + ] + ] + ); Ok(()) } |