diff options
author | ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> | 2024-01-12 11:18:11 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-12 11:18:11 +0100 |
commit | e90bcdcc7c51dd85037055b59f22568100d801f0 (patch) | |
tree | 2c7e89df98f44192d92b185682d40b71295fd704 /candle-core | |
parent | 8e06bfb4fd33f1229a03abee20cc1c07198408b5 (diff) | |
download | candle-e90bcdcc7c51dd85037055b59f22568100d801f0.tar.gz candle-e90bcdcc7c51dd85037055b59f22568100d801f0.tar.bz2 candle-e90bcdcc7c51dd85037055b59f22568100d801f0.zip |
Metal: f16 and bf16 where_cond + benchmark (#1545)
* Use cfg to seperate benchmark results based on features
* Add metal where_cond for f16 and bf16. Add benchmark
* Remove allow pragma
* Avoid some unnecessary returns.
* Improve benchmarks layout
* Updated feature separated benchmarks
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/benches/bench_main.rs | 2 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/mod.rs | 1 | ||||
-rw-r--r-- | candle-core/benches/benchmarks/where_cond.rs | 64 | ||||
-rw-r--r-- | candle-core/src/metal_backend.rs | 1 |
4 files changed, 67 insertions, 1 deletions
diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb..92c33a86 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::matmul::benches, benchmarks::where_cond::benches); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 295bbabd..4e73ebb6 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod matmul; +pub(crate) mod where_cond; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs new file mode 100644 index 00000000..c517dcf5 --- /dev/null +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -0,0 +1,64 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; + +fn run(a: &Tensor, b: &Tensor, c: &Tensor) { + a.where_cond(b, c).unwrap(); +} + +const fn create_cond_arr<const N: usize>() -> [u8; N] { + let mut arr = [0u8; N]; + let mut i = 0; + while i < N { + arr[i] = (i % 2) as u8; + i += 1; + } + arr +} + +const B: usize = 1; +const M: usize = 1024; +const K: usize = 1024; +const SIZE: usize = B * M * K; + +const DATA: [u8; SIZE] = create_cond_arr::<SIZE>(); + +fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap(); + let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap(); + let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap(); + + let elements = B * M * K; + // E.g. 2 f32 tensors + 1 u8 tensor + let flops = (2 * elements * dtype.size_in_bytes()) + elements; + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run( + black_box(&tensor), + black_box(&on_true), + black_box(&on_false), + ); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let device = BenchDeviceHandler::new().unwrap(); + for d in device.devices { + run_where_cond_benchmark(c, &d, DType::F32, "where_cond_f32"); + run_where_cond_benchmark(c, &d, DType::BF16, "where_cond_bf16"); + run_where_cond_benchmark(c, &d, DType::F16, "where_cond_f16"); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index aa2898ff..38f909c8 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -806,6 +806,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", |