diff options
author | ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-08-01 16:06:04 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 10:06:04 +0200 |
commit | fea46cb719d5f59216f5b0a606400f1fd663190e (patch) | |
tree | 03bba987b321014e17f4c841d01c57b0eb705bf5 /candle-metal-kernels | |
parent | 8696cf64947a7f3b712297426078dcf6ab0d199e (diff) | |
download | candle-fea46cb719d5f59216f5b0a606400f1fd663190e.tar.gz candle-fea46cb719d5f59216f5b0a606400f1fd663190e.tar.bz2 candle-fea46cb719d5f59216f5b0a606400f1fd663190e.zip |
Metal bgemm min changes (#2364)
* Add updated mfa metallib
* Add bgemm and tests
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 116184 -> 137444 bytes | |||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 78 |
3 files changed, 76 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e0c97962..743b9fe2 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); +// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); const SORT: &str = include_str!("sort.metal"); @@ -1564,6 +1565,7 @@ pub fn call_gemm( let bytes = match name { "sgemm" => 4, "hgemm" => 2, + "bgemm" => 2, other => { return Err(MetalKernelError::LoadLibraryError(format!( "{other} is not a valid kernel for gemm" diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex 1e2d1acf..d75db5bb 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8c38e74a..f70f773a 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() { } fn run_gemm<T: Clone>( + name: &'static str, (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], lhs_stride: Vec<usize>, @@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>( &device, command_buffer, &kernels, - "sgemm", + name, (b, m, n, k), &lhs_stride, lhs_offset, @@ -1100,7 +1101,16 @@ fn gemm() { let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); assert_eq!( approx(results, 4), vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] @@ -1111,7 +1121,16 @@ fn gemm() { let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); assert_eq!( approx(results, 4), vec![ @@ -1127,11 +1146,62 @@ fn gemm() { let rhs_stride = vec![n * k, n, 1]; let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 - let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); + let results = run_gemm( + "sgemm", + (1, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 12 * 4, + ); assert_eq!( approx(results, 4), vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); + + // bgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_gemm( + "hgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); } fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> { |