summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/conv.metal
diff options
context:
space:
mode:
authorThomas Santerre <thomas@santerre.xyz>2024-03-26 01:48:56 -0400
committerGitHub <noreply@github.com>2024-03-26 06:48:56 +0100
commitf5dfe883d768e55208b325b3838474f8fe58e12f (patch)
tree633f064ce26b56f8f83f377f69f2fef20888ba58 /candle-metal-kernels/src/conv.metal
parent196765e995f7f4bd3b9610a22f8ef5b009437a4e (diff)
downloadcandle-f5dfe883d768e55208b325b3838474f8fe58e12f.tar.gz
candle-f5dfe883d768e55208b325b3838474f8fe58e12f.tar.bz2
candle-f5dfe883d768e55208b325b3838474f8fe58e12f.zip
Extend supported dtypes for metal (im2col & upsample_2d) (#1938)
* update im2col dtype implementations * update dtypes for upsample
Diffstat (limited to 'candle-metal-kernels/src/conv.metal')
-rw-r--r--candle-metal-kernels/src/conv.metal8
1 files changed, 8 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/conv.metal b/candle-metal-kernels/src/conv.metal
index e28ac6b3..8fdd0e5f 100644
--- a/candle-metal-kernels/src/conv.metal
+++ b/candle-metal-kernels/src/conv.metal
@@ -486,16 +486,24 @@ kernel void FN_NAME( \
} \
IM2COL_OP(float, im2col_f32)
+IM2COL_OP(half, im2col_f16)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
+#if defined(__HAVE_BFLOAT__)
+IM2COL_OP(bfloat, im2col_bf16)
+#endif
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
+UPSAMPLE_NEAREST2D_OP(half, upsample_nearest2d_f16)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)
+#if defined(__HAVE_BFLOAT__)
+UPSAMPLE_NEAREST2D_OP(bfloat, upsample_nearest2d_bf16)
+#endif
MAXPOOL2D_OP(float, max_pool2d_f32)
MAXPOOL2D_OP(half, max_pool2d_f16)