summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorGonzalo <456459+grzuy@users.noreply.github.com>2023-12-29 19:56:21 -0300
committerGitHub <noreply@github.com>2023-12-29 23:56:21 +0100
commit87d7f81b438db6a1696f2ac79606b40e61d448e8 (patch)
tree31f3fb7a25091b0a9076d1aa19d510c5b3d9b044 /candle-metal-kernels
parent4373534d59d3a6357aef0b3f35a247f695f4700a (diff)
downloadcandle-87d7f81b438db6a1696f2ac79606b40e61d448e8.tar.gz
candle-87d7f81b438db6a1696f2ac79606b40e61d448e8.tar.bz2
candle-87d7f81b438db6a1696f2ac79606b40e61d448e8.zip
Metal: more u8/u32 (#1502)
* Adds more metal u8 * Metal: more u32
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/binary.metal8
-rw-r--r--candle-metal-kernels/src/lib.rs4
-rw-r--r--candle-metal-kernels/src/reduce.metal5
-rw-r--r--candle-metal-kernels/src/ternary.metal4
4 files changed, 17 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal
index 30c90ff1..cdc8fef8 100644
--- a/candle-metal-kernels/src/binary.metal
+++ b/candle-metal-kernels/src/binary.metal
@@ -56,7 +56,9 @@ kernel void FN_NAME_STRIDED( \
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
-BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
+BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \
+BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP(NAME, FN) \
BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided);
@@ -66,7 +68,9 @@ BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
-BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
+BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \
+BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
+BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 7b0084d9..d080ef52 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -131,6 +131,8 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8"));
}
)+
pub mod copy {
@@ -153,6 +155,8 @@ macro_rules! ops{
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided"));
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided"));
+ pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided"));
+ pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided"));
}
)+
pub mod copy {
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 38252967..83a56f0a 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -263,21 +263,26 @@ kernel void NAME(
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
+REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
+REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
+REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
+ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
+ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal
index dfa0dd12..40b4bcf4 100644
--- a/candle-metal-kernels/src/ternary.metal
+++ b/candle-metal-kernels/src/ternary.metal
@@ -55,8 +55,8 @@ kernel void FN_NAME( \
WHERE_OP(float, uint8_t, where_u8_f32)
// WHERE_OP(double, uint8_t, where_u8_f64)
-// WHERE_OP(uint8_t, uint8_t, where_u8_u8)
-// WHERE_OP(uint32_t, uint8_t, where_u8_u32)
+WHERE_OP(uint8_t, uint8_t, where_u8_u8)
+WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64)