summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs51
1 files changed, 27 insertions, 24 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b91c92d8..960ae1df 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -728,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
true,
shape,
stride,
- &embeddings_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&embeddings_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@@ -774,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
false,
shape,
stride,
- &embeddings_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&embeddings_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&dst_buffer,
)
.unwrap();
@@ -819,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&dims,
&strides,
out_length,
- &input,
- 0,
+ BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@@ -974,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
);
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let cond = BufferOffset {
+ buffer: &cond,
+ offset_in_bytes: cond_offset,
+ };
+ let left = BufferOffset {
+ buffer: &left,
+ offset_in_bytes: left_offset,
+ };
+ let right = BufferOffset {
+ buffer: &right,
+ offset_in_bytes: cond_offset,
+ };
call_where_cond_strided(
&device,
command_buffer,
&kernels,
name,
shape,
- &cond,
- (&cond_stride, cond_offset),
- &left,
- (&left_stride, left_offset),
- &right,
- (&cond_stride, cond_offset),
+ cond,
+ &cond_stride,
+ left,
+ &left_stride,
+ right,
+ &cond_stride,
&output,
)
.unwrap();
@@ -1250,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
- &input_buffer,
- 0,
- &ids_buffer,
- 0,
+ BufferOffset::zero_offset(&input_buffer),
+ BufferOffset::zero_offset(&ids_buffer),
&output,
)
.unwrap();
@@ -1355,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
shape,
shape,
dim,
- &input_buffer,
- 0,
- &indices_buffer,
- 0,
+ BufferOffset::zero_offset(&input_buffer),
+ BufferOffset::zero_offset(&indices_buffer),
&output,
)
.unwrap();