diff options
Diffstat (limited to 'candle-core/src/metal_backend.rs')
-rw-r--r-- | candle-core/src/metal_backend.rs | 54 |
1 files changed, 52 insertions, 2 deletions
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 5d72bd68..aa2898ff 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "cast_u32_f32", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U32, DType::I64) => "cast_u32_i64", + (DType::U32, DType::BF16) => "cast_u32_bf16", + (DType::U8, DType::U32) => "cast_u8_u32", (DType::U8, DType::F32) => "cast_u8_f32", (DType::U8, DType::I64) => "cast_u8_i64", + (DType::U8, DType::BF16) => "cast_u8_bf16", + (DType::F32, DType::F16) => "cast_f32_f16", - (DType::F16, DType::F32) => "cast_f16_f32", - (DType::I64, DType::F32) => "cast_i64_f32", (DType::F32, DType::BF16) => "cast_f32_bf16", + + (DType::I64, DType::F32) => "cast_i64_f32", + + (DType::F16, DType::BF16) => "cast_f16_bf16", + (DType::F16, DType::F32) => "cast_f16_f32", + + (DType::BF16, DType::U8) => "cast_bf16_u8", + (DType::BF16, DType::U32) => "cast_bf16_u32", + (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (left, right) => { crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented") } @@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage { let device = self.device(); let buffer = device.new_buffer(dst_el, dtype, "index_select")?; let name = match (ids.dtype, self.dtype) { + (DType::U8, DType::BF16) => "is_u8_bf16", + (DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F16) => "is_u32_f16", + (DType::U32, DType::BF16) => "is_u32_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } @@ -1322,6 +1338,7 @@ impl MetalStorage { ("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8), ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), + ("add", DType::F16) => (contiguous::add::HALF, self.dtype), ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), @@ -1332,6 +1349,18 @@ impl MetalStorage { ("lt", DType::F16) => (contiguous::lt::HALF, DType::U8), ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), + + ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1342,6 +1371,7 @@ impl MetalStorage { ("lt", DType::I64) => (contiguous::lt::I64, DType::U8), ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), + ("add", DType::U32) => (contiguous::add::U32, self.dtype), ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), @@ -1352,6 +1382,7 @@ impl MetalStorage { ("lt", DType::U32) => (contiguous::lt::U32, DType::U8), ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), + ("add", DType::U8) => (contiguous::add::U8, self.dtype), ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), @@ -1362,6 +1393,7 @@ impl MetalStorage { ("lt", DType::U8) => (contiguous::lt::U8, DType::U8), ("ge", DType::U8) => (contiguous::ge::U8, DType::U8), ("gt", DType::U8) => (contiguous::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented") } @@ -1395,6 +1427,7 @@ impl MetalStorage { ("lt", DType::F32) => (strided::lt::FLOAT, DType::U8), ("ge", DType::F32) => (strided::ge::FLOAT, DType::U8), ("gt", DType::F32) => (strided::gt::FLOAT, DType::U8), + ("badd", DType::F16) => (strided::add::HALF, self.dtype), ("bsub", DType::F16) => (strided::sub::HALF, self.dtype), ("bmul", DType::F16) => (strided::mul::HALF, self.dtype), @@ -1407,6 +1440,20 @@ impl MetalStorage { ("lt", DType::F16) => (strided::lt::HALF, DType::U8), ("ge", DType::F16) => (strided::ge::HALF, DType::U8), ("gt", DType::F16) => (strided::gt::HALF, DType::U8), + + ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype), + ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8), + ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8), + ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8), + ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8), + ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), + ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1419,6 +1466,7 @@ impl MetalStorage { ("lt", DType::I64) => (strided::lt::I64, DType::U8), ("ge", DType::I64) => (strided::ge::I64, DType::U8), ("gt", DType::I64) => (strided::gt::I64, DType::U8), + ("badd", DType::U32) => (strided::add::U32, self.dtype), ("bsub", DType::U32) => (strided::sub::U32, self.dtype), ("bmul", DType::U32) => (strided::mul::U32, self.dtype), @@ -1431,6 +1479,7 @@ impl MetalStorage { ("lt", DType::U32) => (strided::lt::U32, DType::U8), ("ge", DType::U32) => (strided::ge::U32, DType::U8), ("gt", DType::U32) => (strided::gt::U32, DType::U8), + ("badd", DType::U8) => (strided::add::U8, self.dtype), ("bsub", DType::U8) => (strided::sub::U8, self.dtype), ("bmul", DType::U8) => (strided::mul::U8, self.dtype), @@ -1443,6 +1492,7 @@ impl MetalStorage { ("lt", DType::U8) => (strided::lt::U8, DType::U8), ("ge", DType::U8) => (strided::ge::U8, DType::U8), ("gt", DType::U8) => (strided::gt::U8, DType::U8), + (name, dtype) => { crate::bail!("Metal strided binary {name} {dtype:?} not implemented") } |