summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-04-04 12:39:06 -0400
committerGitHub <noreply@github.com>2024-04-04 18:39:06 +0200
commit5aebe53dd2470db731bd9ce2d65914f86f1542f7 (patch)
tree0046d3cde6a2a521e4122aacc8ae5bdce176f399 /candle-metal-kernels
parentf76bb7794aa8659c5023797979a3392cdfc01f32 (diff)
downloadcandle-5aebe53dd2470db731bd9ce2d65914f86f1542f7.tar.gz
candle-5aebe53dd2470db731bd9ce2d65914f86f1542f7.tar.bz2
candle-5aebe53dd2470db731bd9ce2d65914f86f1542f7.zip
update dtypes checks for several metal operations (#2010)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/binary.metal22
-rw-r--r--candle-metal-kernels/src/reduce.metal4
2 files changed, 20 insertions, 6 deletions
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index ae11286a..e83498e4 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -60,21 +60,24 @@ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
-#define INT64_BINARY_OP(NAME, FN) \
-BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
-
-#define BFLOAT_BINARY_OP(FN, NAME) \
-BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
-
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
+#define INT64_BINARY_OP(NAME, FN) \
+BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
+
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
+#define BFLOAT_BINARY_OP(FN, NAME) \
+BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
+
+#define BFLOAT_BINARY_OP_OUT(NAME, FN) \
+BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided);
+
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
@@ -112,4 +115,11 @@ BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
+
+BFLOAT_BINARY_OP_OUT(eq, x == y)
+BFLOAT_BINARY_OP_OUT(ne, x != y)
+BFLOAT_BINARY_OP_OUT(le, x <= y)
+BFLOAT_BINARY_OP_OUT(lt, x < y)
+BFLOAT_BINARY_OP_OUT(ge, x >= y)
+BFLOAT_BINARY_OP_OUT(gt, x > y)
#endif
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 561d1744..acb69299 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -484,9 +484,13 @@ ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#if defined(__HAVE_BFLOAT__)
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
+REDUCE(x + y, fast_sum_bf16_strided, half, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
+REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
+REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
+REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)