summaryrefslogtreecommitdiff
path: root/candle-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-24 16:53:45 +0100
committerGitHub <noreply@github.com>2023-07-24 16:53:45 +0100
commitb50f932e7cb39d727f21d4e935d7f18eb5a49ad3 (patch)
tree7b0f34fbc91bf67fc7e1f40760c38906d911c234 /candle-kernels/src
parent160ba09d3062ede06a770ca4b8fc5c42b16a2d6a (diff)
downloadcandle-b50f932e7cb39d727f21d4e935d7f18eb5a49ad3.tar.gz
candle-b50f932e7cb39d727f21d4e935d7f18eb5a49ad3.tar.bz2
candle-b50f932e7cb39d727f21d4e935d7f18eb5a49ad3.zip
Add some cmp tests. (#233)
* Add some cmp tests. * Add the cuda kernels for comparison operations.
Diffstat (limited to 'candle-kernels/src')
-rw-r--r--candle-kernels/src/binary.cu42
-rw-r--r--candle-kernels/src/binary_op_macros.cuh24
2 files changed, 56 insertions, 10 deletions
diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu
index c99d96fd..299ae7e3 100644
--- a/candle-kernels/src/binary.cu
+++ b/candle-kernels/src/binary.cu
@@ -6,6 +6,12 @@ BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
BINARY_OP(__nv_bfloat16, bsub_bf16, x - y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, eq_bf16, x == y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, ne_bf16, x != y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y)
+BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y)
#endif
#if __CUDA_ARCH__ >= 530
@@ -13,6 +19,12 @@ BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
BINARY_OP(__half, bsub_f16, x - y)
+BINARY_OP_OUT(__half, uint8_t, eq_f16, x == y)
+BINARY_OP_OUT(__half, uint8_t, ne_f16, x != y)
+BINARY_OP_OUT(__half, uint8_t, lt_f16, x < y)
+BINARY_OP_OUT(__half, uint8_t, le_f16, x <= y)
+BINARY_OP_OUT(__half, uint8_t, gt_f16, x > y)
+BINARY_OP_OUT(__half, uint8_t, ge_f16, x >= y)
#endif
BINARY_OP(float, badd_f32, x + y)
@@ -31,3 +43,33 @@ BINARY_OP(float, bsub_f32, x - y)
BINARY_OP(double, bsub_f64, x - y);
BINARY_OP(uint8_t, bsub_u8, x - y);
BINARY_OP(uint32_t, bsub_u32, x - y);
+
+BINARY_OP_OUT(float, uint8_t, eq_f32, x == y)
+BINARY_OP_OUT(double, uint8_t, eq_f64, x == y)
+BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y)
+BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y)
+
+BINARY_OP_OUT(float, uint8_t, ne_f32, x != y)
+BINARY_OP_OUT(double, uint8_t, ne_f64, x != y)
+BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y)
+BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y)
+
+BINARY_OP_OUT(float, uint8_t, lt_f32, x < y)
+BINARY_OP_OUT(double, uint8_t, lt_f64, x < y)
+BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y)
+BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y)
+
+BINARY_OP_OUT(float, uint8_t, le_f32, x <= y)
+BINARY_OP_OUT(double, uint8_t, le_f64, x <= y)
+BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y)
+BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y)
+
+BINARY_OP_OUT(float, uint8_t, gt_f32, x > y)
+BINARY_OP_OUT(double, uint8_t, gt_f64, x > y)
+BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y)
+BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y)
+
+BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y)
+BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y)
+BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y)
+BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y)
diff --git a/candle-kernels/src/binary_op_macros.cuh b/candle-kernels/src/binary_op_macros.cuh
index 219ee09c..05d0c3df 100644
--- a/candle-kernels/src/binary_op_macros.cuh
+++ b/candle-kernels/src/binary_op_macros.cuh
@@ -1,13 +1,13 @@
#include "cuda_utils.cuh"
-#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \
+#define BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
const size_t *dims_and_strides, \
const TYPENAME *lhs, \
const TYPENAME *rhs, \
- TYPENAME *out \
+ OUT_TYPENAME *out \
) { \
const size_t *dims = dims_and_strides; \
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
@@ -16,8 +16,8 @@ extern "C" __global__ void FN_NAME( \
bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \
if (lhs_cont && rhs_cont) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
- TYPENAME x = lhs ? lhs[i] : out[i]; \
- TYPENAME y = rhs ? rhs[i] : out[i]; \
+ TYPENAME x = lhs[i]; \
+ TYPENAME y = rhs[i]; \
out[i] = FUNC; \
} \
} else if (lhs_cont) { \
@@ -29,8 +29,8 @@ extern "C" __global__ void FN_NAME( \
rhs_i += i_dim * rhs_strides[d]; \
tmp_i /= dims[d]; \
} \
- TYPENAME x = lhs ? lhs[i] : out[i]; \
- TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
+ TYPENAME x = lhs[i]; \
+ TYPENAME y = rhs[rhs_i]; \
out[i] = FUNC; \
} \
} else if (rhs_cont) { \
@@ -42,8 +42,8 @@ extern "C" __global__ void FN_NAME( \
lhs_i += i_dim * lhs_strides[d]; \
tmp_i /= dims[d]; \
} \
- TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
- TYPENAME y = rhs ? rhs[i] : out[i]; \
+ TYPENAME x = lhs[lhs_i]; \
+ TYPENAME y = rhs[i]; \
out[i] = FUNC; \
} \
} else { \
@@ -57,9 +57,13 @@ extern "C" __global__ void FN_NAME( \
rhs_i += i_dim * rhs_strides[d]; \
tmp_i /= dims[d]; \
} \
- TYPENAME x = lhs ? lhs[lhs_i] : out[i]; \
- TYPENAME y = rhs ? rhs[rhs_i] : out[i]; \
+ TYPENAME x = lhs[lhs_i]; \
+ TYPENAME y = rhs[rhs_i]; \
out[i] = FUNC; \
} \
} \
} \
+
+
+#define BINARY_OP(TYPENAME, FN_NAME, FUNC) \
+ BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC)