summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorGonzalo <456459+grzuy@users.noreply.github.com>2023-09-29 11:49:30 -0300
committerGitHub <noreply@github.com>2023-09-29 15:49:30 +0100
commitfc59bc31bf49e707d64ed25fd29a8803f9a12fb4 (patch)
treed1f62fff87b09d25659ea26ed1dadc4f2e0909a1 /candle-kernels/src
parent03348e2e6f6c1904b4c41504c2fcd5887366c437 (diff)
downloadcandle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.tar.gz
candle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.tar.bz2
candle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.zip
fix: add missing gpu fill_* (#996)
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/fill.cu9
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu
index e24ac1c8..883ca072 100644
--- a/candle-kernels/src/fill.cu
+++ b/candle-kernels/src/fill.cu
@@ -1,3 +1,4 @@
+#include<stdint.h>
#include "cuda_fp16.h"
template<typename T>
@@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
buf[i] = value;
}
}
+extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
+extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
+
+#if __CUDA_ARCH__ >= 800
+#include <cuda_bf16.h>
+extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
+#endif