summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs269
1 files changed, 234 insertions, 35 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 30c454af..8b1adbde 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
#[test]
fn cast_f32() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -360,7 +360,7 @@ fn cast_f32() {
#[test]
fn cast_f16() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -391,7 +391,7 @@ fn cast_f16() {
#[test]
fn cast_bf16() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -422,7 +422,7 @@ fn cast_bf16() {
#[test]
fn cast_u32() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -453,7 +453,7 @@ fn cast_u32() {
#[test]
fn cast_u8() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -484,7 +484,7 @@ fn cast_u8() {
#[test]
fn cast_i64() {
- let v_f64 = vec![1.0f64, 2.0, 3.0];
+ let v_f64 = [1.0f64, 2.0, 3.0];
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
@@ -911,7 +911,7 @@ fn softmax() {
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
);
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@@ -922,7 +922,7 @@ fn softmax() {
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
);
- let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
+ let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
@@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() {
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
+#[allow(clippy::too_many_arguments)]
fn run_gemm<T: Clone>(
name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
- lhs_stride: Vec<usize>,
+ lhs_stride: &[usize],
lhs_offset: usize,
rhs: &[T],
- rhs_stride: Vec<usize>,
+ rhs_stride: &[usize],
rhs_offset: usize,
) -> Vec<T> {
let device = device();
@@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>(
&kernels,
name,
(b, m, n, k),
- &lhs_stride,
+ lhs_stride,
lhs_offset,
&lhs,
- &rhs_stride,
+ rhs_stride,
rhs_offset,
&rhs,
&output,
@@ -1105,10 +1106,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
- lhs_stride,
+ &lhs_stride,
0,
&rhs,
- rhs_stride,
+ &rhs_stride,
0,
);
assert_eq!(
@@ -1125,10 +1126,10 @@ fn gemm() {
"sgemm",
(b, m, n, k),
&lhs,
- lhs_stride,
+ &lhs_stride,
0,
&rhs,
- rhs_stride,
+ &rhs_stride,
0,
);
assert_eq!(
@@ -1150,10 +1151,10 @@ fn gemm() {
"sgemm",
(1, m, n, k),
&lhs,
- lhs_stride,
+ &lhs_stride,
0,
&rhs,
- rhs_stride,
+ &rhs_stride,
12 * 4,
);
assert_eq!(
@@ -1172,10 +1173,10 @@ fn gemm() {
"bgemm",
(b, m, n, k),
&lhs,
- lhs_stride,
+ &lhs_stride,
0,
&rhs,
- rhs_stride,
+ &rhs_stride,
0,
);
assert_eq!(
@@ -1194,10 +1195,10 @@ fn gemm() {
"hgemm",
(b, m, n, k),
&lhs,
- lhs_stride,
+ &lhs_stride,
0,
&rhs,
- rhs_stride,
+ &rhs_stride,
0,
);
assert_eq!(
@@ -1206,6 +1207,204 @@ fn gemm() {
);
}
+#[allow(clippy::too_many_arguments)]
+fn run_mlx_gemm<T: Clone>(
+ dtype: GemmDType,
+ (b, m, n, k): (usize, usize, usize, usize),
+ lhs: &[T],
+ lhs_stride: &[usize],
+ lhs_offset: usize,
+ rhs: &[T],
+ rhs_stride: &[usize],
+ rhs_offset: usize,
+) -> Vec<T> {
+ let device = device();
+ let kernels = Kernels::new();
+ let command_queue = device.new_command_queue();
+ let command_buffer = command_queue.new_command_buffer();
+ let options = MTLResourceOptions::StorageModeManaged;
+
+ let lhs = device.new_buffer_with_data(
+ lhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(lhs) as u64,
+ options,
+ );
+ let rhs = device.new_buffer_with_data(
+ rhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(rhs) as u64,
+ options,
+ );
+ let length = b * m * n;
+ let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
+ call_mlx_gemm(
+ &device,
+ command_buffer,
+ &kernels,
+ dtype,
+ (b, m, n, k),
+ lhs_stride,
+ lhs_offset,
+ &lhs,
+ rhs_stride,
+ rhs_offset,
+ &rhs,
+ &output,
+ )
+ .unwrap();
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+
+ read_to_vec(&output, length)
+}
+
+fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) {
+ use rand::SeedableRng;
+ use rand_distr::Distribution;
+
+ let mut rng = rand::rngs::StdRng::seed_from_u64(42424242);
+ let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
+
+ let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect();
+ let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect();
+ let v1: Vec<f32> = run_mlx_gemm(
+ dtype,
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[k * n, n, 1],
+ 0,
+ );
+ let v2: Vec<f32> = run_gemm(
+ "sgemm",
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[k * n, n, 1],
+ 0,
+ );
+ for (a, b) in v1.iter().zip(v2.iter()) {
+ let diff = (a - b).abs();
+ assert_eq!((diff * 1e4).round(), 0.)
+ }
+}
+
+#[test]
+fn mlx_vs_mfa() {
+ mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32);
+ mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32);
+ mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32);
+ mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32);
+ mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32);
+}
+
+#[test]
+fn mlx_gemm() {
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_mlx_gemm(
+ GemmDType::F32,
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[n * k, n, 1],
+ 0,
+ );
+ assert_eq!(
+ approx(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let results = run_mlx_gemm(
+ GemmDType::F32,
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[n * k, n, 1],
+ 0,
+ );
+ assert_eq!(
+ approx(results, 4),
+ vec![
+ 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
+ 518.0, 548.0, 578.0
+ ]
+ );
+
+ // OFFSET
+ let (b, m, n, k) = (2, 2, 4, 3);
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
+ let results = run_mlx_gemm(
+ GemmDType::F32,
+ (1, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[n * k, n, 1],
+ 12 * 4,
+ );
+ assert_eq!(
+ approx(results, 4),
+ vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
+ );
+
+ // bgemm sanity test
+ {
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let results = run_mlx_gemm(
+ GemmDType::BF16,
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[n * k, n, 1],
+ 0,
+ );
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+ }
+
+ {
+ // hgemm sanity test
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
+ let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
+ let results = run_mlx_gemm(
+ GemmDType::F16,
+ (b, m, n, k),
+ &lhs,
+ &[m * k, k, 1],
+ 0,
+ &rhs,
+ &[n * k, n, 1],
+ 0,
+ );
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+ }
+}
+
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
@@ -1280,7 +1479,7 @@ fn random() {
variance.sqrt()
}
- let shape = vec![1024, 10];
+ let shape = [1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
@@ -1636,7 +1835,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
- let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
+ let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@@ -1656,7 +1855,7 @@ fn max_pool2d_f16() {
&strides,
"max_pool2d_f16",
);
- let expected = vec![5.0, 7.0, 13.0, 15.0]
+ let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::f16::from_f32(*v))
.collect::<Vec<_>>();
@@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
- let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
+ let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() {
&strides,
"max_pool2d_bf16",
);
- let expected = vec![5.0, 7.0, 13.0, 15.0]
+ let expected = [5.0, 7.0, 13.0, 15.0]
.iter()
.map(|v| half::bf16::from_f32(*v))
.collect::<Vec<_>>();
@@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() {
&strides,
"avg_pool2d_f16",
);
- let expected = vec![
+ let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() {
&strides,
"avg_pool2d_bf16",
);
- let expected = vec![
+ let expected = [
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
]
.iter()
@@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() {
#[test]
fn conv_transpose1d_f16() {
- let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
+ let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
- let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
+ let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
@@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() {
"conv_transpose1d_f16",
);
- let expected = vec![1., 4., 10., 20., 25., 24., 16.]
+ let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
@@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() {
#[test]
fn conv_transpose1d_bf16() {
- let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
+ let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
let input_shape = &[1, 1, 4];
let input_stride = &[4, 4, 1];
- let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
+ let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0]
.iter()
.map(|v| bf16::from_f32(*v))
.collect();
@@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() {
"conv_transpose1d_bf16",
);
- let expected = vec![1., 4., 10., 20., 25., 24., 16.]
+ let expected = [1., 4., 10., 20., 25., 24., 16.]
.iter()
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();