diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 51 |
1 files changed, 27 insertions, 24 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index b91c92d8..960ae1df 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>( true, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>( false, shape, stride, - &embeddings_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&embeddings_buffer), + BufferOffset::zero_offset(&ids_buffer), &dst_buffer, ) .unwrap(); @@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T &dims, &strides, out_length, - &input, - 0, + BufferOffset::zero_offset(&input), &output, ) .unwrap(); @@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>( ); let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + let cond = BufferOffset { + buffer: &cond, + offset_in_bytes: cond_offset, + }; + let left = BufferOffset { + buffer: &left, + offset_in_bytes: left_offset, + }; + let right = BufferOffset { + buffer: &right, + offset_in_bytes: cond_offset, + }; call_where_cond_strided( &device, command_buffer, &kernels, name, shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), + cond, + &cond_stride, + left, + &left_stride, + right, + &cond_stride, &output, ) .unwrap(); @@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &ids_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&ids_buffer), &output, ) .unwrap(); @@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>( shape, shape, dim, - &input_buffer, - 0, - &indices_buffer, - 0, + BufferOffset::zero_offset(&input_buffer), + BufferOffset::zero_offset(&indices_buffer), &output, ) .unwrap(); |