summaryrefslogtreecommitdiff
path: root/candle-core/src/metal_backend
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-10 12:49:25 +0200
committerGitHub <noreply@github.com>2024-04-10 12:49:25 +0200
commita4d5a414e3ae79642ecfd6b7bb410c26a8a62a06 (patch)
treeb43f5eabd83c6a97909b965058a3d88c06bfc926 /candle-core/src/metal_backend
parent798e0335cd2c4661f0fd0429cdf06abe3b45f4ea (diff)
downloadcandle-a4d5a414e3ae79642ecfd6b7bb410c26a8a62a06.tar.gz
candle-a4d5a414e3ae79642ecfd6b7bb410c26a8a62a06.tar.bz2
candle-a4d5a414e3ae79642ecfd6b7bb410c26a8a62a06.zip
Support gather on bf16 for metal. (#2035)
Diffstat (limited to 'candle-core/src/metal_backend')
-rw-r--r--candle-core/src/metal_backend/mod.rs1
1 files changed, 1 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 50149a9d..158eb8e0 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -1042,6 +1042,7 @@ impl BackendStorage for MetalStorage {
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
+ (DType::U32, DType::BF16) => "gather_u32_bf16",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;