diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 58 |
1 files changed, 34 insertions, 24 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b15d9b36..b91c92d8 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> { fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { let options = MTLResourceOptions::StorageModeManaged; let ptr = data.as_ptr() as *const c_void; - let size = (data.len() * std::mem::size_of::<T>()) as u64; + let size = std::mem::size_of_val(data) as u64; device.new_buffer_with_data(ptr, size, options) } @@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: 0, + }; let output = new_buffer(&device, v); call_unary_contiguous( &device, @@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> { &kernels, name, v.len(), - &input, + input, &output, ) .unwrap(); @@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V &kernels, name, x.len(), - &left, - &right, + BufferOffset::zero_offset(&left), + BufferOffset::zero_offset(&right), &output, ) .unwrap(); @@ -93,7 +97,15 @@ fn run_strided<T: Clone>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let input = new_buffer(&device, v); - let output = new_buffer(&device, v); + let input = BufferOffset { + buffer: &input, + offset_in_bytes: offset, + }; + let output_b = new_buffer(&device, v); + let output = BufferOffset { + buffer: &output_b, + offset_in_bytes: 0, + }; let kernels = Kernels::new(); call_unary_strided( &device, @@ -101,16 +113,14 @@ fn run_strided<T: Clone>( &kernels, kernel, shape, - &input, + input, strides, - offset, - &output, - 0, + output, ) .unwrap(); command_buffer.commit(); command_buffer.wait_until_completed(); - read_to_vec(&output, v.len()) + read_to_vec(&output_b, v.len()) } #[test] @@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { &kernels, name, v.len(), - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> { &kernels, "affine_f32", size, - &input, + BufferOffset::zero_offset(&input), &output, mul as f32, add as f32, @@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>( &kernels, "affine_f32_strided", shape, - &input, + BufferOffset::zero_offset(&input), strides, - 0, &output, mul as f32, add as f32, @@ -633,7 +641,7 @@ fn index_select_strided() { fn index_select_f16() { let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] .into_iter() - .map(|x| f16::from_f32(x)) + .map(f16::from_f32) .collect(); let shape = [5, 2]; let stride = [2, 1]; @@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( let kernels = Kernels::new(); call_index_select( &device, - &command_buffer, + command_buffer, &kernels, name, shape, @@ -746,8 +754,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); + let embeddings_buffer = new_buffer(&device, embeddings); + let ids_buffer = new_buffer(&device, ids); let left_size: usize = shape[..dim].iter().product(); let right_size: usize = shape[dim + 1..].iter().product(); @@ -757,7 +765,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( let kernels = Kernels::new(); call_index_select( &device, - &command_buffer, + command_buffer, &kernels, name, shape, @@ -931,6 +939,7 @@ fn softmax() { ); } +#[allow(clippy::too_many_arguments)] fn run_where_cond<I: Clone, T: Clone>( shape: &[usize], cond: &[I], @@ -1148,7 +1157,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: #[test] fn random() { fn calc_mean(data: &[f32]) -> f32 { - let sum = data.iter().sum::<f32>() as f32; + let sum = data.iter().sum::<f32>(); let count = data.len(); assert!(count > 0); sum / count as f32 @@ -1162,7 +1171,7 @@ fn random() { let variance = data .iter() .map(|value| { - let diff = mean - (*value as f32); + let diff = mean - *value; diff * diff }) .sum::<f32>() @@ -1787,6 +1796,7 @@ fn avg_pool2d_u32() { assert_eq!(results, expected); } +#[allow(clippy::too_many_arguments)] fn run_conv_transpose1d<T: Clone>( input: &[T], input_shape: &[usize], |