diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-29 16:37:42 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 16:37:42 +0100 |
commit | 71221559d306b1d504820a9561533cb521ffb39a (patch) | |
tree | 556d608e4e395495b7820ddad66e71b95fc25954 /candle-kernels/src/conv.cu | |
parent | a044907ffce553a0394db3a1204f21e3691e54af (diff) | |
download | candle-71221559d306b1d504820a9561533cb521ffb39a.tar.gz candle-71221559d306b1d504820a9561533cb521ffb39a.tar.bz2 candle-71221559d306b1d504820a9561533cb521ffb39a.zip |
Fix the dilated convolutions. (#659)
Diffstat (limited to 'candle-kernels/src/conv.cu')
-rw-r--r-- | candle-kernels/src/conv.cu | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index c67a4300..91f4c7b2 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -92,13 +92,13 @@ __device__ void conv2d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { - size_t src_w = (stride * dst_w + w_offset) * dilation; + size_t src_w = stride * dst_w + w_offset * dilation; if (src_w < padding || src_w >= w_in + padding) { continue; } src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { - size_t src_h = (stride * dst_h + h_offset) * dilation; + size_t src_h = stride * dst_h + h_offset * dilation; if (src_h < padding || src_h >= h_in + padding) { continue; } |