diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-29 16:12:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 16:12:11 +0100 |
commit | a044907ffce553a0394db3a1204f21e3691e54af (patch) | |
tree | 8ce11fae8ee11e4eb181f7240344994356625791 /candle-kernels/src | |
parent | ee8bb1bde1a44738c314dfaacba743f4eabf917c (diff) | |
download | candle-a044907ffce553a0394db3a1204f21e3691e54af.tar.gz candle-a044907ffce553a0394db3a1204f21e3691e54af.tar.bz2 candle-a044907ffce553a0394db3a1204f21e3691e54af.zip |
Dilated convolutions (#657)
* Add the dilation parameter.
* Restore the basic optimizer example.
* Dilation support in cudnn.
* Use the dilation parameter in the cpu backend.
* More dilation support.
* No support for dilation in transposed convolutions.
* Add dilation to a test.
* Remove a print.
* Helper function.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/conv.cu | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 5ccce317..c67a4300 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -8,6 +8,7 @@ __device__ void conv1d( const size_t l_out, const size_t stride, const size_t padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -36,7 +37,7 @@ __device__ void conv1d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t offset = 0; offset < k_size; ++offset) { - size_t src_l = stride * dst_l + offset; + size_t src_l = (stride * dst_l + offset) * dilation; if (src_l < padding || src_l >= padding + l_in) { continue; } @@ -58,6 +59,7 @@ __device__ void conv2d( const size_t h_out, const size_t stride, const size_t padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -90,13 +92,13 @@ __device__ void conv2d( const size_t src_idx0 = b_idx * src_s[0]; A d = 0; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { - size_t src_w = stride * dst_w + w_offset; + size_t src_w = (stride * dst_w + w_offset) * dilation; if (src_w < padding || src_w >= w_in + padding) { continue; } src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { - size_t src_h = stride * dst_h + h_offset; + size_t src_h = (stride * dst_h + h_offset) * dilation; if (src_h < padding || src_h >= h_in + padding) { continue; } @@ -120,6 +122,7 @@ __device__ void conv_transpose2d( const size_t stride, const size_t padding, const size_t out_padding, + const size_t dilation, const size_t *info, const T *src, const T *kernel, @@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \ const size_t num_dims, \ const size_t stride, \ const size_t padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \ + conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \ } \ #define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \ @@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \ const size_t h_out, \ const size_t stride, \ const size_t padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \ + conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \ } \ #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ @@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \ const size_t stride, \ const size_t padding, \ const size_t out_padding, \ + const size_t dilation, \ const size_t *info, \ const TYPENAME *src, \ const TYPENAME *kernel, \ TYPENAME *dst \ ) { \ - conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \ + conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \ } \ #define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \ |