summaryrefslogtreecommitdiff
path: root/candle-core/tests/conv_tests.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-08 21:50:20 +0200
committerGitHub <noreply@github.com>2023-08-08 20:50:20 +0100
commit608b2358c6ecb8ded59d31874fca8d426b310133 (patch)
tree120aa5d2fdb6b9498c2956cd78d029299001e542 /candle-core/tests/conv_tests.rs
parent1e6dbeac0133d67c1df75818cd099a6ec5d276b1 (diff)
downloadcandle-608b2358c6ecb8ded59d31874fca8d426b310133.tar.gz
candle-608b2358c6ecb8ded59d31874fca8d426b310133.tar.bz2
candle-608b2358c6ecb8ded59d31874fca8d426b310133.zip
Add some conv1d test + bugfix using padding. (#349)
Diffstat (limited to 'candle-core/tests/conv_tests.rs')
-rw-r--r--candle-core/tests/conv_tests.rs40
1 files changed, 40 insertions, 0 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 59240079..4ef47780 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -6,6 +6,46 @@ use candle_core::{Device, Tensor};
import torch
torch.manual_seed(4242)
+t = torch.randn((1, 4, 5))
+w = torch.randn((2, 4, 3))
+print(t.flatten())
+print(w.flatten())
+res = torch.nn.functional.conv1d(t, w)
+print(res.flatten())
+*/
+#[test]
+fn conv1d() -> Result<()> {
+ let dev = &Device::Cpu;
+ let t = Tensor::new(
+ &[
+ 0.4056f32, -0.8689, -0.0773, -1.5630, 1.2279, -0.9287, -1.7030, 0.1370, 0.1866, 0.4145,
+ 1.8025, -0.1536, 2.2013, -0.6836, 0.2477, 1.3127, -0.6957, 0.3278, -1.0124, 0.5599,
+ ],
+ dev,
+ )?
+ .reshape((1, 4, 5))?;
+ let w = Tensor::new(
+ &[
+ -0.8404f32, -0.3490, 0.0130, 1.3123, 0.1763, -1.9249, 1.4270, 0.9421, 0.8670, -0.7181,
+ -1.1111, 0.8869, -1.2429, 1.8357, 1.6052, -1.3844, 0.3951, -1.2036, 0.6686, 1.6261,
+ -0.6451, -0.0840, -1.4247, 0.5512,
+ ],
+ dev,
+ )?
+ .reshape((2, 4, 3))?;
+ let res = t.conv1d(&w, 0, 1)?;
+ assert_eq!(res.dims(), [1, 2, 3]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069]
+ );
+ Ok(())
+}
+
+/* This test is based on the following script.
+import torch
+torch.manual_seed(4242)
+
t = torch.randn((1, 4, 5, 5))
w = torch.randn((2, 4, 3, 3))
print(t.flatten())