summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml1
-rw-r--r--candle-core/Cargo.toml7
-rw-r--r--candle-core/benches/matmul.rs43
-rw-r--r--candle-metal-kernels/src/lib.rs20
4 files changed, 66 insertions, 5 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 7c2e3a7d..9fda5fba 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
+criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml
index e7d3ab6a..0f8c1a9f 100644
--- a/candle-core/Cargo.toml
+++ b/candle-core/Cargo.toml
@@ -34,6 +34,8 @@ zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
+criterion = { workspace = true }
+
[features]
default = []
@@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
+
+[[bench]]
+name = "matmul"
+harness = false
+
diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs
new file mode 100644
index 00000000..8732f451
--- /dev/null
+++ b/candle-core/benches/matmul.rs
@@ -0,0 +1,43 @@
+use candle_core::{DType, Device, Tensor};
+use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
+use std::time::Instant;
+
+fn run(a: &Tensor, b: &Tensor) {
+ a.matmul(&b.t().unwrap()).unwrap();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let b = 1;
+ let m = 1;
+ let n = 2048;
+ let k = 2048;
+
+ let device = Device::new_metal(0).unwrap();
+ let dtype = DType::F32;
+ let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
+ let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
+
+ let flops = b * m * n * k;
+
+ let mut group = c.benchmark_group("matmul_metal");
+ group.throughput(Throughput::Bytes(flops as u64));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(black_box(&lhs), black_box(&rhs));
+ }
+ if let Device::Metal(device) = &device {
+ device.wait_until_completed().unwrap();
+ } else {
+ panic!("Expected metal device");
+ }
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
+
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 0bd7d8cb..0418c96c 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -1297,11 +1297,21 @@ pub fn call_gemm(
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
- let m_simd = 16;
- let n_simd = 16;
- let k_simd = 16;
- let m_splits = 2;
- let n_splits = 2;
+ let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
+ let m_simd = 16;
+ let n_simd = 8;
+ let k_simd = 64;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ } else {
+ let m_simd = 40;
+ let n_simd = 40;
+ let k_simd = 8;
+ let m_splits = 1;
+ let n_splits = 1;
+ (m_simd, n_simd, k_simd, m_splits, n_splits)
+ };
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),