summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-09-11 21:52:37 +0100
committerGitHub <noreply@github.com>2024-09-11 22:52:37 +0200
commit0cb0bd1dfaac97425a0805a2e7f24ea7992236c0 (patch)
tree4bd20985fa2333efa01dc6ade4789812283979bf /candle-metal-kernels
parentafb6575835599938248c027f50a8100c289a1a96 (diff)
downloadcandle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.tar.gz
candle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.tar.bz2
candle-0cb0bd1dfaac97425a0805a2e7f24ea7992236c0.zip
Add some metal gemm benchark. (#2471)
* Add some metal gemm benchark. * More benchmarks.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/Cargo.toml2
-rw-r--r--candle-metal-kernels/examples/metal_benchmarks.rs136
2 files changed, 138 insertions, 0 deletions
diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml
index 772452c9..8f92099d 100644
--- a/candle-metal-kernels/Cargo.toml
+++ b/candle-metal-kernels/Cargo.toml
@@ -17,10 +17,12 @@ thiserror = "1"
tracing = "0.1.37"
[dev-dependencies]
+clap = { version = "4.2.4", features = ["derive"] }
half = { version = "2.3.1", features = [
"num-traits",
"use-intrinsics",
"rand_distr",
] }
+anyhow = "1"
rand = "0.8.5"
rand_distr = "0.4.3"
diff --git a/candle-metal-kernels/examples/metal_benchmarks.rs b/candle-metal-kernels/examples/metal_benchmarks.rs
new file mode 100644
index 00000000..c9c27997
--- /dev/null
+++ b/candle-metal-kernels/examples/metal_benchmarks.rs
@@ -0,0 +1,136 @@
+use anyhow::Result;
+use candle_metal_kernels::GemmDType;
+/// This example contains some simple benchmarks so that it's easy to run them in perf etc.
+use clap::{Parser, Subcommand};
+use half::f16;
+
+fn run_gemm(f32: bool, n: usize) -> Result<()> {
+ const WARMUP_ITERS: usize = 2;
+ const MIN_DUR: f64 = 4.;
+
+ let device = metal::Device::system_default().unwrap();
+
+ let (b, m, n, k) = (1, n, n, n);
+ let kernels = candle_metal_kernels::Kernels::new();
+ let command_queue = device.new_command_queue();
+ let options = metal::MTLResourceOptions::StorageModeManaged;
+
+ let (lhs, rhs) = if f32 {
+ let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
+ let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
+ let lhs = device.new_buffer_with_data(
+ lhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(&lhs) as u64,
+ options,
+ );
+ let rhs = device.new_buffer_with_data(
+ rhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(&rhs) as u64,
+ options,
+ );
+ (lhs, rhs)
+ } else {
+ let lhs: Vec<f16> = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect();
+ let rhs: Vec<f16> = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect();
+ let lhs = device.new_buffer_with_data(
+ lhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(&lhs) as u64,
+ options,
+ );
+ let rhs = device.new_buffer_with_data(
+ rhs.as_ptr() as *const core::ffi::c_void,
+ std::mem::size_of_val(&rhs) as u64,
+ options,
+ );
+ (lhs, rhs)
+ };
+ let (dtype, name, sizeof) = if f32 {
+ (GemmDType::F32, "sgemm", core::mem::size_of::<f32>())
+ } else {
+ (GemmDType::F16, "hgemm", core::mem::size_of::<f16>())
+ };
+ let output = device.new_buffer((b * m * n * sizeof) as u64, options);
+
+ for mlx in [false, true] {
+ let mut sum_dt = 0f64;
+ let mut iters = 0usize;
+ for idx in 0.. {
+ let command_buffer = command_queue.new_command_buffer();
+ let start_time = std::time::Instant::now();
+ if mlx {
+ candle_metal_kernels::call_mlx_gemm(
+ &device,
+ command_buffer,
+ &kernels,
+ dtype,
+ (b, m, n, k),
+ &[m * k, k, 1],
+ 0,
+ &lhs,
+ &[n * k, n, 1],
+ 0,
+ &rhs,
+ &output,
+ )?;
+ } else {
+ candle_metal_kernels::call_gemm(
+ &device,
+ command_buffer,
+ &kernels,
+ name,
+ (b, m, n, k),
+ &[m * k, k, 1],
+ 0,
+ &lhs,
+ &[n * k, n, 1],
+ 0,
+ &rhs,
+ &output,
+ )?;
+ }
+ command_buffer.commit();
+ command_buffer.wait_until_completed();
+ let dt = start_time.elapsed().as_secs_f64();
+ if idx < WARMUP_ITERS {
+ continue;
+ }
+ sum_dt += dt;
+ iters += 1;
+ if sum_dt > MIN_DUR {
+ break;
+ }
+ }
+ let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt);
+ let mlx = if mlx { "MLX" } else { "MFA" };
+ println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}");
+ }
+
+ Ok(())
+}
+
+#[derive(Subcommand, Debug, Clone)]
+enum Task {
+ Gemm,
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+pub struct Args {
+ /// The benchmark to be run.
+ #[command(subcommand)]
+ task: Task,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ match args.task {
+ Task::Gemm => {
+ for f32 in [false, true] {
+ for n in [512, 1024, 2048, 4096] {
+ run_gemm(f32, n)?;
+ }
+ }
+ }
+ }
+ Ok(())
+}