summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-12 10:29:18 +0200
committerGitHub <noreply@github.com>2023-10-12 09:29:18 +0100
commitc096f024113a64612bc66e4ccb6908f07c4d791c (patch)
tree7f2b3122e16a20a84e98d24a5d6a4c1aca1763dd /candle-nn
parente7560443e4680b7655d011948d3cf178268fcfff (diff)
downloadcandle-c096f024113a64612bc66e4ccb6908f07c4d791c.tar.gz
candle-c096f024113a64612bc66e4ccb6908f07c4d791c.tar.bz2
candle-c096f024113a64612bc66e4ccb6908f07c4d791c.zip
Add a matvec cpu benchmark. (#1076)
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs25
1 files changed, 22 insertions, 3 deletions
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
index 6007ff6c..9ded5f71 100644
--- a/candle-nn/examples/cpu_benchmarks.rs
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -180,8 +180,25 @@ impl Benchmark for Conv2dIm2Col {
const ITERS: usize = 5;
}
-struct Matmul;
-impl Benchmark for Matmul {
+struct MatMul;
+impl Benchmark for MatMul {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
+ Ok((lhs, rhs))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ d.0.matmul(&d.1)
+ }
+
+ const ITERS: usize = 100;
+}
+
+struct MatVec;
+impl Benchmark for MatVec {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
@@ -271,6 +288,7 @@ enum Task {
Conv2d,
Conv2dIm2Col,
Matmul,
+ Matvec,
Qmatmul,
Softmax,
SoftmaxLastDim,
@@ -293,7 +311,8 @@ fn main() -> Result<()> {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?,
- Task::Matmul => run::<Matmul>(args.iters)?,
+ Task::Matmul => run::<MatMul>(args.iters)?,
+ Task::Matvec => run::<MatVec>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,