From 0fcb40b229a3fc627cdc86513560d2c917b39550 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 1 Aug 2024 22:08:47 +0100 Subject: Revert the bf16 gemm metal changes for now. (#2386) --- .../src/libMetalFlashAttention.metallib | Bin 137444 -> 116184 bytes candle-metal-kernels/src/tests.rs | 40 +++++++++++---------- 2 files changed, 21 insertions(+), 19 deletions(-) (limited to 'candle-metal-kernels') diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib index d75db5bb..1e2d1acf 100644 Binary files a/candle-metal-kernels/src/libMetalFlashAttention.metallib and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index f70f773a..30c454af 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1162,25 +1162,27 @@ fn gemm() { ); // bgemm sanity test - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); - let results = run_gemm( - "bgemm", - (b, m, n, k), - &lhs, - lhs_stride, - 0, - &rhs, - rhs_stride, - 0, - ); - assert_eq!( - approx_bf16(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); + if false { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } // hgemm sanity test let (b, m, n, k) = (1, 2, 4, 3); -- cgit v1.2.3