diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-08 21:50:20 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-08 20:50:20 +0100 |
commit | 608b2358c6ecb8ded59d31874fca8d426b310133 (patch) | |
tree | 120aa5d2fdb6b9498c2956cd78d029299001e542 /candle-core/tests/conv_tests.rs | |
parent | 1e6dbeac0133d67c1df75818cd099a6ec5d276b1 (diff) | |
download | candle-608b2358c6ecb8ded59d31874fca8d426b310133.tar.gz candle-608b2358c6ecb8ded59d31874fca8d426b310133.tar.bz2 candle-608b2358c6ecb8ded59d31874fca8d426b310133.zip |
Add some conv1d test + bugfix using padding. (#349)
Diffstat (limited to 'candle-core/tests/conv_tests.rs')
-rw-r--r-- | candle-core/tests/conv_tests.rs | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 59240079..4ef47780 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -6,6 +6,46 @@ use candle_core::{Device, Tensor}; import torch torch.manual_seed(4242) +t = torch.randn((1, 4, 5)) +w = torch.randn((2, 4, 3)) +print(t.flatten()) +print(w.flatten()) +res = torch.nn.functional.conv1d(t, w) +print(res.flatten()) +*/ +#[test] +fn conv1d() -> Result<()> { + let dev = &Device::Cpu; + let t = Tensor::new( + &[ + 0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145, + 1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599, + ], + dev, + )? + .reshape((1, 4, 5))?; + let w = Tensor::new( + &[ + -0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181, + -1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261, + -0.6451, -0.0840, -1.4247, 0.5512, + ], + dev, + )? + .reshape((2, 4, 3))?; + let res = t.conv1d(&w, 0, 1)?; + assert_eq!(res.dims(), [1, 2, 3]); + assert_eq!( + test_utils::to_vec1_round(&res.flatten_all()?, 4)?, + [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069] + ); + Ok(()) +} + +/* This test is based on the following script. +import torch +torch.manual_seed(4242) + t = torch.randn((1, 4, 5, 5)) w = torch.randn((2, 4, 3, 3)) print(t.flatten()) |