summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-07 11:52:03 +0100
committerIvar Flakstad <69173633+ivarflakstad@users.noreply.github.com>2024-01-07 11:52:03 +0100
commit6ebe04327397ebf7c9400d68c43fede705f8ce75 (patch)
tree4e701ef55370cb6226593019f9e6daaf8ddf73b4 /candle-metal-kernels
parent6bf52b9fdf82ad775611e82924d73172660a605e (diff)
parent84250bf52f58528cf59dca3b82effd9f07a13cc7 (diff)
downloadcandle-6ebe04327397ebf7c9400d68c43fede705f8ce75.tar.gz
candle-6ebe04327397ebf7c9400d68c43fede705f8ce75.tar.bz2
candle-6ebe04327397ebf7c9400d68c43fede705f8ce75.zip
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/binary.metal29
-rw-r--r--candle-metal-kernels/src/cast.metal8
-rw-r--r--candle-metal-kernels/src/lib.rs13
-rw-r--r--candle-metal-kernels/src/reduce.metal14
-rw-r--r--candle-metal-kernels/src/ternary.metal9
-rw-r--r--candle-metal-kernels/src/unary.metal12
6 files changed, 77 insertions, 8 deletions
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index 8c3b4a8c..cdc8fef8 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -56,15 +56,24 @@ kernel void FN_NAME_STRIDED( \
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
-BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
+BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
+BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
+
+#define INT64_BINARY_OP(NAME, FN) \
+BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
-BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
+BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
+BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
+#define INT64_BINARY_OP_OUT(NAME, FN) \
+BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
@@ -80,6 +89,22 @@ BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
+#if __METAL_VERSION__ >= 220
+INT64_BINARY_OP(add, x + y)
+INT64_BINARY_OP(sub, x - y)
+INT64_BINARY_OP(mul, x * y)
+INT64_BINARY_OP(div, x / y)
+INT64_BINARY_OP(min, MIN(x, y))
+INT64_BINARY_OP(max, MAX(x, y))
+
+INT64_BINARY_OP_OUT(eq, x == y)
+INT64_BINARY_OP_OUT(ne, x != y)
+INT64_BINARY_OP_OUT(le, x <= y)
+INT64_BINARY_OP_OUT(lt, x < y)
+INT64_BINARY_OP_OUT(ge, x >= y)
+INT64_BINARY_OP_OUT(gt, x > y)
+#endif
+
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 8481389d..e9ab17b1 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -52,5 +52,13 @@ CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
+#if __METAL_VERSION__ >= 220
+CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
+CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
+CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
+#endif
+
#if __METAL_VERSION__ >= 310
+CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e2603b3b..75f0286d 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -134,6 +134,9 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
+ pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
}
)+
pub mod copy {
@@ -141,6 +144,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32");
pub const HALF: Kernel = Kernel("copy_f16");
pub const BFLOAT: Kernel = Kernel("copy_bf16");
+ pub const I64: Kernel = Kernel("copy_i64");
pub const U32: Kernel = Kernel("copy_u32");
pub const U8: Kernel = Kernel("copy_u8");
}
@@ -154,6 +158,9 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
+ pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
}
)+
pub mod copy {
@@ -161,6 +168,7 @@ macro_rules! ops{
pub const FLOAT: Kernel = Kernel("copy_f32_strided");
pub const HALF: Kernel = Kernel("copy_f16_strided");
pub const BFLOAT: Kernel = Kernel("copy_bf16_strided");
+ pub const I64: Kernel = Kernel("copy_i64_strided");
pub const U32: Kernel = Kernel("copy_u32_strided");
pub const U8: Kernel = Kernel("copy_u8_strided");
}
@@ -169,7 +177,10 @@ macro_rules! ops{
}
pub mod unary {
- ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
+ ops!(
+ cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
+ recip
+ );
}
pub mod binary {
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 2d584917..83a56f0a 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -263,24 +263,38 @@ kernel void NAME(
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
+REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
+REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
+REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
+ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
+ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
+
+#if __METAL_VERSION__ >= 220
+REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
+REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
+REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
+ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
+ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
+#endif
+
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index 1f9cb38a..40b4bcf4 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -55,6 +55,9 @@ kernel void FN_NAME( \
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
-// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
-// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
-// WHERE_OP(int64_t, uint8_t, where_u8_i64)
+WHERE_OP(uint8_t, uint8_t, where_u8_u8)
+WHERE_OP(uint32_t, uint8_t, where_u8_u32)
+
+#if __METAL_VERSION__ >= 220
+WHERE_OP(int64_t, uint8_t, where_u8_i64)
+#endif
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 04fa37a9..7fbb613d 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index(
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
+template <typename T> METAL_FUNC T recip(T in){ return T(1.0 / in); }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
+
template <typename T> METAL_FUNC T erf(T in){
float x = (float) in;
// constants
@@ -57,8 +59,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
}
-
-
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@@ -102,17 +102,24 @@ UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
+UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
+UNARY_OP(recip)
+
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
+#if __METAL_VERSION__ >= 220
+UNARY(id, int64_t, copy_i64, copy_i64_strided)
+#endif
+
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
@@ -128,6 +135,7 @@ BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
+BFLOAT_UNARY_OP(recip)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
#endif