diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-01 22:08:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-01 23:08:47 +0200 |
commit | 0fcb40b229a3fc627cdc86513560d2c917b39550 (patch) | |
tree | f2de574d7228c4084229a8befdff5257e796c110 /candle-metal-kernels | |
parent | 6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d (diff) | |
download | candle-0fcb40b229a3fc627cdc86513560d2c917b39550.tar.gz candle-0fcb40b229a3fc627cdc86513560d2c917b39550.tar.bz2 candle-0fcb40b229a3fc627cdc86513560d2c917b39550.zip |
Revert the bf16 gemm metal changes for now. (#2386)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/libMetalFlashAttention.metallib | bin | 137444 -> 116184 bytes | |||
-rw-r--r-- | candle-metal-kernels/src/tests.rs | 40 |
2 files changed, 21 insertions, 19 deletions
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib Binary files differindex d75db5bb..1e2d1acf 100644 --- a/candle-metal-kernels/src/libMetalFlashAttention.metallib +++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index f70f773a..30c454af 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1162,25 +1162,27 @@ fn gemm() { ); // bgemm sanity test - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); - let results = run_gemm( - "bgemm", - (b, m, n, k), - &lhs, - lhs_stride, - 0, - &rhs, - rhs_stride, - 0, - ); - assert_eq!( - approx_bf16(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); + if false { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + } // hgemm sanity test let (b, m, n, k) = (1, 2, 4, 3); |