summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLionel Touati <ltouati@gmail.com>2024-06-02 14:30:06 +0200
committerGitHub <noreply@github.com>2024-06-02 14:30:06 +0200
commit1ec3b2cc189fa6020018f2c8dad7b216b4512019 (patch)
treefc643da98f7649780798a2668279d39e3441c47f /candle-core
parentf7773d498a58fc5678784bd4843011974e11f953 (diff)
downloadcandle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.gz
candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.tar.bz2
candle-1ec3b2cc189fa6020018f2c8dad7b216b4512019.zip
add where_cond f32 for metal (#2236)
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/metal_backend/mod.rs1
1 files changed, 1 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs
index 2563607a..06f6cd37 100644
--- a/candle-core/src/metal_backend/mod.rs
+++ b/candle-core/src/metal_backend/mod.rs
@@ -718,6 +718,7 @@ impl BackendStorage for MetalStorage {
}
let name = match (self.dtype, t.dtype()) {
(DType::U8, DType::F32) => "where_u8_f32",
+ (DType::U32, DType::F32) => "where_u32_f32",
(DType::U8, DType::BF16) => "where_u8_bf16",
(DType::U8, DType::F16) => "where_u8_f16",
(DType::U8, DType::I64) => "where_u8_i64",