summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-29 16:12:11 +0100
committerGitHub <noreply@github.com>2023-08-29 16:12:11 +0100
commita044907ffce553a0394db3a1204f21e3691e54af (patch)
tree8ce11fae8ee11e4eb181f7240344994356625791 /candle-kernels/src
parentee8bb1bde1a44738c314dfaacba743f4eabf917c (diff)
downloadcandle-a044907ffce553a0394db3a1204f21e3691e54af.tar.gz
candle-a044907ffce553a0394db3a1204f21e3691e54af.tar.bz2
candle-a044907ffce553a0394db3a1204f21e3691e54af.zip
Dilated convolutions (#657)
* Add the dilation parameter. * Restore the basic optimizer example. * Dilation support in cudnn. * Use the dilation parameter in the cpu backend. * More dilation support. * No support for dilation in transposed convolutions. * Add dilation to a test. * Remove a print. * Helper function.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/conv.cu18
1 files changed, 12 insertions, 6 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 5ccce317..c67a4300 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -8,6 +8,7 @@ __device__ void conv1d(
const size_t l_out,
const size_t stride,
const size_t padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -36,7 +37,7 @@ __device__ void conv1d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t offset = 0; offset < k_size; ++offset) {
- size_t src_l = stride * dst_l + offset;
+ size_t src_l = (stride * dst_l + offset) * dilation;
if (src_l < padding || src_l >= padding + l_in) {
continue;
}
@@ -58,6 +59,7 @@ __device__ void conv2d(
const size_t h_out,
const size_t stride,
const size_t padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -90,13 +92,13 @@ __device__ void conv2d(
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
- size_t src_w = stride * dst_w + w_offset;
+ size_t src_w = (stride * dst_w + w_offset) * dilation;
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
- size_t src_h = stride * dst_h + h_offset;
+ size_t src_h = (stride * dst_h + h_offset) * dilation;
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
@@ -120,6 +122,7 @@ __device__ void conv_transpose2d(
const size_t stride,
const size_t padding,
const size_t out_padding,
+ const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
@@ -335,12 +338,13 @@ extern "C" __global__ void FN_NAME( \
const size_t num_dims, \
const size_t stride, \
const size_t padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, info, src, kernel, dst); \
+ conv1d<TYPENAME, TYPEACC>(src_numel, num_dims, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONV2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@@ -350,12 +354,13 @@ extern "C" __global__ void FN_NAME( \
const size_t h_out, \
const size_t stride, \
const size_t padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
+ conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
@@ -366,12 +371,13 @@ extern "C" __global__ void FN_NAME( \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
+ const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
- conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
+ conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \