summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-12-18 11:01:18 +0100
committerNicolas Patry <patry.nicolas@protonmail.com>2023-12-18 11:01:18 +0100
commite8ee253ee0766c33ac69f08bb0bcd6601f47ca6f (patch)
tree6c4bceb7df56ab6530722cc06dd055e61ba0136e
parent8bd3d6b94bb2449d056d38d5a42c8e9c762f2d7e (diff)
downloadcandle-e8ee253ee0766c33ac69f08bb0bcd6601f47ca6f.tar.gz
candle-e8ee253ee0766c33ac69f08bb0bcd6601f47ca6f.tar.bz2
candle-e8ee253ee0766c33ac69f08bb0bcd6601f47ca6f.zip
Missing cast.
-rw-r--r--candle-core/src/metal_backend.rs2
-rw-r--r--candle-metal-kernels/src/cast.metal1
2 files changed, 3 insertions, 0 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 21a8967b..0af11a3d 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -578,6 +578,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U8, DType::U32) => "cast_u8_u32",
+ (DType::U8, DType::F32) => "cast_u8_f32",
(DType::F32, DType::F16) => "cast_f32_f16",
(DType::F16, DType::F32) => "cast_f16_f32",
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
@@ -598,6 +599,7 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32_strided",
(DType::U32, DType::U8) => "cast_u32_u8_strided",
(DType::U8, DType::U32) => "cast_u8_u32_strided",
+ (DType::U8, DType::F32) => "cast_u8_f32_strided",
(DType::F32, DType::F16) => "cast_f32_f16_strided",
(DType::F16, DType::F32) => "cast_f16_f32_strided",
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 4398e9d4..8481389d 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -48,6 +48,7 @@ kernel void FN_NAME_STRIDED( \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
+CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)