summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-28 20:58:49 +0100
committerGitHub <noreply@github.com>2023-08-28 20:58:49 +0100
commit037b41c9dc5255ef5bf03745451c8175ddc4bc37 (patch)
treee2f428f5c188114a79dd08b8146bcdd8ad2c818c /candle-kernels/src
parent72fae3140cd5f3b781bb779601fed4241ba4cb77 (diff)
downloadcandle-037b41c9dc5255ef5bf03745451c8175ddc4bc37.tar.gz
candle-037b41c9dc5255ef5bf03745451c8175ddc4bc37.tar.bz2
candle-037b41c9dc5255ef5bf03745451c8175ddc4bc37.zip
Cuda conv transpose (#645)
* Cuda kernel for conv-transpose. * Fix the cuda kernel. * Fix the tests.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/conv.cu88
1 files changed, 88 insertions, 0 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
index 19d94385..5ccce317 100644
--- a/candle-kernels/src/conv.cu
+++ b/candle-kernels/src/conv.cu
@@ -111,6 +111,71 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d);
}
+// Naive implementation of conv_transpose2d.
+template <typename T, typename A>
+__device__ void conv_transpose2d(
+ const size_t src_numel,
+ const size_t w_out,
+ const size_t h_out,
+ const size_t stride,
+ const size_t padding,
+ const size_t out_padding,
+ const size_t *info,
+ const T *src,
+ const T *kernel,
+ T *dst
+) {
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ // src: (b_size, c_in, h_in, w_in)
+ // k: (c_in, c_out, h_k, w_k)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 4;
+ const size_t *k_dims = info + 8;
+ const size_t *k_s = info + 12;
+ const size_t h_k = k_dims[2];
+ const size_t w_k = k_dims[3];
+ const size_t c_out = k_dims[1];
+ const size_t c_in = src_dims[1];
+ const size_t h_in = src_dims[2];
+ const size_t w_in = src_dims[3];
+ if (dst_i >= src_dims[0] * c_out * w_out * h_out) {
+ return;
+ }
+
+ // TODO
+ const size_t b_idx = dst_i / (w_out * h_out * c_out);
+ const size_t dst_c_idx = (dst_i / (w_out * h_out)) % c_out;
+ // NCHW layout.
+ const size_t out_y = (dst_i / w_out) % h_out;
+ const size_t out_x = dst_i % w_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ A d = 0;
+ for (int k_x = 0; k_x < (int)w_k; ++k_x) {
+ // let out_x = inp_x * p.stride + k_x - p.padding;
+ int inp_x_stride = (int)(out_x + padding) - k_x;
+ if (inp_x_stride < 0 || inp_x_stride % stride) {
+ continue;
+ }
+ int inp_x = inp_x_stride / stride;
+ if (inp_x >= w_in) continue;
+ for (int k_y = 0; k_y < (int)h_k; ++k_y) {
+ int inp_y_stride = (int)(out_y + padding) - k_y;
+ if (inp_y_stride < 0 || inp_y_stride % stride) {
+ continue;
+ }
+ int inp_y = inp_y_stride / stride;
+ if (inp_y >= h_in) continue;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_y * src_s[2] + inp_x * src_s[3];
+ const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_y * k_s[2] + k_x * k_s[3];
+ d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
+ }
+ }
+ }
+ dst[dst_i] = static_cast<T>(d);
+}
+
template <typename T, typename A>
__device__ void avg_pool2d(
const size_t src_numel,
@@ -293,6 +358,22 @@ extern "C" __global__ void FN_NAME( \
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, info, src, kernel, dst); \
} \
+#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t w_out, \
+ const size_t h_out, \
+ const size_t stride, \
+ const size_t padding, \
+ const size_t out_padding, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ const TYPENAME *kernel, \
+ TYPENAME *dst \
+) { \
+ conv_transpose2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, out_padding, info, src, kernel, dst); \
+} \
+
#define AVG_POOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
@@ -337,6 +418,7 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
+CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
@@ -345,6 +427,7 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
#if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16)
+CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
@@ -360,6 +443,11 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
+CONVT2D_OP(float, float, conv_transpose2d_f32)
+CONVT2D_OP(double, double, conv_transpose2d_f64)
+CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
+CONVT2D_OP(uint32_t, uint32_t, conv_transpose2d_u32)
+
AVG_POOL2D_OP(float, float, avg_pool2d_f32)
AVG_POOL2D_OP(double, double, avg_pool2d_f64)
AVG_POOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)