diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 365 |
1 files changed, 312 insertions, 53 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 2330d48d..1b3153b1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1,7 +1,14 @@ use super::*; -use half::f16; +use half::{bf16, f16}; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; +fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { + let ptr = buffer.contents() as *const T; + assert!(!ptr.is_null()); + let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; + slice.to_vec() +} + fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const core::ffi::c_void; @@ -23,13 +30,19 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> { v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() } +fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -37,23 +50,24 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { name, v.len(), &input, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; let left = new_buffer(&device, x); let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + let output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, command_buffer, @@ -62,12 +76,12 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V x.len(), &left, &right, - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(x.len()) + read_to_vec(&output, x.len()) } fn run_strided<T: Clone>( @@ -81,8 +95,9 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - let kernels = Kernels::new(); + let output = new_buffer(&device, v); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_unary_strided( &device, command_buffer, @@ -92,13 +107,13 @@ fn run_strided<T: Clone>( &input, strides, offset, - &mut output, + &output, 0, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -201,6 +216,25 @@ fn cos_strided_random() { } #[test] +fn gelu_f16() { + let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::HALF); + assert_eq!(approx_f16(results, 2), expected); +} + +#[test] +fn gelu_f32() { + let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]; + let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0]; + let results = run(&v, unary::contiguous::gelu::FLOAT); + assert_eq!(approx(results, 3), expected); +} + +#[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; let right = vec![2.0f32, 3.1, 4.2]; @@ -216,11 +250,14 @@ fn binary_add_f32() { fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; + let size = (v.len() * std::mem::size_of::<U>()) as u64; + let output = device.new_buffer(size, options); call_cast_contiguous( &device, @@ -229,12 +266,13 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { name, v.len(), &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<U>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -245,21 +283,28 @@ fn cast_u32_f32() { assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32, 2.0, 3.0]; + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results, vec![1.0f32, 2.0, 3.0]); + let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect(); + let results: Vec<f32> = cast(&input, "cast_f16_f32"); + assert_eq!(results.len(), 10_000); + assert_eq!(&results[..10], vec![1.0f32; 10]); + assert_eq!(results, vec![1.0f32; 10_000]); } fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); let size = v.len(); @@ -267,9 +312,46 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &device, command_buffer, &kernels, + "affine_f32", size, &input, - &mut output, + &output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, v.len()) +} + +fn run_affine_strided<T: Clone>( + v: &[T], + shape: &[usize], + strides: &[usize], + mul: f64, + add: f64, +) -> Vec<T> { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let output = new_buffer(&device, v); + + call_affine_strided( + &device, + command_buffer, + &kernels, + "affine_f32_strided", + shape, + &input, + strides, + 0, + &output, mul as f32, add as f32, ) @@ -277,7 +359,8 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + let len: usize = shape.iter().product(); + read_to_vec(&output, len) } #[test] @@ -296,6 +379,18 @@ fn affine() { } #[test] +fn affine_strided() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let shape = [4]; + let strides = [2]; + let result = run_affine_strided(&input, &shape, &strides, mul, add); + // 1 on 2 + assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]); +} + +#[test] fn index_select() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; @@ -313,7 +408,26 @@ fn index_select() { result, vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] ); +} +#[test] +fn index_select_f16() { + let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] + .into_iter() + .map(|x| f16::from_f32(x)) + .collect(); + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + approx_f16(result, 4), + vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0] + ); +} + +#[test] +fn index_select_dim1() { let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; let shape = [5, 2]; let ids = [0u32, 1, 0]; @@ -321,7 +435,7 @@ fn index_select() { let result = run_index_select(&embedding, &shape, &ids, dim); assert_eq!( result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0] ); } @@ -341,27 +455,34 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let name = match core::mem::size_of::<T>() { + 4 => "is_u32_f32", + 2 => "is_u32_f16", + _ => unimplemented!(), + }; - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); call_index_select( &device, &command_buffer, &kernels, - "is_u32_f32", + name, shape, ids.len(), dim, &embeddings_buffer, &ids_buffer, - &mut dst_buffer, + &dst_buffer, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - dst_buffer.read_to_vec::<T>(dst_el) + read_to_vec(&dst_buffer, dst_el) } #[test] @@ -427,7 +548,7 @@ fn index_add() { let expected = vec![ 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, ]; - let result = outputs_buffer.read_to_vec::<f32>(right.len()); + let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len()); assert_eq!(result, expected); } @@ -439,43 +560,49 @@ fn cos_f16() { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]); + assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]); } fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); let options = MTLResourceOptions::StorageModeManaged; - let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); - call_reduce_contiguous( + let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); + let dims = vec![v.len()]; + let strides = vec![1]; + call_reduce_strided( &device, command_buffer, &kernels, name, - v.len(), + &dims, + &strides, out_length, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(out_length) + read_to_vec(&output, out_length) } fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); + let output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, @@ -484,13 +611,14 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta v.len(), last_dim, &input, - &mut output, + 0, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(v.len()) + read_to_vec(&output, v.len()) } #[test] @@ -498,7 +626,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -507,7 +635,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_float"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); } @@ -515,15 +643,33 @@ fn reduce_sum2() { fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] ); + let last_dim = 4096; + let n = 200; + let mut v = vec![0.0; n * last_dim]; + for i in 0..n { + v[i * last_dim] = 20.0; + } + let results = run_softmax(&v, last_dim, "softmax_f32"); + let results = approx(results, 4); + println!("{results:?}"); + assert_eq!( + results.iter().map(|&s| s.round() as usize).sum::<usize>(), + n + ); + assert_eq!(results[0], 1.0); + assert_eq!(results[1], 0.0); + assert_eq!(results[last_dim], 1.0); + assert_eq!(results[2 * last_dim], 1.0); + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] @@ -531,11 +677,33 @@ fn softmax() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); + let results = run_softmax(&v, last_dim, "softmax_f32"); assert_eq!( approx(results, 4), vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_f16"); + assert_eq!( + approx_f16(results, 4), + vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + .iter() + .map(|v| bf16::from_f32(*v)) + .collect::<Vec<_>>(); + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_bf16"); + assert_eq!( + approx_bf16(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328] + ); } fn run_where_cond<I: Clone, T: Clone>( @@ -549,7 +717,8 @@ fn run_where_cond<I: Clone, T: Clone>( name: &'static str, ) -> Vec<T> { let device = device(); - let kernels = Kernels::new(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; @@ -571,7 +740,7 @@ fn run_where_cond<I: Clone, T: Clone>( options, ); - let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); call_where_cond_strided( &device, command_buffer, @@ -584,13 +753,13 @@ fn run_where_cond<I: Clone, T: Clone>( (&left_stride, left_offset), &right, (&cond_stride, cond_offset), - &mut output, + &output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - output.read_to_vec::<T>(length) + read_to_vec(&output, length) } #[test] @@ -614,3 +783,93 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } + +fn run_gemm<T: Clone>( + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: Vec<usize>, + lhs_offset: usize, + rhs: &[T], + rhs_stride: Vec<usize>, + rhs_offset: usize, +) -> Vec<T> { + let device = device(); + let fence = device.new_fence(); + let kernels = Kernels::new(fence); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + call_gemm( + &device, + command_buffer, + &kernels, + "sgemm", + (b, m, n, k), + &lhs_stride, + lhs_offset, + &lhs, + &rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +#[test] +fn gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); +} |