diff options
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/conv.cu | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fed920f1..fa834faa 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -98,6 +98,50 @@ __device__ void im2col1d( } template <typename T> +__device__ void col2im1d( + const size_t dst_el, + const size_t l_out, + const size_t l_in, + const size_t c_out, + const size_t k_size, + const size_t stride, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, l_in, c_out, l_k) + // dst: (b_size, c_out, l_out) + if (dst_i >= dst_el) { + return; + } + + const size_t dst_s0 = c_out * l_out; + const size_t dst_s1 = l_out; + const size_t src_s0 = c_out * k_size * l_in; + const size_t src_s1 = c_out * k_size; + const size_t src_s2 = k_size; + + size_t tmp_dst_i = dst_i; + const size_t b_idx = tmp_dst_i / dst_s0; + tmp_dst_i -= b_idx * dst_s0; + const size_t c_idx = tmp_dst_i / dst_s1; + tmp_dst_i -= c_idx * dst_s1; + const int l_out_idx = tmp_dst_i; + + dst[dst_i] = static_cast<T>(0); + + int l_in_idx = l_out_idx / stride; + int k0 = l_out_idx - l_in_idx * stride; + // l_out_idx = l_in_idx * stride + k0 + for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) { + if (l_in_idx < l_in) { + const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0; + dst[dst_i] += src[src_i]; + } + } +} + +template <typename T> __device__ void im2col( const size_t dst_numel, const size_t h_out, @@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \ im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \ } \ +#define COL2IM1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t dst_el, \ + const size_t l_out, \ + const size_t l_in, \ + const size_t c_out, \ + const size_t k_size, \ + const size_t stride, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \ +} \ + #define IM2COL_OP(TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t dst_numel, \ @@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) +COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) IM2COL_OP(__half, im2col_f16) IM2COL1D_OP(__half, im2col1d_f16) +COL2IM1D_OP(__half, col2im1d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(double, im2col1d_f64) IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint32_t, im2col1d_u32) + +COL2IM1D_OP(float, col2im1d_f32) +COL2IM1D_OP(double, col2im1d_f64) +COL2IM1D_OP(uint8_t, col2im1d_u8) +COL2IM1D_OP(uint32_t, col2im1d_u32) |