diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-23 10:42:19 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 10:42:19 +0100 |
commit | 9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8 (patch) | |
tree | 4c7fef2cdb78409ca30e14981c783d717cd49f97 /candle-kernels | |
parent | 3743bed2d7bc02069770902e4a956aeabaef5453 (diff) | |
download | candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.gz candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.bz2 candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.zip |
Add support for i64 (#563)
* Add the i64 dtype.
* Adapt the cuda kernels.
Diffstat (limited to 'candle-kernels')
-rw-r--r-- | candle-kernels/src/affine.cu | 1 | ||||
-rw-r--r-- | candle-kernels/src/binary.cu | 12 | ||||
-rw-r--r-- | candle-kernels/src/cuda_utils.cuh | 2 | ||||
-rw-r--r-- | candle-kernels/src/indexing.cu | 38 | ||||
-rw-r--r-- | candle-kernels/src/reduce.cu | 1 | ||||
-rw-r--r-- | candle-kernels/src/ternary.cu | 12 |
6 files changed, 65 insertions, 1 deletions
diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index a02ce7a6..152b9463 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -40,3 +40,4 @@ AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) +AFFINE_OP(int64_t, affine_i64) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index bd3c2a88..d44e3b20 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,53 +35,65 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index ffdf4026..4096d2d1 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -142,6 +142,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } +__device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); } __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); } __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 7723d3bc..c57be129 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -151,19 +151,25 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) IA_OP(__half, uint32_t, ia_u32_f16) @@ -172,42 +178,74 @@ SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int64_t, is_i64_f32) +IS_OP(double, int64_t, is_i64_f64) +IS_OP(uint8_t, int64_t, is_i64_u8) +IS_OP(uint32_t, int64_t, is_i64_u32) +IS_OP(int64_t, int64_t, is_i64_i64) + IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int64_t, uint8_t, is_u8_i64) + +GATHER_OP(float, int64_t, gather_i64_f32) +GATHER_OP(double, int64_t, gather_i64_f64) +GATHER_OP(uint8_t, int64_t, gather_i64_u8) +GATHER_OP(uint32_t, int64_t, gather_i64_u32) +GATHER_OP(int64_t, int64_t, gather_i64_i64) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int64_t, uint8_t, gather_u8_i64) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(double, int64_t, ia_i64_f64) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int64_t, uint8_t, ia_u8_i64) + +SA_OP(float, int64_t, sa_i64_f32) +SA_OP(double, int64_t, sa_i64_f64) +SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int64_t, int64_t, sa_i64_i64) +SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 9d4fc710..271502c5 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -307,4 +307,5 @@ SUM_OP(uint32_t, sum_u32) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index eceb45c8..aaa8a881 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,21 +33,31 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int64_t, where_i64_f32) +WHERE_OP(double, int64_t, where_i64_f64) +WHERE_OP(uint8_t, int64_t, where_i64_u8) +WHERE_OP(uint32_t, int64_t, where_i64_u32) +WHERE_OP(int64_t, int64_t, where_i64_i64) + WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) -WHERE_OP(uint8_t, uint32_t, where_u8_u32) +WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int64_t, uint8_t, where_u8_i64) |