summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend.rs1
-rw-r--r--candle-core/tests/conv_tests.rs14
2 files changed, 15 insertions, 0 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index f912c1b2..05e8c979 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1263,6 +1263,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
+ let k = &k[k_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 211a1fe0..b967515d 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -18,6 +18,9 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
+res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
+print(res.shape)
+print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@@ -59,6 +62,17 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621
],
);
+ let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
+ assert_eq!(res.dims(), [1, 4, 7]);
+ assert_eq!(
+ test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
+ [
+ [-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
+ [0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
+ [1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
+ [1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
+ ]
+ );
Ok(())
}