summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs58
1 files changed, 34 insertions, 24 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index b15d9b36..b91c92d8 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void;
- let size = (data.len() * std::mem::size_of::<T>()) as u64;
+ let size = std::mem::size_of_val(data) as u64;
device.new_buffer_with_data(ptr, size, options)
}
@@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
+ let input = BufferOffset {
+ buffer: &input,
+ offset_in_bytes: 0,
+ };
let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
@@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
&kernels,
name,
v.len(),
- &input,
+ input,
&output,
)
.unwrap();
@@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
&kernels,
name,
x.len(),
- &left,
- &right,
+ BufferOffset::zero_offset(&left),
+ BufferOffset::zero_offset(&right),
&output,
)
.unwrap();
@@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let output = new_buffer(&device, v);
+ let input = BufferOffset {
+ buffer: &input,
+ offset_in_bytes: offset,
+ };
+ let output_b = new_buffer(&device, v);
+ let output = BufferOffset {
+ buffer: &output_b,
+ offset_in_bytes: 0,
+ };
let kernels = Kernels::new();
call_unary_strided(
&device,
@@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
&kernels,
kernel,
shape,
- &input,
+ input,
strides,
- offset,
- &output,
- 0,
+ output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- read_to_vec(&output, v.len())
+ read_to_vec(&output_b, v.len())
}
#[test]
@@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
&kernels,
name,
v.len(),
- &input,
- 0,
+ BufferOffset::zero_offset(&input),
&output,
)
.unwrap();
@@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&kernels,
"affine_f32",
size,
- &input,
+ BufferOffset::zero_offset(&input),
&output,
mul as f32,
add as f32,
@@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
&kernels,
"affine_f32_strided",
shape,
- &input,
+ BufferOffset::zero_offset(&input),
strides,
- 0,
&output,
mul as f32,
add as f32,
@@ -633,7 +641,7 @@ fn index_select_strided() {
fn index_select_f16() {
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
.into_iter()
- .map(|x| f16::from_f32(x))
+ .map(f16::from_f32)
.collect();
let shape = [5, 2];
let stride = [2, 1];
@@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
- let embeddings_buffer = new_buffer(&device, &embeddings);
- let ids_buffer = new_buffer(&device, &ids);
+ let embeddings_buffer = new_buffer(&device, embeddings);
+ let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
- &command_buffer,
+ command_buffer,
&kernels,
name,
shape,
@@ -746,8 +754,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
- let embeddings_buffer = new_buffer(&device, &embeddings);
- let ids_buffer = new_buffer(&device, &ids);
+ let embeddings_buffer = new_buffer(&device, embeddings);
+ let ids_buffer = new_buffer(&device, ids);
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@@ -757,7 +765,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
let kernels = Kernels::new();
call_index_select(
&device,
- &command_buffer,
+ command_buffer,
&kernels,
name,
shape,
@@ -931,6 +939,7 @@ fn softmax() {
);
}
+#[allow(clippy::too_many_arguments)]
fn run_where_cond<I: Clone, T: Clone>(
shape: &[usize],
cond: &[I],
@@ -1148,7 +1157,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
- let sum = data.iter().sum::<f32>() as f32;
+ let sum = data.iter().sum::<f32>();
let count = data.len();
assert!(count > 0);
sum / count as f32
@@ -1162,7 +1171,7 @@ fn random() {
let variance = data
.iter()
.map(|value| {
- let diff = mean - (*value as f32);
+ let diff = mean - *value;
diff * diff
})
.sum::<f32>()
@@ -1787,6 +1796,7 @@ fn avg_pool2d_u32() {
assert_eq!(results, expected);
}
+#[allow(clippy::too_many_arguments)]
fn run_conv_transpose1d<T: Clone>(
input: &[T],
input_shape: &[usize],