summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/Cargo.toml1
-rw-r--r--candle-core/src/metal_backend.rs54
-rw-r--r--candle-metal-kernels/Cargo.toml9
-rw-r--r--candle-metal-kernels/src/cast.metal40
-rw-r--r--candle-metal-kernels/src/indexing.metal3
-rw-r--r--candle-metal-kernels/src/tests.rs154
6 files changed, 243 insertions, 18 deletions
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index d9fc7526..92a04917 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -48,4 +48,3 @@ metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "bench_main"
harness = false
-
diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs
index 5d72bd68..aa2898ff 100644
--- a/candle-core/src/metal_backend.rs
+++ b/candle-core/src/metal_backend.rs
@@ -590,14 +590,26 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
(DType::U32, DType::I64) => "cast_u32_i64",
+ (DType::U32, DType::BF16) => "cast_u32_bf16",
+
(DType::U8, DType::U32) => "cast_u8_u32",
(DType::U8, DType::F32) => "cast_u8_f32",
(DType::U8, DType::I64) => "cast_u8_i64",
+ (DType::U8, DType::BF16) => "cast_u8_bf16",
+
(DType::F32, DType::F16) => "cast_f32_f16",
- (DType::F16, DType::F32) => "cast_f16_f32",
- (DType::I64, DType::F32) => "cast_i64_f32",
(DType::F32, DType::BF16) => "cast_f32_bf16",
+
+ (DType::I64, DType::F32) => "cast_i64_f32",
+
+ (DType::F16, DType::BF16) => "cast_f16_bf16",
+ (DType::F16, DType::F32) => "cast_f16_f32",
+
+ (DType::BF16, DType::U8) => "cast_bf16_u8",
+ (DType::BF16, DType::U32) => "cast_bf16_u32",
+ (DType::BF16, DType::F16) => "cast_bf16_f16",
(DType::BF16, DType::F32) => "cast_bf16_f32",
+
(left, right) => {
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
}
@@ -1131,8 +1143,12 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
+ (DType::U8, DType::BF16) => "is_u8_bf16",
+
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
+ (DType::U32, DType::BF16) => "is_u32_bf16",
+
(left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
}
@@ -1322,6 +1338,7 @@ impl MetalStorage {
("lt", DType::F32) => (contiguous::lt::FLOAT, DType::U8),
("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8),
("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8),
+
("add", DType::F16) => (contiguous::add::HALF, self.dtype),
("sub", DType::F16) => (contiguous::sub::HALF, self.dtype),
("mul", DType::F16) => (contiguous::mul::HALF, self.dtype),
@@ -1332,6 +1349,18 @@ impl MetalStorage {
("lt", DType::F16) => (contiguous::lt::HALF, DType::U8),
("ge", DType::F16) => (contiguous::ge::HALF, DType::U8),
("gt", DType::F16) => (contiguous::gt::HALF, DType::U8),
+
+ ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype),
+ ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype),
+ ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype),
+ ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype),
+ ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8),
+ ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8),
+ ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8),
+ ("lt", DType::BF16) => (contiguous::lt::BFLOAT, DType::U8),
+ ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8),
+ ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8),
+
("add", DType::I64) => (contiguous::add::I64, self.dtype),
("sub", DType::I64) => (contiguous::sub::I64, self.dtype),
("mul", DType::I64) => (contiguous::mul::I64, self.dtype),
@@ -1342,6 +1371,7 @@ impl MetalStorage {
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
+
("add", DType::U32) => (contiguous::add::U32, self.dtype),
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
@@ -1352,6 +1382,7 @@ impl MetalStorage {
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
+
("add", DType::U8) => (contiguous::add::U8, self.dtype),
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
@@ -1362,6 +1393,7 @@ impl MetalStorage {
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
+
(name, dtype) => {
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
}
@@ -1395,6 +1427,7 @@ impl MetalStorage {
("lt", DType::F32) => (strided::lt::FLOAT, DType::U8),
("ge", DType::F32) => (strided::ge::FLOAT, DType::U8),
("gt", DType::F32) => (strided::gt::FLOAT, DType::U8),
+
("badd", DType::F16) => (strided::add::HALF, self.dtype),
("bsub", DType::F16) => (strided::sub::HALF, self.dtype),
("bmul", DType::F16) => (strided::mul::HALF, self.dtype),
@@ -1407,6 +1440,20 @@ impl MetalStorage {
("lt", DType::F16) => (strided::lt::HALF, DType::U8),
("ge", DType::F16) => (strided::ge::HALF, DType::U8),
("gt", DType::F16) => (strided::gt::HALF, DType::U8),
+
+ ("badd", DType::BF16) => (strided::add::BFLOAT, self.dtype),
+ ("bsub", DType::BF16) => (strided::sub::BFLOAT, self.dtype),
+ ("bmul", DType::BF16) => (strided::mul::BFLOAT, self.dtype),
+ ("bdiv", DType::BF16) => (strided::div::BFLOAT, self.dtype),
+ ("bminimum", DType::BF16) => (strided::min::BFLOAT, self.dtype),
+ ("bmaximum", DType::BF16) => (strided::max::BFLOAT, self.dtype),
+ ("eq", DType::BF16) => (strided::eq::BFLOAT, DType::U8),
+ ("ne", DType::BF16) => (strided::ne::BFLOAT, DType::U8),
+ ("le", DType::BF16) => (strided::le::BFLOAT, DType::U8),
+ ("lt", DType::BF16) => (strided::lt::BFLOAT, DType::U8),
+ ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8),
+ ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8),
+
("badd", DType::I64) => (strided::add::I64, self.dtype),
("bsub", DType::I64) => (strided::sub::I64, self.dtype),
("bmul", DType::I64) => (strided::mul::I64, self.dtype),
@@ -1419,6 +1466,7 @@ impl MetalStorage {
("lt", DType::I64) => (strided::lt::I64, DType::U8),
("ge", DType::I64) => (strided::ge::I64, DType::U8),
("gt", DType::I64) => (strided::gt::I64, DType::U8),
+
("badd", DType::U32) => (strided::add::U32, self.dtype),
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
@@ -1431,6 +1479,7 @@ impl MetalStorage {
("lt", DType::U32) => (strided::lt::U32, DType::U8),
("ge", DType::U32) => (strided::ge::U32, DType::U8),
("gt", DType::U32) => (strided::gt::U32, DType::U8),
+
("badd", DType::U8) => (strided::add::U8, self.dtype),
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
@@ -1443,6 +1492,7 @@ impl MetalStorage {
("lt", DType::U8) => (strided::lt::U8, DType::U8),
("ge", DType::U8) => (strided::ge::U8, DType::U8),
("gt", DType::U8) => (strided::gt::U8, DType::U8),
+
(name, dtype) => {
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
}
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
index 441d2e88..187cb4fd 100644
--- a/candle-metal-kernels/Cargo.toml
+++ b/candle-metal-kernels/Cargo.toml
@@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"
+
[dependencies]
-metal = { version = "0.27.0", features = ["mps"]}
+metal = { version = "0.27.0", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
-half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
+half = { version = "2.3.1", features = [
+ "num-traits",
+ "use-intrinsics",
+ "rand_distr",
+] }
rand = "0.8.5"
diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal
index 5aacac4a..e08931cf 100644
--- a/candle-metal-kernels/src/cast.metal
+++ b/candle-metal-kernels/src/cast.metal
@@ -28,7 +28,7 @@ kernel void FN_NAME( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[tid]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \
if (tid >= dim) { \
return; \
} \
- output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \
+ output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \
+} \
+
+#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &dim, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \
+} \
+kernel void FN_NAME_STRIDED( \
+ constant size_t &dim, \
+ constant size_t &num_dims, \
+ constant size_t *dims, \
+ constant size_t *strides, \
+ device const LEFT_TYPENAME *input, \
+ device RIGHT_TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= dim) { \
+ return; \
+ } \
+ output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
} \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
@@ -59,6 +86,15 @@ CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
#endif
#if defined(__HAVE_BFLOAT__)
+#if __METAL_VERSION__ >= 310
+CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
+
+CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
+CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
+
+CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
+CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
+CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
#endif
diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal
index 32f3f410..2a57bdbb 100644
--- a/candle-metal-kernels/src/indexing.metal
+++ b/candle-metal-kernels/src/indexing.metal
@@ -174,6 +174,9 @@ SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
+INDEX_OP(is_u32_bf16, uint32_t, bfloat)
+INDEX_OP(is_u8_bf16, uint8_t, bfloat)
+
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index c955abca..87f8ac45 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,6 +1,6 @@
use super::*;
use half::{bf16, f16};
-use metal::{Device, MTLResourceOptions};
+use metal::{Buffer, Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
@@ -248,6 +248,34 @@ fn binary_add_f32() {
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
+#[test]
+fn binary_ops_bf16() {
+ let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect();
+ let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32]
+ .into_iter()
+ .map(bf16::from_f32)
+ .collect();
+
+ macro_rules! binary_op {
+ ($opname:ident, $opexpr:expr) => {{
+ let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT);
+ let expected: Vec<bf16> = lhs
+ .iter()
+ .zip(rhs.iter())
+ .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y))
+ .collect();
+ assert_eq!(results, expected);
+ }};
+ }
+
+ binary_op!(add, |x, y| x + y);
+ binary_op!(sub, |x, y| x - y);
+ binary_op!(mul, |x, y| x * y);
+ binary_op!(div, |x, y| x / y);
+ binary_op!(min, |x: bf16, y| x.min(y));
+ binary_op!(max, |x: bf16, y| x.max(y));
+}
+
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
@@ -296,6 +324,89 @@ fn cast_u32_f32() {
assert_eq!(results, vec![1.0f32; 10_000]);
}
+#[test]
+fn it_cast_bf16_u32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u32> = cast(&input, "cast_bf16_u32");
+ let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f32() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f32> = cast(&input, "cast_bf16_f32");
+ let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u8_bf16() {
+ let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
+ let expected: Vec<bf16> = input
+ .iter()
+ .map(|v| bf16::from_f32(*v as f32))
+ .collect::<Vec<_>>();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_u32_bf16() {
+ let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f32_bf16() {
+ let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_u8() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<u8> = cast(&input, "cast_bf16_u8");
+ let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_bf16_f16() {
+ let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
+
+ let output: Vec<f16> = cast(&input, "cast_bf16_f16");
+ let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
+#[test]
+fn it_cast_f16_bf16() {
+ let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
+
+ let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
+ let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
+
+ assert_eq!(output, expected);
+}
+
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
@@ -396,14 +507,14 @@ fn index_select() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]);
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [2, 5];
let ids = [0u32, 1, 0];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
@@ -419,7 +530,7 @@ fn index_select_f16() {
let shape = [5, 2];
let ids = [0u32, 4, 2];
let dim = 0;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16");
assert_eq!(
approx_f16(result, 4),
vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
@@ -427,12 +538,38 @@ fn index_select_f16() {
}
#[test]
+fn index_select_is_u32_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_is_u8_bf16() {
+ let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect();
+ let shape = [5, 2];
+ let ids = [0u8, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16");
+ assert_eq!(
+ approx_bf16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
let dim = 1;
- let result = run_index_select(&embedding, &shape, &ids, dim);
+ let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32");
assert_eq!(
result,
vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
@@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
shape: &[usize],
ids: &[I],
dim: usize,
+ name: &'static str,
) -> Vec<T> {
let device = Device::system_default().expect("no device found");
@@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let dst_el = ids.len() * left_size * right_size;
let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
- let name = match core::mem::size_of::<T>() {
- 4 => "is_u32_f32",
- 2 => "is_u32_f16",
- _ => unimplemented!(),
- };
-
let fence = device.new_fence();
let kernels = Kernels::new(fence);
call_index_select(