diff options
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 269 |
1 files changed, 234 insertions, 35 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 30c454af..8b1adbde 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -329,7 +329,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> { #[test] fn cast_f32() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -360,7 +360,7 @@ fn cast_f32() { #[test] fn cast_f16() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -391,7 +391,7 @@ fn cast_f16() { #[test] fn cast_bf16() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -422,7 +422,7 @@ fn cast_bf16() { #[test] fn cast_u32() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -453,7 +453,7 @@ fn cast_u32() { #[test] fn cast_u8() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -484,7 +484,7 @@ fn cast_u8() { #[test] fn cast_i64() { - let v_f64 = vec![1.0f64, 2.0, 3.0]; + let v_f64 = [1.0f64, 2.0, 3.0]; let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect(); let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect(); let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect(); @@ -911,7 +911,7 @@ fn softmax() { vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] ); - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] .iter() .map(|v| f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -922,7 +922,7 @@ fn softmax() { vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338] ); - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] + let v = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] .iter() .map(|v| bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1045,14 +1045,15 @@ fn where_cond_u32_f32() { assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } +#[allow(clippy::too_many_arguments)] fn run_gemm<T: Clone>( name: &'static str, (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], - lhs_stride: Vec<usize>, + lhs_stride: &[usize], lhs_offset: usize, rhs: &[T], - rhs_stride: Vec<usize>, + rhs_stride: &[usize], rhs_offset: usize, ) -> Vec<T> { let device = device(); @@ -1079,10 +1080,10 @@ fn run_gemm<T: Clone>( &kernels, name, (b, m, n, k), - &lhs_stride, + lhs_stride, lhs_offset, &lhs, - &rhs_stride, + rhs_stride, rhs_offset, &rhs, &output, @@ -1105,10 +1106,10 @@ fn gemm() { "sgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1125,10 +1126,10 @@ fn gemm() { "sgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1150,10 +1151,10 @@ fn gemm() { "sgemm", (1, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 12 * 4, ); assert_eq!( @@ -1172,10 +1173,10 @@ fn gemm() { "bgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1194,10 +1195,10 @@ fn gemm() { "hgemm", (b, m, n, k), &lhs, - lhs_stride, + &lhs_stride, 0, &rhs, - rhs_stride, + &rhs_stride, 0, ); assert_eq!( @@ -1206,6 +1207,204 @@ fn gemm() { ); } +#[allow(clippy::too_many_arguments)] +fn run_mlx_gemm<T: Clone>( + dtype: GemmDType, + (b, m, n, k): (usize, usize, usize, usize), + lhs: &[T], + lhs_stride: &[usize], + lhs_offset: usize, + rhs: &[T], + rhs_stride: &[usize], + rhs_offset: usize, +) -> Vec<T> { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(rhs) as u64, + options, + ); + let length = b * m * n; + let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options); + call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + lhs_stride, + lhs_offset, + &lhs, + rhs_stride, + rhs_offset, + &rhs, + &output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + read_to_vec(&output, length) +} + +fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) { + use rand::SeedableRng; + use rand_distr::Distribution; + + let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); + let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); + + let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect(); + let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect(); + let v1: Vec<f32> = run_mlx_gemm( + dtype, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + let v2: Vec<f32> = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[k * n, n, 1], + 0, + ); + for (a, b) in v1.iter().zip(v2.iter()) { + let diff = (a - b).abs(); + assert_eq!((diff * 1e4).round(), 0.) + } +} + +#[test] +fn mlx_vs_mfa() { + mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32); + mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32); + mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32); + mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32); + mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32); +} + +#[test] +fn mlx_gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_mlx_gemm( + GemmDType::F32, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx(results, 4), + vec![ + 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, + 518.0, 548.0, 578.0 + ] + ); + + // OFFSET + let (b, m, n, k) = (2, 2, 4, 3); + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 + let results = run_mlx_gemm( + GemmDType::F32, + (1, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 12 * 4, + ); + assert_eq!( + approx(results, 4), + vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] + ); + + // bgemm sanity test + { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::BF16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } + + { + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_mlx_gemm( + GemmDType::F16, + (b, m, n, k), + &lhs, + &[m * k, k, 1], + 0, + &rhs, + &[n * k, n, 1], + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } +} + fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> { let device = device(); let kernels = Kernels::new(); @@ -1280,7 +1479,7 @@ fn random() { variance.sqrt() } - let shape = vec![1024, 10]; + let shape = [1024, 10]; let length = shape.iter().product::<usize>(); let seed = 299792458; @@ -1636,7 +1835,7 @@ fn max_pool2d_f16() { &strides, "max_pool2d_f16", ); - let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] .iter() .map(|v| half::f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1656,7 +1855,7 @@ fn max_pool2d_f16() { &strides, "max_pool2d_f16", ); - let expected = vec![5.0, 7.0, 13.0, 15.0] + let expected = [5.0, 7.0, 13.0, 15.0] .iter() .map(|v| half::f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1679,7 +1878,7 @@ fn max_pool2d_bf16() { &strides, "max_pool2d_bf16", ); - let expected = vec![5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] + let expected = [5.0, 6.0, 7.0, 9.0, 10.0, 11.0, 13.0, 14.0, 15.0] .iter() .map(|v| half::bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1699,7 +1898,7 @@ fn max_pool2d_bf16() { &strides, "max_pool2d_bf16", ); - let expected = vec![5.0, 7.0, 13.0, 15.0] + let expected = [5.0, 7.0, 13.0, 15.0] .iter() .map(|v| half::bf16::from_f32(*v)) .collect::<Vec<_>>(); @@ -1818,7 +2017,7 @@ fn avg_pool2d_f16() { &strides, "avg_pool2d_f16", ); - let expected = vec![ + let expected = [ 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, ] .iter() @@ -1843,7 +2042,7 @@ fn avg_pool2d_bf16() { &strides, "avg_pool2d_bf16", ); - let expected = vec![ + let expected = [ 2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000, ] .iter() @@ -1981,14 +2180,14 @@ fn conv_transpose1d_f32() { #[test] fn conv_transpose1d_f16() { - let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + let input: Vec<f16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| f16::from_f32(*v)) .collect(); let input_shape = &[1, 1, 4]; let input_stride = &[4, 4, 1]; - let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0] + let kernel: Vec<f16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| f16::from_f32(*v)) .collect(); @@ -2009,7 +2208,7 @@ fn conv_transpose1d_f16() { "conv_transpose1d_f16", ); - let expected = vec![1., 4., 10., 20., 25., 24., 16.] + let expected = [1., 4., 10., 20., 25., 24., 16.] .iter() .map(|v| f16::from_f32(*v)) .collect::<Vec<_>>(); @@ -2018,14 +2217,14 @@ fn conv_transpose1d_f16() { #[test] fn conv_transpose1d_bf16() { - let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + let input: Vec<bf16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| bf16::from_f32(*v)) .collect(); let input_shape = &[1, 1, 4]; let input_stride = &[4, 4, 1]; - let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0] + let kernel: Vec<bf16> = [1.0, 2.0, 3.0, 4.0] .iter() .map(|v| bf16::from_f32(*v)) .collect(); @@ -2046,7 +2245,7 @@ fn conv_transpose1d_bf16() { "conv_transpose1d_bf16", ); - let expected = vec![1., 4., 10., 20., 25., 24., 16.] + let expected = [1., 4., 10., 20., 25., 24., 16.] .iter() .map(|v| bf16::from_f32(*v)) .collect::<Vec<_>>(); |