summaryrefslogtreecommitdiff
path: root/candle-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-23 10:42:19 +0100
committerGitHub <noreply@github.com>2023-08-23 10:42:19 +0100
commit9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8 (patch)
tree4c7fef2cdb78409ca30e14981c783d717cd49f97 /candle-kernels
parent3743bed2d7bc02069770902e4a956aeabaef5453 (diff)
downloadcandle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.gz
candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.tar.bz2
candle-9a5c7db91a40bfeab1dbaf1622c67a21f5ad19b8.zip
Add support for i64 (#563)
* Add the i64 dtype. * Adapt the cuda kernels.
Diffstat (limited to 'candle-kernels')
-rw-r--r--candle-kernels/src/affine.cu1
-rw-r--r--candle-kernels/src/binary.cu12
-rw-r--r--candle-kernels/src/cuda_utils.cuh2
-rw-r--r--candle-kernels/src/indexing.cu38
-rw-r--r--candle-kernels/src/reduce.cu1
-rw-r--r--candle-kernels/src/ternary.cu12
6 files changed, 65 insertions, 1 deletions
diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu
index a02ce7a6..152b9463 100644
--- a/candle-kernels/src/affine.cu
+++ b/candle-kernels/src/affine.cu
@@ -40,3 +40,4 @@ AFFINE_OP(float, affine_f32)
AFFINE_OP(double, affine_f64)
AFFINE_OP(uint8_t, affine_u8)
AFFINE_OP(uint32_t, affine_u32)
+AFFINE_OP(int64_t, affine_i64)
diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu
index bd3c2a88..d44e3b20 100644
--- a/candle-kernels/src/binary.cu
+++ b/candle-kernels/src/binary.cu
@@ -35,53 +35,65 @@ BINARY_OP(float, badd_f32, x + y)
BINARY_OP(double, badd_f64, x + y);
BINARY_OP(uint8_t, badd_u8, x + y);
BINARY_OP(uint32_t, badd_u32, x + y);
+BINARY_OP(int64_t, badd_i64, x + y);
BINARY_OP(float, bdiv_f32, x / y)
BINARY_OP(double, bdiv_f64, x / y);
BINARY_OP(uint8_t, bdiv_u8, x / y);
BINARY_OP(uint32_t, bdiv_u32, x / y);
+BINARY_OP(int64_t, bdiv_i64, x / y);
BINARY_OP(float, bmul_f32, x * y)
BINARY_OP(double, bmul_f64, x * y);
BINARY_OP(uint8_t, bmul_u8, x * y);
BINARY_OP(uint32_t, bmul_u32, x * y);
+BINARY_OP(int64_t, bmul_i64, x * y);
BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y);
BINARY_OP(uint8_t, bsub_u8, x - y);
BINARY_OP(uint32_t, bsub_u32, x - y);
+BINARY_OP(int64_t, bsub_i64, x - y);
BINARY_OP(float, bminimum_f32, ming(x, y));
BINARY_OP(double, bminimum_f64, ming(x, y));
BINARY_OP(uint8_t, bminimum_u8, ming(x, y));
BINARY_OP(uint32_t, bminimum_u32, ming(x, y));
+BINARY_OP(int64_t, bminimum_i64, ming(x, y));
BINARY_OP(float, bmaximum_f32, maxg(x, y));
BINARY_OP(double, bmaximum_f64, maxg(x, y));
BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y));
BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y));
+BINARY_OP(int64_t, bmaximum_i64, maxg(x, y));
BINARY_OP_OUT(float, uint8_t, eq_f32, x == y)
BINARY_OP_OUT(double, uint8_t, eq_f64, x == y)
BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y)
BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y)
+BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y)
BINARY_OP_OUT(float, uint8_t, ne_f32, x != y)
BINARY_OP_OUT(double, uint8_t, ne_f64, x != y)
BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y)
BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y)
+BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y)
BINARY_OP_OUT(float, uint8_t, lt_f32, x < y)
BINARY_OP_OUT(double, uint8_t, lt_f64, x < y)
BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y)
BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y)
+BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y)
BINARY_OP_OUT(float, uint8_t, le_f32, x <= y)
BINARY_OP_OUT(double, uint8_t, le_f64, x <= y)
BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y)
BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y)
+BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y)
BINARY_OP_OUT(float, uint8_t, gt_f32, x > y)
BINARY_OP_OUT(double, uint8_t, gt_f64, x > y)
BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y)
BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y)
+BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y)
BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y)
BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y)
BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y)
BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y)
+BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y)
diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh
index ffdf4026..4096d2d1 100644
--- a/candle-kernels/src/cuda_utils.cuh
+++ b/candle-kernels/src/cuda_utils.cuh
@@ -142,6 +142,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); }
__device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); }
__device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); }
+__device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); }
+__device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); }
__device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); }
__device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, b); }
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu
index 7723d3bc..c57be129 100644
--- a/candle-kernels/src/indexing.cu
+++ b/candle-kernels/src/indexing.cu
@@ -151,19 +151,25 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800
+IS_OP(__nv_bfloat16, int64_t, is_i64_bf16)
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16)
+GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16)
GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16)
GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
+IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16)
IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16)
IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
+SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
+IS_OP(__half, int64_t, is_i64_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
+GATHER_OP(__half, int64_t, gather_i64_f16)
GATHER_OP(__half, uint32_t, gather_u32_f16)
GATHER_OP(__half, uint8_t, gather_u8_f16)
IA_OP(__half, uint32_t, ia_u32_f16)
@@ -172,42 +178,74 @@ SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
#endif
+IS_OP(float, int64_t, is_i64_f32)
+IS_OP(double, int64_t, is_i64_f64)
+IS_OP(uint8_t, int64_t, is_i64_u8)
+IS_OP(uint32_t, int64_t, is_i64_u32)
+IS_OP(int64_t, int64_t, is_i64_i64)
+
IS_OP(float, uint32_t, is_u32_f32)
IS_OP(double, uint32_t, is_u32_f64)
IS_OP(uint8_t, uint32_t, is_u32_u8)
+IS_OP(int64_t, uint32_t, is_u32_i64)
IS_OP(uint32_t, uint32_t, is_u32_u32)
IS_OP(float, uint8_t, is_u8_f32)
IS_OP(double, uint8_t, is_u8_f64)
IS_OP(uint8_t, uint8_t, is_u8_u8)
IS_OP(uint32_t, uint8_t, is_u8_u32)
+IS_OP(int64_t, uint8_t, is_u8_i64)
+
+GATHER_OP(float, int64_t, gather_i64_f32)
+GATHER_OP(double, int64_t, gather_i64_f64)
+GATHER_OP(uint8_t, int64_t, gather_i64_u8)
+GATHER_OP(uint32_t, int64_t, gather_i64_u32)
+GATHER_OP(int64_t, int64_t, gather_i64_i64)
GATHER_OP(float, uint32_t, gather_u32_f32)
GATHER_OP(double, uint32_t, gather_u32_f64)
GATHER_OP(uint8_t, uint32_t, gather_u32_u8)
+GATHER_OP(int64_t, uint32_t, gather_u32_i64)
GATHER_OP(uint32_t, uint32_t, gather_u32_u32)
GATHER_OP(float, uint8_t, gather_u8_f32)
GATHER_OP(double, uint8_t, gather_u8_f64)
GATHER_OP(uint8_t, uint8_t, gather_u8_u8)
GATHER_OP(uint32_t, uint8_t, gather_u8_u32)
+GATHER_OP(int64_t, uint8_t, gather_u8_i64)
+
+IA_OP(float, int64_t, ia_i64_f32)
+IA_OP(double, int64_t, ia_i64_f64)
+IA_OP(uint8_t, int64_t, ia_i64_u8)
+IA_OP(int64_t, int64_t, ia_i64_i64)
+IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(double, uint32_t, ia_u32_f64)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
+IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(double, uint8_t, ia_u8_f64)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
+IA_OP(int64_t, uint8_t, ia_u8_i64)
+
+SA_OP(float, int64_t, sa_i64_f32)
+SA_OP(double, int64_t, sa_i64_f64)
+SA_OP(uint8_t, int64_t, sa_i64_u8)
+SA_OP(int64_t, int64_t, sa_i64_i64)
+SA_OP(uint32_t, int64_t, sa_i64_u32)
SA_OP(float, uint32_t, sa_u32_f32)
SA_OP(double, uint32_t, sa_u32_f64)
SA_OP(uint8_t, uint32_t, sa_u32_u8)
+SA_OP(int64_t, uint32_t, sa_u32_i64)
SA_OP(uint32_t, uint32_t, sa_u32_u32)
SA_OP(float, uint8_t, sa_u8_f32)
SA_OP(double, uint8_t, sa_u8_f64)
SA_OP(uint8_t, uint8_t, sa_u8_u8)
SA_OP(uint32_t, uint8_t, sa_u8_u32)
+SA_OP(int64_t, uint8_t, sa_u8_i64)
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 9d4fc710..271502c5 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -307,4 +307,5 @@ SUM_OP(uint32_t, sum_u32)
FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32)
FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64)
FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32)
+FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64)
FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8)
diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu
index eceb45c8..aaa8a881 100644
--- a/candle-kernels/src/ternary.cu
+++ b/candle-kernels/src/ternary.cu
@@ -33,21 +33,31 @@ extern "C" __global__ void FN_NAME( \
} \
#if __CUDA_ARCH__ >= 800
+WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16)
WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16)
WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16)
#endif
#if __CUDA_ARCH__ >= 530
+WHERE_OP(__half, int64_t, where_i64_f16)
WHERE_OP(__half, uint32_t, where_u32_f16)
WHERE_OP(__half, uint8_t, where_u8_f16)
#endif
+WHERE_OP(float, int64_t, where_i64_f32)
+WHERE_OP(double, int64_t, where_i64_f64)
+WHERE_OP(uint8_t, int64_t, where_i64_u8)
+WHERE_OP(uint32_t, int64_t, where_i64_u32)
+WHERE_OP(int64_t, int64_t, where_i64_i64)
+
WHERE_OP(float, uint32_t, where_u32_f32)
WHERE_OP(double, uint32_t, where_u32_f64)
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
+WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(double, uint8_t, where_u8_f64)
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
-WHERE_OP(uint8_t, uint32_t, where_u8_u32)
+WHERE_OP(uint32_t, uint8_t, where_u8_u32)
+WHERE_OP(int64_t, uint8_t, where_u8_i64)