diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-12-20 15:37:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-20 15:37:31 +0100 |
commit | 9fc210fae8175a180dba8c28aa8e5975868a237c (patch) | |
tree | 5c009b11e1c11f20c99d1546849a00a063e068c0 /candle-core/benches/matmul.rs | |
parent | 96f1a28e390fceeaa12b3272c8ac5dcccc8eb5fa (diff) | |
parent | 9b5e4843a63180a2803b1e836b4ca90f14281d03 (diff) | |
download | candle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.gz candle-9fc210fae8175a180dba8c28aa8e5975868a237c.tar.bz2 candle-9fc210fae8175a180dba8c28aa8e5975868a237c.zip |
Merge pull request #1318 from huggingface/metal4
Starting to fix some tests.
Diffstat (limited to 'candle-core/benches/matmul.rs')
-rw-r--r-- | candle-core/benches/matmul.rs | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/candle-core/benches/matmul.rs b/candle-core/benches/matmul.rs new file mode 100644 index 00000000..8732f451 --- /dev/null +++ b/candle-core/benches/matmul.rs @@ -0,0 +1,43 @@ +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor) { + a.matmul(&b.t().unwrap()).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 1; + let n = 2048; + let k = 2048; + + let device = Device::new_metal(0).unwrap(); + let dtype = DType::F32; + let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap(); + let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap(); + + let flops = b * m * n * k; + + let mut group = c.benchmark_group("matmul_metal"); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + if let Device::Metal(device) = &device { + device.wait_until_completed().unwrap(); + } else { + panic!("Expected metal device"); + } + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + |