summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-01 22:08:47 +0100
committerGitHub <noreply@github.com>2024-08-01 23:08:47 +0200
commit0fcb40b229a3fc627cdc86513560d2c917b39550 (patch)
treef2de574d7228c4084229a8befdff5257e796c110 /candle-metal-kernels
parent6991a37b94fdcfb6c1d69b7ac4b6d6b96654111d (diff)
downloadcandle-0fcb40b229a3fc627cdc86513560d2c917b39550.tar.gz
candle-0fcb40b229a3fc627cdc86513560d2c917b39550.tar.bz2
candle-0fcb40b229a3fc627cdc86513560d2c917b39550.zip
Revert the bf16 gemm metal changes for now. (#2386)
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin137444 -> 116184 bytes
-rw-r--r--candle-metal-kernels/src/tests.rs40
2 files changed, 21 insertions, 19 deletions
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index d75db5bb..1e2d1acf 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index f70f773a..30c454af 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1162,25 +1162,27 @@ fn gemm() {
);
// bgemm sanity test
- let (b, m, n, k) = (1, 2, 4, 3);
- let lhs_stride = vec![m * k, k, 1];
- let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
- let rhs_stride = vec![n * k, n, 1];
- let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
- let results = run_gemm(
- "bgemm",
- (b, m, n, k),
- &lhs,
- lhs_stride,
- 0,
- &rhs,
- rhs_stride,
- 0,
- );
- assert_eq!(
- approx_bf16(results, 4),
- vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
- );
+ if false {
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let results = run_gemm(
+ "bgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+ }
// hgemm sanity test
let (b, m, n, k) = (1, 2, 4, 3);