summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-08 18:13:25 +0100
committerGitHub <noreply@github.com>2023-07-08 18:13:25 +0100
commiteb64ad0d4d2303343c6717e49e211b6b75dfcdad (patch)
tree0e5cb671b84c7d608f521b5ed2aceca3af9de257 /candle-kernels/src
parent5c3864f9f765f3283fc82bbe1ea1aafd45adbad6 (diff)
downloadcandle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.tar.gz
candle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.tar.bz2
candle-eb64ad0d4d2303343c6717e49e211b6b75dfcdad.zip
Cuda kernel for the conv1d op (#111)
* Boilerplate code for conv1d. * Boilerplate code for conv1d. * More boilerplate for conv1d. * Conv1d work. * Get the conv1d cuda kernel to work. * Conv1d support when no batch dim.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/conv.cu74
-rw-r--r--candle-kernels/src/lib.rs1
2 files changed, 75 insertions, 0 deletions
diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu
new file mode 100644
index 00000000..55ea7863
--- /dev/null
+++ b/candle-kernels/src/conv.cu
@@ -0,0 +1,74 @@
+#include "cuda_utils.cuh"
+#include<stdint.h>
+
+template <typename T>
+__device__ void conv1d(
+ const size_t src_numel,
+ const size_t l_out,
+ const size_t stride,
+ const size_t *info,
+ const T *src,
+ const T *kernel,
+ T *dst
+) {
+ // src: (b_size, c_in, l_in)
+ // k: (c_out, c_in, k_size)
+ const size_t *src_dims = info;
+ const size_t *src_s = info + 3;
+ const size_t *k_dims = info + 6;
+ const size_t *k_s = info + 9;
+ const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
+ const size_t k_size = k_dims[2];
+ const size_t k_over_2 = k_size / 2;
+ const size_t c_out = k_dims[0];
+ const size_t c_in = src_dims[1];
+ const size_t l_in = src_dims[2];
+
+ // TODO
+ const size_t b_idx = dst_i / (l_out * c_out);
+ const size_t dst_c_idx = (dst_i / l_out) % c_out;
+ const size_t dst_l = dst_i % l_out;
+
+ const size_t src_idx0 = b_idx * src_s[0];
+ T d = 0;
+ for (size_t offset = 0; offset < k_size; ++offset) {
+ const size_t src_l_plus = stride * dst_l + offset;
+ if (k_over_2 <= src_l_plus && src_l_plus < l_in + k_over_2) {
+ const size_t src_l = src_l_plus - k_over_2;
+ for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
+ const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + src_l * src_s[2];
+ const size_t k_idx = dst_c_idx * k_s[0] + src_c_idx * k_s[1] + offset * k_s[2];
+ d += src[src_idx] * kernel[k_idx];
+ }
+ }
+ }
+ dst[dst_i] = d;
+}
+
+
+#define CONV1D_OP(TYPENAME, FN_NAME) \
+extern "C" __global__ void FN_NAME( \
+ const size_t src_numel, \
+ const size_t num_dims, \
+ const size_t stride, \
+ const size_t *info, \
+ const TYPENAME *src, \
+ const TYPENAME *kernel, \
+ TYPENAME *dst \
+) { \
+ conv1d(src_numel, num_dims, stride, info, src, kernel, dst); \
+} \
+
+#if __CUDA_ARCH__ >= 800
+CONV1D_OP(__nv_bfloat16, conv1d_bf16)
+#endif
+
+#if __CUDA_ARCH__ >= 530
+CONV1D_OP(__half, conv1d_f16)
+#endif
+
+CONV1D_OP(float, conv1d_f32)
+CONV1D_OP(double, conv1d_f64)
+CONV1D_OP(uint8_t, conv1d_u8)
+CONV1D_OP(uint32_t, conv1d_u32)
+
diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs
index c3a927ad..b9d12b7b 100644
--- a/candle-kernels/src/lib.rs
+++ b/candle-kernels/src/lib.rs
@@ -1,6 +1,7 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
+pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));