summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs365
1 files changed, 312 insertions, 53 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 2330d48d..1b3153b1 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1,7 +1,14 @@
use super::*;
-use half::f16;
+use half::{bf16, f16};
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
+fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
+ let ptr = buffer.contents() as *const T;
+ assert!(!ptr.is_null());
+ let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
+ slice.to_vec()
+}
+
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const core::ffi::c_void;
@@ -23,13 +30,19 @@ fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
+fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
+ let b = 10f32.powi(digits);
+ v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
+}
+
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_unary_contiguous(
&device,
command_buffer,
@@ -37,23 +50,24 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
name,
v.len(),
&input,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = new_buffer(&device, x);
let right = new_buffer(&device, y);
- let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
+ let output = device.new_buffer(std::mem::size_of_val(x) as u64, options);
call_binary_contiguous(
&device,
command_buffer,
@@ -62,12 +76,12 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
x.len(),
&left,
&right,
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(x.len())
+ read_to_vec(&output, x.len())
}
fn run_strided<T: Clone>(
@@ -81,8 +95,9 @@ fn run_strided<T: Clone>(
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
- let kernels = Kernels::new();
+ let output = new_buffer(&device, v);
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
call_unary_strided(
&device,
command_buffer,
@@ -92,13 +107,13 @@ fn run_strided<T: Clone>(
&input,
strides,
offset,
- &mut output,
+ &output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -201,6 +216,25 @@ fn cos_strided_random() {
}
#[test]
+fn gelu_f16() {
+ let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect();
+ let expected: Vec<f32> = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::HALF);
+ assert_eq!(approx_f16(results, 2), expected);
+}
+
+#[test]
+fn gelu_f32() {
+ let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
+ let expected: Vec<f32> = vec![-0.0, -0.159, 0.0, 0.841, 1.955, 2.996, 10.0, 20.0];
+ let results = run(&v, unary::contiguous::gelu::FLOAT);
+ assert_eq!(approx(results, 3), expected);
+}
+
+#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
@@ -216,11 +250,14 @@ fn binary_add_f32() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let options = MTLResourceOptions::StorageModeManaged;
+ let size = (v.len() * std::mem::size_of::<U>()) as u64;
+ let output = device.new_buffer(size, options);
call_cast_contiguous(
&device,
@@ -229,12 +266,13 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
name,
v.len(),
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<U>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -245,21 +283,28 @@ fn cast_u32_f32() {
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
+ let v = vec![1.0f32, 2.0, 3.0];
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
+
let v = vec![1.0f32; 10_000];
- let results = run(&v, unary::contiguous::cos::FLOAT);
- let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
- assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
- assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
+ let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
+ let results: Vec<f32> = cast(&input, "cast_f16_f32");
+ assert_eq!(results.len(), 10_000);
+ assert_eq!(&results[..10], vec![1.0f32; 10]);
+ assert_eq!(results, vec![1.0f32; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
let size = v.len();
@@ -267,9 +312,46 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
+ "affine_f32",
size,
&input,
- &mut output,
+ &output,
+ mul as f32,
+ add as f32,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, v.len())
+}
+
+fn run_affine_strided<T: Clone>(
+ v: &[T],
+ shape: &[usize],
+ strides: &[usize],
+ mul: f64,
+ add: f64,
+) -> Vec<T> {
+ let device = device();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+
+ let input = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
+
+ call_affine_strided(
+ &device,
+ command_buffer,
+ &kernels,
+ "affine_f32_strided",
+ shape,
+ &input,
+ strides,
+ 0,
+ &output,
mul as f32,
add as f32,
)
@@ -277,7 +359,8 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ let len: usize = shape.iter().product();
+ read_to_vec(&output, len)
}
#[test]
@@ -296,6 +379,18 @@ fn affine() {
}
#[test]
+fn affine_strided() {
+ let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
+ let mul = 1.5;
+ let add = 1.1;
+ let shape = [4];
+ let strides = [2];
+ let result = run_affine_strided(&input, &shape, &strides, mul, add);
+ // 1 on 2
+ assert_eq!(result, vec![2.6, 5.6, 8.6, 11.6]);
+}
+
+#[test]
fn index_select() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
@@ -313,7 +408,26 @@ fn index_select() {
result,
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
);
+}
+#[test]
+fn index_select_f16() {
+ let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
+ .into_iter()
+ .map(|x| f16::from_f32(x))
+ .collect();
+ let shape = [5, 2];
+ let ids = [0u32, 4, 2];
+ let dim = 0;
+ let result = run_index_select(&embedding, &shape, &ids, dim);
+ assert_eq!(
+ approx_f16(result, 4),
+ vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]
+ );
+}
+
+#[test]
+fn index_select_dim1() {
let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let shape = [5, 2];
let ids = [0u32, 1, 0];
@@ -321,7 +435,7 @@ fn index_select() {
let result = run_index_select(&embedding, &shape, &ids, dim);
assert_eq!(
result,
- vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0]
+ vec![1.0f32, 2.0, 1.0, 3.0, 4.0, 3.0, 5.0, 6.0, 5.0, 7.0, 8.0f32, 7.0, 9.0, 10.0, 9.0]
);
}
@@ -341,27 +455,34 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
let dst_el = ids.len() * left_size * right_size;
- let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+ let dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]);
+
+ let name = match core::mem::size_of::<T>() {
+ 4 => "is_u32_f32",
+ 2 => "is_u32_f16",
+ _ => unimplemented!(),
+ };
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
call_index_select(
&device,
&command_buffer,
&kernels,
- "is_u32_f32",
+ name,
shape,
ids.len(),
dim,
&embeddings_buffer,
&ids_buffer,
- &mut dst_buffer,
+ &dst_buffer,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- dst_buffer.read_to_vec::<T>(dst_el)
+ read_to_vec(&dst_buffer, dst_el)
}
#[test]
@@ -427,7 +548,7 @@ fn index_add() {
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
- let result = outputs_buffer.read_to_vec::<f32>(right.len());
+ let result: Vec<f32> = read_to_vec(&outputs_buffer, right.len());
assert_eq!(result, expected);
}
@@ -439,43 +560,49 @@ fn cos_f16() {
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
- assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]);
- assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
+ assert_eq!(approx_f16(results, 2), vec![0.54, -0.42, -0.99]);
+ assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
}
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
- let mut output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
- call_reduce_contiguous(
+ let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
+ let dims = vec![v.len()];
+ let strides = vec![1];
+ call_reduce_strided(
&device,
command_buffer,
&kernels,
name,
- v.len(),
+ &dims,
+ &strides,
out_length,
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(out_length)
+ read_to_vec(&output, out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
- let mut output = new_buffer(&device, v);
+ let output = new_buffer(&device, v);
call_last_softmax(
&device,
command_buffer,
@@ -484,13 +611,14 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
- &mut output,
+ 0,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(v.len())
+ read_to_vec(&output, v.len())
}
#[test]
@@ -498,7 +626,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]);
}
@@ -507,7 +635,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
- let results = run_reduce(&v, out_length, "fast_sum_float");
+ let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
@@ -515,15 +643,33 @@ fn reduce_sum2() {
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
+ let last_dim = 4096;
+ let n = 200;
+ let mut v = vec![0.0; n * last_dim];
+ for i in 0..n {
+ v[i * last_dim] = 20.0;
+ }
+ let results = run_softmax(&v, last_dim, "softmax_f32");
+ let results = approx(results, 4);
+ println!("{results:?}");
+ assert_eq!(
+ results.iter().map(|&s| s.round() as usize).sum::<usize>(),
+ n
+ );
+ assert_eq!(results[0], 1.0);
+ assert_eq!(results[1], 0.0);
+ assert_eq!(results[last_dim], 1.0);
+ assert_eq!(results[2 * last_dim], 1.0);
+
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@@ -531,11 +677,33 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
- let results = run_softmax(&v, last_dim, "softmax_float");
+ let results = run_softmax(&v, last_dim, "softmax_f32");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| f16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_f16");
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
+ );
+
+ let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ .iter()
+ .map(|v| bf16::from_f32(*v))
+ .collect::<Vec<_>>();
+ let last_dim = 6;
+ let results = run_softmax(&v, last_dim, "softmax_bf16");
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
+ );
}
fn run_where_cond<I: Clone, T: Clone>(
@@ -549,7 +717,8 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
- let kernels = Kernels::new();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@@ -571,7 +740,7 @@ fn run_where_cond<I: Clone, T: Clone>(
options,
);
- let mut output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_where_cond_strided(
&device,
command_buffer,
@@ -584,13 +753,13 @@ fn run_where_cond<I: Clone, T: Clone>(
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
- &mut output,
+ &output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
- output.read_to_vec::<T>(length)
+ read_to_vec(&output, length)
}
#[test]
@@ -614,3 +783,93 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
+
+fn run_gemm<T: Clone>(
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs: &[T],
+ lhs_stride: Vec<usize>,
+ lhs_offset: usize,
+ rhs: &[T],
+ rhs_stride: Vec<usize>,
+ rhs_offset: usize,
+) -> Vec<T> {
+ let device = device();
+ let fence = device.new_fence();
+ let kernels = Kernels::new(fence);
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let lhs = device.new_buffer_with_data(
+ lhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(lhs) as u64,
+ options,
+ );
+ let rhs = device.new_buffer_with_data(
+ rhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(rhs) as u64,
+ options,
+ );
+ let length = b * m * n;
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ call_gemm(
+ &device,
+ command_buffer,
+ &kernels,
+ "sgemm",
+ (b, m, n, k),
+ &lhs_stride,
+ lhs_offset,
+ &lhs,
+ &rhs_stride,
+ rhs_offset,
+ &rhs,
+ &output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, length)
+}
+
+#[test]
+fn gemm() {
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ assert_eq!(
+ approx(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ assert_eq!(
+ approx(results, 4),
+ vec![
+ 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
+ 518.0, 548.0, 578.0
+ ]
+ );
+
+ // OFFSET
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
+ let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
+ assert_eq!(
+ approx(results, 4),
+ vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
+ );
+}