From 0cb0bd1dfaac97425a0805a2e7f24ea7992236c0 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 11 Sep 2024 21:52:37 +0100 Subject: Add some metal gemm benchark. (#2471) * Add some metal gemm benchark. * More benchmarks. --- candle-metal-kernels/Cargo.toml | 2 + candle-metal-kernels/examples/metal_benchmarks.rs | 136 ++++++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 candle-metal-kernels/examples/metal_benchmarks.rs (limited to 'candle-metal-kernels') diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 772452c9..8f92099d 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -17,10 +17,12 @@ thiserror = "1" tracing = "0.1.37" [dev-dependencies] +clap = { version = "4.2.4", features = ["derive"] } half = { version = "2.3.1", features = [ "num-traits", "use-intrinsics", "rand_distr", ] } +anyhow = "1" rand = "0.8.5" rand_distr = "0.4.3" diff --git a/candle-metal-kernels/examples/metal_benchmarks.rs b/candle-metal-kernels/examples/metal_benchmarks.rs new file mode 100644 index 00000000..c9c27997 --- /dev/null +++ b/candle-metal-kernels/examples/metal_benchmarks.rs @@ -0,0 +1,136 @@ +use anyhow::Result; +use candle_metal_kernels::GemmDType; +/// This example contains some simple benchmarks so that it's easy to run them in perf etc. +use clap::{Parser, Subcommand}; +use half::f16; + +fn run_gemm(f32: bool, n: usize) -> Result<()> { + const WARMUP_ITERS: usize = 2; + const MIN_DUR: f64 = 4.; + + let device = metal::Device::system_default().unwrap(); + + let (b, m, n, k) = (1, n, n, n); + let kernels = candle_metal_kernels::Kernels::new(); + let command_queue = device.new_command_queue(); + let options = metal::MTLResourceOptions::StorageModeManaged; + + let (lhs, rhs) = if f32 { + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + } else { + let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + }; + let (dtype, name, sizeof) = if f32 { + (GemmDType::F32, "sgemm", core::mem::size_of::()) + } else { + (GemmDType::F16, "hgemm", core::mem::size_of::()) + }; + let output = device.new_buffer((b * m * n * sizeof) as u64, options); + + for mlx in [false, true] { + let mut sum_dt = 0f64; + let mut iters = 0usize; + for idx in 0.. { + let command_buffer = command_queue.new_command_buffer(); + let start_time = std::time::Instant::now(); + if mlx { + candle_metal_kernels::call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } else { + candle_metal_kernels::call_gemm( + &device, + command_buffer, + &kernels, + name, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + let dt = start_time.elapsed().as_secs_f64(); + if idx < WARMUP_ITERS { + continue; + } + sum_dt += dt; + iters += 1; + if sum_dt > MIN_DUR { + break; + } + } + let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt); + let mlx = if mlx { "MLX" } else { "MFA" }; + println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}"); + } + + Ok(()) +} + +#[derive(Subcommand, Debug, Clone)] +enum Task { + Gemm, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The benchmark to be run. + #[command(subcommand)] + task: Task, +} + +fn main() -> Result<()> { + let args = Args::parse(); + match args.task { + Task::Gemm => { + for f32 in [false, true] { + for n in [512, 1024, 2048, 4096] { + run_gemm(f32, n)?; + } + } + } + } + Ok(()) +} -- cgit v1.2.3