diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:19:58 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:19:58 +0100 |
commit | e63bb8661beb4ea139f4e7f1d85f56907d918b2b (patch) | |
tree | 2326f731957d56667ba4d432a68f8b37a2b79830 /candle-metal-kernels | |
parent | 87efb5d8eb6a6c3f17acf326aadcb11ad6900306 (diff) | |
parent | 41915184bb3e530cc8184fdd8841c66df9285684 (diff) | |
download | candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.gz candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.bz2 candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.zip |
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/Cargo.toml | 9 | ||||
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/binary.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/cast.metal | 42 | ||||
-rw-r--r-- | candle-metal-kernels/src/indexing.metal | 5 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 4 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 154 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 10 |
9 files changed, 206 insertions, 24 deletions
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 441d2e88..187cb4fd 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -9,12 +9,17 @@ keywords = ["blas", "tensor", "machine-learning"] categories = ["science"] license = "MIT OR Apache-2.0" + [dependencies] -metal = { version = "0.27.0", features = ["mps"]} +metal = { version = "0.27.0", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" [dev-dependencies] -half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +half = { version = "2.3.1", features = [ + "num-traits", + "use-intrinsics", + "rand_distr", +] } rand = "0.8.5" diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index 4166d811..3d8e7f0d 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -117,7 +117,7 @@ ELU(elu_f32, float) ELU(elu_f16, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) AFFINE(affine_bf16, bfloat); POWF(powf_bf16, bfloat); ELU(elu_bf16, bfloat); diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cdc8fef8..eb560f16 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -105,7 +105,7 @@ INT64_BINARY_OP_OUT(ge, x >= y) INT64_BINARY_OP_OUT(gt, x > y) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index e9ab17b1..9aead139 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -28,7 +28,7 @@ kernel void FN_NAME( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[tid]); \ + output[tid] = static_cast<RIGHT_TYPENAME>(input[tid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -42,7 +42,34 @@ kernel void FN_NAME_STRIDED( \ if (tid >= dim) { \ return; \ } \ - output[tid] = RIGHT_TYPENAME(input[get_strided_index(tid, num_dims, dims, strides)]); \ + output[tid] = static_cast<RIGHT_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)]); \ +} \ + +#define CAST_THROUGH(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME, IR_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[tid])); \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= dim) { \ + return; \ + } \ + output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) @@ -58,7 +85,14 @@ CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) +CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) +CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) +CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) -#endif + +CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) +CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) +CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) +#endif
\ No newline at end of file diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 63357428..2a57bdbb 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -173,7 +173,10 @@ SCATTER_ADD_OP(sa_u32_f32, uint, float) SCATTER_ADD_OP(sa_u32_f16, uint, half) -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_u32_bf16, uint32_t, bfloat) +INDEX_OP(is_u8_bf16, uint8_t, bfloat) + INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 75f0286d..c427a690 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -178,8 +178,8 @@ macro_rules! ops{ pub mod unary { ops!( - cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh, - recip + cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf, + tanh, recip ); } pub mod binary { diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..93dac662 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -295,7 +295,7 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 067dece8..b15505f7 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Device, MTLResourceOptions}; +use metal::{Buffer, Device, MTLResourceOptions}; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -248,6 +248,34 @@ fn binary_add_f32() { assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); } +#[test] +fn binary_ops_bf16() { + let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec<bf16> = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); let fence = device.new_fence(); @@ -296,6 +324,89 @@ fn cast_u32_f32() { assert_eq!(results, vec![1.0f32; 10_000]); } +#[test] +fn it_cast_bf16_u32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u32> = cast(&input, "cast_bf16_u32"); + let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f32> = cast(&input, "cast_bf16_f32"); + let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u8_bf16() { + let input: Vec<u8> = (1..=3).map(|v| v as u8).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u8_bf16"); + let expected: Vec<bf16> = input + .iter() + .map(|v| bf16::from_f32(*v as f32)) + .collect::<Vec<_>>(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u32_bf16() { + let input: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f32_bf16() { + let input: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_u8() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u8> = cast(&input, "cast_bf16_u8"); + let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f16() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f16> = cast(&input, "cast_bf16_f16"); + let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f16_bf16() { + let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f16_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let device = device(); let fence = device.new_fence(); @@ -396,14 +507,14 @@ fn index_select() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] @@ -419,7 +530,7 @@ fn index_select_f16() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -427,12 +538,38 @@ fn index_select_f16() { } #[test] +fn index_select_is_u32_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( shape: &[usize], ids: &[I], dim: usize, + name: &'static str, ) -> Vec<T> { let device = Device::system_default().expect("no device found"); @@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let name = match core::mem::size_of::<T>() { - 4 => "is_u32_f32", - 2 => "is_u32_f16", - _ => unimplemented!(), - }; - let fence = device.new_fence(); let kernels = Kernels::new(fence); call_index_select( diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 7fbb613d..dcf803d8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -58,6 +58,12 @@ template <typename T> METAL_FUNC T gelu(T x) { T beta = (static_cast<T>(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta))); } +template <typename T> METAL_FUNC T relu(T in){ + if (in < 0) { + return 0; + } + return in; +} #define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ kernel void FN_NAME( \ @@ -110,6 +116,7 @@ UNARY_OP(gelu_erf) UNARY_OP(erf) UNARY_OP(tanh) UNARY_OP(recip) +UNARY_OP(relu) UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, half, copy_f16, copy_f16_strided) @@ -120,7 +127,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) UNARY(id, int64_t, copy_i64, copy_i64_strided) #endif -#if __METAL_VERSION__ >= 310 +#if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) BFLOAT_UNARY_OP(sqr) @@ -136,6 +143,7 @@ BFLOAT_UNARY_OP(gelu_erf) BFLOAT_UNARY_OP(erf) BFLOAT_UNARY_OP(tanh) BFLOAT_UNARY_OP(recip) +BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) #endif |