summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/cpu_benchmarks.rs19
1 files changed, 19 insertions, 0 deletions
diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs
index 4cc710fb..421a4d50 100644
--- a/candle-core/examples/cpu_benchmarks.rs
+++ b/candle-core/examples/cpu_benchmarks.rs
@@ -55,6 +55,23 @@ impl Benchmark for Conv2d {
const ITERS: usize = 1;
}
+struct Matmul;
+impl Benchmark for Matmul {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ Ok((lhs, rhs))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.matmul(&d.1)
+ }
+
+ const ITERS: usize = 100;
+}
+
fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
use std::hint::black_box;
@@ -72,6 +89,7 @@ fn run<B: Benchmark>(iters: Option<usize>) -> Result<()> {
enum Task {
Conv1d,
Conv2d,
+ Matmul,
}
#[derive(Parser, Debug)]
@@ -90,6 +108,7 @@ fn main() -> Result<()> {
match args.task {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
+ Task::Matmul => run::<Matmul>(args.iters)?,
}
Ok(())
}