diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-09-11 21:52:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-11 22:52:37 +0200 |
commit | 0cb0bd1dfaac97425a0805a2e7f24ea7992236c0 (patch) | |
tree | 4bd20985fa2333efa01dc6ade4789812283979bf /candle-metal-kernels | |
parent | afb6575835599938248c027f50a8100c289a1a96 (diff) | |
download | candle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.tar.gz candle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.tar.bz2 candle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.zip |
Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark.
* More benchmarks.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-metal-kernels/examples/metal_benchmarks.rs | 136 |
2 files changed, 138 insertions, 0 deletions
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 772452c9..8f92099d 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -17,10 +17,12 @@ thiserror = "1" tracing = "0.1.37" [dev-dependencies] +clap = { version = "4.2.4", features = ["derive"] } half = { version = "2.3.1", features = [ "num-traits", "use-intrinsics", "rand_distr", ] } +anyhow = "1" rand = "0.8.5" rand_distr = "0.4.3" diff --git a/candle-metal-kernels/examples/metal_benchmarks.rs b/candle-metal-kernels/examples/metal_benchmarks.rs new file mode 100644 index 00000000..c9c27997 --- /dev/null +++ b/candle-metal-kernels/examples/metal_benchmarks.rs @@ -0,0 +1,136 @@ +use anyhow::Result; +use candle_metal_kernels::GemmDType; +/// This example contains some simple benchmarks so that it's easy to run them in perf etc. +use clap::{Parser, Subcommand}; +use half::f16; + +fn run_gemm(f32: bool, n: usize) -> Result<()> { + const WARMUP_ITERS: usize = 2; + const MIN_DUR: f64 = 4.; + + let device = metal::Device::system_default().unwrap(); + + let (b, m, n, k) = (1, n, n, n); + let kernels = candle_metal_kernels::Kernels::new(); + let command_queue = device.new_command_queue(); + let options = metal::MTLResourceOptions::StorageModeManaged; + + let (lhs, rhs) = if f32 { + let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect(); + let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + } else { + let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let lhs = device.new_buffer_with_data( + lhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&lhs) as u64, + options, + ); + let rhs = device.new_buffer_with_data( + rhs.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(&rhs) as u64, + options, + ); + (lhs, rhs) + }; + let (dtype, name, sizeof) = if f32 { + (GemmDType::F32, "sgemm", core::mem::size_of::<f32>()) + } else { + (GemmDType::F16, "hgemm", core::mem::size_of::<f16>()) + }; + let output = device.new_buffer((b * m * n * sizeof) as u64, options); + + for mlx in [false, true] { + let mut sum_dt = 0f64; + let mut iters = 0usize; + for idx in 0.. { + let command_buffer = command_queue.new_command_buffer(); + let start_time = std::time::Instant::now(); + if mlx { + candle_metal_kernels::call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } else { + candle_metal_kernels::call_gemm( + &device, + command_buffer, + &kernels, + name, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + let dt = start_time.elapsed().as_secs_f64(); + if idx < WARMUP_ITERS { + continue; + } + sum_dt += dt; + iters += 1; + if sum_dt > MIN_DUR { + break; + } + } + let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt); + let mlx = if mlx { "MLX" } else { "MFA" }; + println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}"); + } + + Ok(()) +} + +#[derive(Subcommand, Debug, Clone)] +enum Task { + Gemm, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// The benchmark to be run. + #[command(subcommand)] + task: Task, +} + +fn main() -> Result<()> { + let args = Args::parse(); + match args.task { + Task::Gemm => { + for f32 in [false, true] { + for n in [512, 1024, 2048, 4096] { + run_gemm(f32, n)?; + } + } + } + } + Ok(()) +} |