diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-08 18:13:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-08 18:13:25 +0100 |
commit | eb64ad0d4d2303343c6717e49e211b6b75dfcdad (patch) | |
tree | 0e5cb671b84c7d608f521b5ed2aceca3af9de257 /candle-kernels/src | |
parent | 5c3864f9f765f3283fc82bbe1ea1aafd45adbad6 (diff) | |
download | candle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.tar.gz candle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.tar.bz2 candle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.zip |
Cuda kernel for the conv1d op (#111)
* Boilerplate code for conv1d.
* Boilerplate code for conv1d.
* More boilerplate for conv1d.
* Conv1d work.
* Get the conv1d cuda kernel to work.
* Conv1d support when no batch dim.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/conv.cu | 74 | ||||
-rw-r--r-- | candle-kernels/src/lib.rs | 1 |
2 files changed, 75 insertions, 0 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu new file mode 100644 index 00000000..55ea7863 --- /dev/null +++ b/candle-kernels/src/conv.cu @@ -0,0 +1,74 @@ +#include "cuda_utils.cuh" +#include<stdint.h> + +template <typename T> +__device__ void conv1d( + const size_t src_numel, + const size_t l_out, + const size_t stride, + const size_t *info, + const T *src, + const T *kernel, + T *dst +) { + // src: (b_size, c_in, l_in) + // k: (c_out, c_in, k_size) + const size_t *src_dims = info; + const size_t *src_s = info + 3; + const size_t *k_dims = info + 6; + const size_t *k_s = info + 9; + const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x; + const size_t k_size = k_dims[2]; + const size_t k_over_2 = k_size / 2; + const size_t c_out = k_dims[0]; + const size_t c_in = src_dims[1]; + const size_t l_in = src_dims[2]; + + // TODO + const size_t b_idx = dst_i / (l_out * c_out); + const size_t dst_c_idx = (dst_i / l_out) % c_out; + const size_t dst_l = dst_i % l_out; + + const size_t src_idx0 = b_idx * src_s[0]; + T d = 0; + for (size_t offset = 0; offset < k_size; ++offset) { + const size_t src_l_plus = stride * dst_l + offset; + if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) { + const size_t src_l = src_l_plus - k_over_2; + for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) { + const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2]; + const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2]; + d += src[src_idx] * kernel[k_idx]; + } + } + } + dst[dst_i] = d; +} + + +#define CONV1D_OP(TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t src_numel, \ + const size_t num_dims, \ + const size_t stride, \ + const size_t *info, \ + const TYPENAME *src, \ + const TYPENAME *kernel, \ + TYPENAME *dst \ +) { \ + conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \ +} \ + +#if __CUDA_ARCH__ >= 800 +CONV1D_OP(__nv_bfloat16, conv1d_bf16) +#endif + +#if __CUDA_ARCH__ >= 530 +CONV1D_OP(__half, conv1d_f16) +#endif + +CONV1D_OP(float, conv1d_f32) +CONV1D_OP(double, conv1d_f64) +CONV1D_OP(uint8_t, conv1d_u8) +CONV1D_OP(uint32_t, conv1d_u32) + diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index c3a927ad..b9d12b7b 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -1,6 +1,7 @@ pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx")); pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); +pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); |