diff options
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | candle-core/Cargo.toml | 7 | ||||
-rw-r--r-- | candle-core/benches/matmul.rs | 43 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 20 |
4 files changed, 66 insertions, 5 deletions
@@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } +criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index e7d3ab6a..0f8c1a9f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -34,6 +34,8 @@ zip = { workspace = true } [dev-dependencies] anyhow = { workspace = true } clap = { workspace = true } +criterion = { workspace = true } + [features] default = [] @@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] + +[[bench]] +name = "matmul" +harness = false + diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs new file mode 100644 index 00000000..8732f451 --- /dev/null +++ b/candle-core/benches/matmul.rs @@ -0,0 +1,43 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group("matmul_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0bd7d8cb..0418c96c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1297,11 +1297,21 @@ pub fn call_gemm( let batched = b > 1; let fused_activation = false; let fused_bias = false; - let m_simd = 16; - let n_simd = 16; - let k_simd = 16; - let m_splits = 2; - let n_splits = 2; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 16; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 8; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; let constants = Some(ConstantValues::new(vec![ (0, Value::USize(m)), (1, Value::USize(n)), |