summaryrefslogtreecommitdiff
path: root/candle-nn/benches/benchmarks/conv.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/benches/benchmarks/conv.rs')
-rw-r--r--candle-nn/benches/benchmarks/conv.rs54
1 files changed, 54 insertions, 0 deletions
diff --git a/candle-nn/benches/benchmarks/conv.rs b/candle-nn/benches/benchmarks/conv.rs
new file mode 100644
index 00000000..eb80645b
--- /dev/null
+++ b/candle-nn/benches/benchmarks/conv.rs
@@ -0,0 +1,54 @@
+use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
+use candle::{DType, Device, Module, Tensor};
+use candle_nn::{Conv2d, Conv2dConfig};
+use criterion::{black_box, criterion_group, Criterion};
+use std::time::Instant;
+
+const B: usize = 1;
+const C: usize = 1;
+const M: usize = 128;
+const K: usize = 128;
+const K_SIZE: usize = 3;
+
+fn run(input: Tensor, weight: Tensor, bias: Tensor, config: Conv2dConfig) {
+ Conv2d::new(weight, Some(bias), config)
+ .forward(&input)
+ .unwrap();
+}
+
+fn run_conv2d_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
+ let weight = Tensor::ones((1, 1, K_SIZE, K_SIZE), dtype, device)
+ .unwrap()
+ .to_dtype(dtype)
+ .unwrap();
+ let bias = Tensor::zeros(K, dtype, device).unwrap();
+ let input = Tensor::ones((B, C, M, K), dtype, device).unwrap();
+
+ let mut group = c.benchmark_group(device.bench_name(name));
+ group.bench_function("iter", move |b| {
+ b.iter_custom(|iters| {
+ let start = Instant::now();
+ for _i in 0..iters {
+ run(
+ black_box(input.clone()),
+ black_box(weight.clone()),
+ black_box(bias.clone()),
+ Default::default(),
+ );
+ }
+ device.sync().unwrap();
+ start.elapsed()
+ })
+ });
+ group.finish();
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let device = BenchDeviceHandler::new().unwrap();
+ for d in device.devices {
+ run_conv2d_benchmark(c, &d, DType::F32, "conv2d_f32");
+ run_conv2d_benchmark(c, &d, DType::F16, "conv2d_f16");
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);