diff options
author | Gonzalo <456459+grzuy@users.noreply.github.com> | 2023-09-29 11:49:30 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-29 15:49:30 +0100 |
commit | fc59bc31bf49e707d64ed25fd29a8803f9a12fb4 (patch) | |
tree | d1f62fff87b09d25659ea26ed1dadc4f2e0909a1 | |
parent | 03348e2e6f6c1904b4c41504c2fcd5887366c437 (diff) | |
download | candle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.tar.gz candle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.tar.bz2 candle-fc59bc31bf49e707d64ed25fd29a8803f9a12fb4.zip |
fix: add missing gpu fill_* (#996)
-rw-r--r-- | candle-core/tests/tensor_tests.rs | 26 | ||||
-rw-r--r-- | candle-kernels/src/fill.cu | 9 |
2 files changed, 35 insertions, 0 deletions
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index dbe0dd6a..d3eede48 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -8,6 +8,31 @@ fn zeros(device: &Device) -> Result<()> { Ok(()) } +fn ones(device: &Device) -> Result<()> { + assert_eq!( + Tensor::ones((2, 3), DType::U8, device)?.to_vec2::<u8>()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::U32, device)?.to_vec2::<u32>()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::I64, device)?.to_vec2::<i64>()?, + [[1, 1, 1], [1, 1, 1]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?, + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ); + assert_eq!( + Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?, + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ); + + Ok(()) +} + fn add_mul(device: &Device) -> Result<()> { let tensor = Tensor::new(&[3f32, 1., 4.], device)?; let dim1 = tensor.dims1()?; @@ -966,6 +991,7 @@ fn randn(device: &Device) -> Result<()> { } test_device!(zeros, zeros_cpu, zeros_gpu); +test_device!(ones, ones_cpu, ones_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); test_device!(narrow, narrow_cpu, narrow_gpu); diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index e24ac1c8..883ca072 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -1,3 +1,4 @@ +#include<stdint.h> #include "cuda_fp16.h" template<typename T> @@ -6,6 +7,14 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { buf[i] = value; } } +extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } + +#if __CUDA_ARCH__ >= 800 +#include <cuda_bf16.h> +extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } +#endif |