summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/conv.rs2
-rw-r--r--candle-core/tests/conv_tests.rs16
2 files changed, 17 insertions, 1 deletions
diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs
index 710f6abf..fe923087 100644
--- a/candle-core/src/conv.rs
+++ b/candle-core/src/conv.rs
@@ -196,8 +196,8 @@ impl Tensor {
stride: usize,
dilation: usize,
) -> Result<Self> {
- let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
+ let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index e7fdf138..a5375c11 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -13,6 +13,11 @@ res = torch.nn.functional.conv1d(t, w)
print(res.flatten())
res = torch.nn.functional.conv1d(t, w, padding=1)
print(res.flatten())
+
+w_t = w.transpose(0, 1)
+res = torch.nn.functional.conv_transpose1d(t, w_t)
+print(res.shape)
+print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@@ -45,6 +50,17 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
+ if dev.is_cpu() {
+ let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
+ assert_eq!(res.dims(), [1, 2, 7]);
+ assert_eq!(
+ test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
+ [
+ 0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
+ 4.7076, -5.9745, -0.8276, 1.621
+ ],
+ );
+ }
Ok(())
}