summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-01-04 12:12:56 +0100
committerGitHub <noreply@github.com>2024-01-04 12:12:56 +0100
commitfa3ea98ba92835960fdd825a5b4dda30ef2baaa4 (patch)
treea14bcd435d10dc325429887c413fdddda798e66d /candle-metal-kernels
parent135ae5f3eb90f28327ba5d055291b7b3a2e2a47d (diff)
downloadcandle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.tar.gz
candle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.tar.bz2
candle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.zip
Adding bfloat16 support for the cast kernels. (#1520)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/cast.metal2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 3baefcc2..e9ab17b1 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -59,4 +59,6 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if __METAL_VERSION__ >= 310
+CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
#endif