diff options
Diffstat (limited to 'candle-kernels/src')
-rw-r--r-- | candle-kernels/src/fill.cu | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index e24ac1c8..883ca072 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -1,3 +1,4 @@ +#include<stdint.h> #include "cuda_fp16.h" template<typename T> @@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { buf[i] = value; } } +extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } + +#if __CUDA_ARCH__ >= 800 +#include <cuda_bf16.h> +extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } +#endif |