summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorivarflakstad <69173633+ivarflakstad@users.noreply.github.com>2024-08-01 16:06:04 +0800
committerGitHub <noreply@github.com>2024-08-01 10:06:04 +0200
commitfea46cb719d5f59216f5b0a606400f1fd663190e (patch)
tree03bba987b321014e17f4c841d01c57b0eb705bf5 /candle-metal-kernels
parent8696cf64947a7f3b712297426078dcf6ab0d199e (diff)
downloadcandle-fea46cb719d5f59216f5b0a606400f1fd663190e.tar.gz
candle-fea46cb719d5f59216f5b0a606400f1fd663190e.tar.bz2
candle-fea46cb719d5f59216f5b0a606400f1fd663190e.zip
Metal bgemm min changes (#2364)
* Add updated mfa metallib * Add bgemm and tests
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs2
-rw-r--r--candle-metal-kernels/src/libMetalFlashAttention.metallibbin116184 -> 137444 bytes
-rw-r--r--candle-metal-kernels/src/tests.rs78
3 files changed, 76 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e0c97962..743b9fe2 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
+// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");
@@ -1564,6 +1565,7 @@ pub fn call_gemm(
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
+ "bgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib
index 1e2d1acf..d75db5bb 100644
--- a/candle-metal-kernels/src/libMetalFlashAttention.metallib
+++ b/candle-metal-kernels/src/libMetalFlashAttention.metallib
Binary files differ
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 8c38e74a..f70f773a 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() {
}
fn run_gemm<T: Clone>(
+ name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
@@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>(
&device,
command_buffer,
&kernels,
- "sgemm",
+ name,
(b, m, n, k),
&lhs_stride,
lhs_offset,
@@ -1100,7 +1101,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
- let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ let results = run_gemm(
+ "sgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
@@ -1111,7 +1121,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
- let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ let results = run_gemm(
+ "sgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
assert_eq!(
approx(results, 4),
vec![
@@ -1127,11 +1146,62 @@ fn gemm() {
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
- let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
+ let results = run_gemm(
+ "sgemm",
+ (1, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 12 * 4,
+ );
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
+
+ // bgemm sanity test
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let results = run_gemm(
+ "bgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+
+ // hgemm sanity test
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
+ let results = run_gemm(
+ "hgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {