summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/tests/conv_tests.rs12
1 files changed, 5 insertions, 7 deletions
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 7ec83592..f955b4a5 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -12,6 +12,8 @@ print(t.flatten())
print(w.flatten())
res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
+res = torch.nn.functional.conv1d(t, w, padding=1)
+print(res.flatten())
*/
#[test]
fn conv1d() -> Result<()> {
@@ -41,14 +43,10 @@ fn conv1d() -> Result<()> {
);
let res = t.conv1d(&w, /*padding*/ 1, 1)?;
assert_eq!(res.dims(), [1, 2, 5]);
- /* Note that the default for padding is different from PyTorch at the moment: instead of
- padding with zeros, the edge value from the input tensor is used, i.e. this is similiar to:
- t = torch.nn.functional.pad(t, (1, 1), mode='replicate')
- res = torch.nn.functional.conv1d(t, w, padding=0)
- */
+ // Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [2.5209, 2.6357, -1.3336, 4.1393, 0.4951, 3.6855, -1.1784, 3.5675, 0.5069, 4.9562]
+ [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
Ok(())
}
@@ -68,7 +66,7 @@ fn conv1d_small() -> Result<()> {
assert_eq!(res.dims(), [1, 1, 4]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
- [0.4056, 0.4056, -0.8689, -0.0773],
+ [0.0, 0.4056, -0.8689, -0.0773],
);
Ok(())
}