diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-14 13:12:17 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-14 13:12:17 +0100 |
commit | c84883ecf2c240792392353175b634f6ec92a011 (patch) | |
tree | 10b14324310421802a68669485c75cc3dcc16c48 /candle-kernels/src | |
parent | a094dc503d69a6ca3db71098ebc26d0d2f2a33a6 (diff) | |
download | candle-c84883ecf2c240792392353175b634f6ec92a011.tar.gz candle-c84883ecf2c240792392353175b634f6ec92a011.tar.bz2 candle-c84883ecf2c240792392353175b634f6ec92a011.zip |
Add a cuda kernel for upsampling. (#441)
* Add a cuda kernel for upsampling.
* Update for the latest tokenizers version.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/conv.cu | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 2da4d401..afda7d1d 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -220,6 +220,48 @@ __device__ void max_pool2d( dst[dst_i] = d; } +template <typename T> +__device__ void upsample_nearest2d( + const size_t w_out, + const size_t h_out, + const double w_scale, + const double h_scale, + const size_t *info, + const T *src, + T *dst +) { + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + // src: (b_size, c_in, w_in, h_in) + const size_t *src_dims = info; + const size_t *src_s = info + 4; + + const size_t c = src_dims[1]; + const size_t w_in = src_dims[2]; + const size_t h_in = src_dims[3]; + + if (dst_i >= src_dims[0] * c * w_out * h_out) { + return; + } + + // TODO: Improve this. + const size_t b_idx = dst_i / (w_out * h_out * c); + const size_t c_idx = (dst_i / (w_out * h_out)) % c; + const size_t dst_w = (dst_i / h_out) % w_out; + const size_t dst_h = dst_i % h_out; + + size_t src_w = static_cast<size_t>(dst_w * w_scale); + size_t src_h = static_cast<size_t>(dst_h * h_scale); + if (src_w >= w_in) { + src_w = w_in - 1; + } + if (src_h >= h_in) { + src_h = h_in - 1; + } + + const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; + dst[dst_i] = src[src_i]; +} + #define CONV1D_OP(TYPENAME, TYPEACC, FN_NAME) \ extern "C" __global__ void FN_NAME( \ @@ -278,11 +320,25 @@ extern "C" __global__ void FN_NAME( \ max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ } \ +#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t w_out, \ + const size_t h_out, \ + const double w_scale, \ + const double h_scale, \ + const size_t *info, \ + const TYPENAME *src, \ + TYPENAME *dst \ +) { \ + upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, info, src, dst); \ +} \ + #if __CUDA_ARCH__ >= 800 CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) +UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) #endif #if __CUDA_ARCH__ >= 530 @@ -290,6 +346,7 @@ CONV1D_OP(__half, float, conv1d_f16) CONV2D_OP(__half, float, conv2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16) +UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) #endif CONV1D_OP(float, float, conv1d_f32) @@ -311,3 +368,8 @@ MAX_POOL2D_OP(float, max_pool2d_f32) MAX_POOL2D_OP(double, max_pool2d_f64) MAX_POOL2D_OP(uint8_t, max_pool2d_u8) MAX_POOL2D_OP(uint32_t, max_pool2d_u32) + +UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32) +UPSAMPLE_NEAREST2D_OP(double, upsample_nearest2d_f64) +UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8) +UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32) |