summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cpu_backend.rs10
-rw-r--r--candle-core/src/cuda_backend.rs6
-rw-r--r--candle-core/tests/conv_tests.rs35
-rw-r--r--candle-kernels/src/conv.cu6
4 files changed, 40 insertions, 17 deletions
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 4a061c39..17d64b10 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1186,12 +1186,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
const OP: &'static str = "conv_transpose2d";
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
- if p.dilation != 1 {
- crate::bail!(
- "dilation {} is not supported for conv-transpose2d",
- p.dilation
- )
- }
let inp = &inp[inp_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
let k = &k[k_l.start_offset()..];
@@ -1235,8 +1229,8 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
for b_idx in 0..p.b_size {
for inp_y in 0..p.i_h {
for inp_x in 0..p.i_w {
- let out_x = inp_x * p.stride + k_x;
- let out_y = inp_y * p.stride + k_y;
+ let out_x = inp_x * p.stride + k_x * p.dilation;
+ let out_y = inp_y * p.stride + k_y * p.dilation;
if out_x < p.padding || out_y < p.padding {
continue;
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index 14a77b52..663f2319 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1046,12 +1046,6 @@ impl<'a> Map2 for ConvTranspose2D<'a> {
// Kernel shape: (c_in_k, c_out, h_k, w_k)
// Input shape: (b_size, c_in, h_in, w_in)
let p = &self.0;
- if p.dilation != 1 {
- crate::bail!(
- "dilation {} is not supported for conv-transpose2d",
- p.dilation
- )
- }
let (out_w, out_h) = (p.out_w(), p.out_h());
let dst_el = p.c_out * out_w * out_h * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs
index 8196a27e..937ddf67 100644
--- a/candle-core/tests/conv_tests.rs
+++ b/candle-core/tests/conv_tests.rs
@@ -85,6 +85,10 @@ print(res)
res = torch.nn.functional.conv2d(t, w, dilation=2)
print(res.shape)
print(res[0])
+
+res = torch.nn.functional.conv_transpose2d(t, w_t, dilation=2)
+print(res.shape)
+print(res)
*/
fn conv2d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@@ -158,6 +162,37 @@ fn conv2d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.45, -2.3504],
);
+
+ // Transpose and dilations.
+ let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
+ assert_eq!(res.dims(), [1, 2, 9, 9]);
+ assert_eq!(
+ test_utils::to_vec3_round(&res.i(0)?, 4)?,
+ [
+ [
+ [-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
+ [2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
+ [-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
+ [-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
+ [-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
+ [0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
+ [-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
+ [4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
+ [5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
+ ],
+ [
+ [1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
+ [-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
+ [1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
+ [1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
+ [1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
+ [3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
+ [5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
+ [-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
+ [-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
+ ]
+ ]
+ );
Ok(())
}
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 91f4c7b2..ba2fa1ad 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -155,15 +155,15 @@ __device__ void conv_transpose2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (int k_x = 0; k_x < (int)w_k; ++k_x) {
- // let out_x = inp_x * p.stride + k_x - p.padding;
- int inp_x_stride = (int)(out_x + padding) - k_x;
+ // let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= w_in) continue;
for (int k_y = 0; k_y < (int)h_k; ++k_y) {
- int inp_y_stride = (int)(out_y + padding) - k_y;
+ int inp_y_stride = (int)(out_y + padding) - k_y * dilation;
if (inp_y_stride < 0 || inp_y_stride % stride) {
continue;
}