summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src/tests.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-metal-kernels/src/tests.rs')
-rw-r--r--candle-metal-kernels/src/tests.rs78
1 files changed, 74 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs
index 8c38e74a..f70f773a 100644
--- a/candle-metal-kernels/src/tests.rs
+++ b/candle-metal-kernels/src/tests.rs
@@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() {
}
fn run_gemm<T: Clone>(
+ name: &'static str,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
@@ -1076,7 +1077,7 @@ fn run_gemm<T: Clone>(
&device,
command_buffer,
&kernels,
- "sgemm",
+ name,
(b, m, n, k),
&lhs_stride,
lhs_offset,
@@ -1100,7 +1101,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
- let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ let results = run_gemm(
+ "sgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
@@ -1111,7 +1121,16 @@ fn gemm() {
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
- let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
+ let results = run_gemm(
+ "sgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
assert_eq!(
approx(results, 4),
vec![
@@ -1127,11 +1146,62 @@ fn gemm() {
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
- let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
+ let results = run_gemm(
+ "sgemm",
+ (1, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 12 * 4,
+ );
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
+
+ // bgemm sanity test
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<bf16> = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<bf16> = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect();
+ let results = run_gemm(
+ "bgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
+ assert_eq!(
+ approx_bf16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
+
+ // hgemm sanity test
+ let (b, m, n, k) = (1, 2, 4, 3);
+ let lhs_stride = vec![m * k, k, 1];
+ let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
+ let rhs_stride = vec![n * k, n, 1];
+ let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
+ let results = run_gemm(
+ "hgemm",
+ (b, m, n, k),
+ &lhs,
+ lhs_stride,
+ 0,
+ &rhs,
+ rhs_stride,
+ 0,
+ );
+ assert_eq!(
+ approx_f16(results, 4),
+ vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
+ );
}
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {