summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2024-01-04 12:12:56 +0100
committerGitHub <noreply@github.com>2024-01-04 12:12:56 +0100
commitfa3ea98ba92835960fdd825a5b4dda30ef2baaa4 (patch)
treea14bcd435d10dc325429887c413fdddda798e66d /candle-core
parent135ae5f3eb90f28327ba5d055291b7b3a2e2a47d (diff)
downloadcandle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.tar.gz
candle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.tar.bz2
candle-fa3ea98ba92835960fdd825a5b4dda30ef2baaa4.zip
Adding bfloat16 support for the cast kernels. (#1520)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/metal_backend.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index e168c24b..c1c4aa4b 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -596,6 +596,8 @@ impl BackendStorage for MetalStorage {
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(DType::I64, DType::F32) => "cast_i64_f32",
+ (DType::F32, DType::BF16) => "cast_f32_bf16",
+ (DType::BF16, DType::F32) => "cast_bf16_f32",
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
@@ -622,6 +624,8 @@ impl BackendStorage for MetalStorage {
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(DType::I64, DType::F32) => "cast_i64_f32_strided",
+ (DType::F32, DType::BF16) => "cast_f32_bf16_strided",
+ (DType::BF16, DType::F32) => "cast_bf16_f32_strided",
(left, right) => {
crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented")
}