diff options
author | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:19:58 +0100 |
---|---|---|
committer | Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 07:19:58 +0100 |
commit | e63bb8661beb4ea139f4e7f1d85f56907d918b2b (patch) | |
tree | 2326f731957d56667ba4d432a68f8b37a2b79830 /candle-metal-kernels/src/tests.rs | |
parent | 87efb5d8eb6a6c3f17acf326aadcb11ad6900306 (diff) | |
parent | 41915184bb3e530cc8184fdd8841c66df9285684 (diff) | |
download | candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.gz candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.tar.bz2 candle-e63bb8661beb4ea139f4e7f1d85f56907d918b2b.zip |
Merge branch 'main' into ivarflakstad/metal-prng
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 154 |
1 files changed, 143 insertions, 11 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 067dece8..b15505f7 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,6 +1,6 @@ use super::*; use half::{bf16, f16}; -use metal::{Device, MTLResourceOptions}; +use metal::{Buffer, Device, MTLResourceOptions}; fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { let ptr = buffer.contents() as *const T; @@ -248,6 +248,34 @@ fn binary_add_f32() { assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); } +#[test] +fn binary_ops_bf16() { + let lhs: Vec<bf16> = [1.1f32, 2.2, 3.3].into_iter().map(bf16::from_f32).collect(); + let rhs: Vec<bf16> = [4.2f32, 5.5f32, 6.91f32] + .into_iter() + .map(bf16::from_f32) + .collect(); + + macro_rules! binary_op { + ($opname:ident, $opexpr:expr) => {{ + let results = run_binary(&lhs, &rhs, binary::contiguous::$opname::BFLOAT); + let expected: Vec<bf16> = lhs + .iter() + .zip(rhs.iter()) + .map(|(x, y): (&bf16, &bf16)| $opexpr(*x, *y)) + .collect(); + assert_eq!(results, expected); + }}; + } + + binary_op!(add, |x, y| x + y); + binary_op!(sub, |x, y| x - y); + binary_op!(mul, |x, y| x * y); + binary_op!(div, |x, y| x / y); + binary_op!(min, |x: bf16, y| x.min(y)); + binary_op!(max, |x: bf16, y| x.max(y)); +} + fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); let fence = device.new_fence(); @@ -296,6 +324,89 @@ fn cast_u32_f32() { assert_eq!(results, vec![1.0f32; 10_000]); } +#[test] +fn it_cast_bf16_u32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u32> = cast(&input, "cast_bf16_u32"); + let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f32() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f32> = cast(&input, "cast_bf16_f32"); + let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u8_bf16() { + let input: Vec<u8> = (1..=3).map(|v| v as u8).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u8_bf16"); + let expected: Vec<bf16> = input + .iter() + .map(|v| bf16::from_f32(*v as f32)) + .collect::<Vec<_>>(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_u32_bf16() { + let input: Vec<u32> = (1..=3).map(|v| v as u32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_u32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f32_bf16() { + let input: Vec<f32> = (1..=3).map(|v| v as f32).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f32_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_u8() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<u8> = cast(&input, "cast_bf16_u8"); + let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_bf16_f16() { + let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect(); + + let output: Vec<f16> = cast(&input, "cast_bf16_f16"); + let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + +#[test] +fn it_cast_f16_bf16() { + let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect(); + + let output: Vec<bf16> = cast(&input, "cast_f16_bf16"); + let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect(); + + assert_eq!(output, expected); +} + fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let device = device(); let fence = device.new_fence(); @@ -396,14 +507,14 @@ fn index_select() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [2, 5]; let ids = [0u32, 1, 0]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] @@ -419,7 +530,7 @@ fn index_select_f16() { let shape = [5, 2]; let ids = [0u32, 4, 2]; let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f16"); assert_eq!( approx_f16(result, 4), vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] @@ -427,12 +538,38 @@ fn index_select_f16() { } #[test] +fn index_select_is_u32_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_is_u8_bf16() { + let embedding: Vec<bf16> = (1..=10).map(|x| bf16::from_f32(x as f32)).collect(); + let shape = [5, 2]; + let ids = [0u8, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u8_bf16"); + assert_eq!( + approx_bf16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); + let result = run_index_select(&embedding, &shape, &ids, dim, "is_u32_f32"); assert_eq!( result, vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] @@ -444,6 +581,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( shape: &[usize], ids: &[I], dim: usize, + name: &'static str, ) -> Vec<T> { let device = Device::system_default().expect("no device found"); @@ -457,12 +595,6 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let dst_el = ids.len() * left_size * right_size; let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - let name = match core::mem::size_of::<T>() { - 4 => "is_u32_f32", - 2 => "is_u32_f16", - _ => unimplemented!(), - }; - let fence = device.new_fence(); let kernels = Kernels::new(fence); call_index_select( |