From 931432ed55918886680e37a280c3ff25d7ee9839 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 13 Dec 2023 16:58:36 +0100 Subject: Fixing tests + matmul from MFA --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'Cargo.toml') diff --git a/Cargo.toml b/Cargo.toml index ba09b1d4..7c2e3a7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } +metal = { version = "0.27.0", features = ["mps"]} [profile.release-with-debug] inherits = "release" -- cgit v1.2.3 From 9b5e4843a63180a2803b1e836b4ca90f14281d03 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 20 Dec 2023 09:54:19 +0100 Subject: Optimizing decode matmul (Phi at 28tok/s on M3). Adding some benchmark in order to help checking out matmul performance. --- Cargo.toml | 1 + candle-core/Cargo.toml | 7 +++++++ candle-core/benches/matmul.rs | 43 +++++++++++++++++++++++++++++++++++++++++ candle-metal-kernels/src/lib.rs | 20 ++++++++++++++----- 4 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 candle-core/benches/matmul.rs (limited to 'Cargo.toml') diff --git a/Cargo.toml b/Cargo.toml index 7c2e3a7d..9fda5fba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ accelerate-src = { version = "0.3.2" } anyhow = { version = "1", features = ["backtrace"] } byteorder = "1.4.3" clap = { version = "4.2.4", features = ["derive"] } +criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.9.14", features = ["f16"] } gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] } hf-hub = "0.3.0" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index e7d3ab6a..0f8c1a9f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -34,6 +34,8 @@ zip = { workspace = true } [dev-dependencies] anyhow = { workspace = true } clap = { workspace = true } +criterion = { workspace = true } + [features] default = [] @@ -42,3 +44,8 @@ cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] metal = ["dep:metal", "dep:candle-metal-kernels"] + +[[bench]] +name = "matmul" +harness = false + diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs new file mode 100644 index 00000000..8732f451 --- /dev/null +++ b/candle-core/benches/matmul.rs @@ -0,0 +1,43 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group("matmul_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 0bd7d8cb..0418c96c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1297,11 +1297,21 @@ pub fn call_gemm( let batched = b > 1; let fused_activation = false; let fused_bias = false; - let m_simd = 16; - let n_simd = 16; - let k_simd = 16; - let m_splits = 2; - let n_splits = 2; + let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { + let m_simd = 16; + let n_simd = 8; + let k_simd = 64; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + } else { + let m_simd = 40; + let n_simd = 40; + let k_simd = 8; + let m_splits = 1; + let n_splits = 1; + (m_simd, n_simd, k_simd, m_splits, n_splits) + }; let constants = Some(ConstantValues::new(vec![ (0, Value::USize(m)), (1, Value::USize(n)), -- cgit v1.2.3