diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-17 10:49:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 10:49:13 +0100 |
commit | ce9fbc368211815ef2dddff01575ca1f9d4eccd5 (patch) | |
tree | e260edd957ab716d1789da05059c9a79696b0730 /candle-nn | |
parent | db8b24ae92419377283821ee0a65fb224a4f3c4d (diff) | |
download | candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.gz candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.bz2 candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.zip |
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.
* Move the cat operations.
* Avoid transpositions in cat.
* Bugfix.
* Bugfix for the cuda kernel.
* Add a benchmark.
* Add more testing.
* Test fix.
* Faster kernel.
* Add the missing kernel.
* Tweak the test.
* Add a metal kernel.
* Fix for the metal kernel.
* Get the tests to pass on metal.
* Also use this opportunity to fix the metal kernel for ELU.
* Add some bf16 kernels.
* Clippy fixes.
Diffstat (limited to 'candle-nn')
-rw-r--r-- | candle-nn/examples/cpu_benchmarks.rs | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 001be116..430316b8 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -238,6 +238,23 @@ impl Benchmark for QMatMul { const ITERS: usize = 100; } +struct Cat; +impl Benchmark for Cat { + type PreProcessData = (Tensor, Tensor); + type RunResult = Tensor; + fn preprocess() -> Result<Self::PreProcessData> { + let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?; + Ok((lhs, rhs)) + } + + fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> { + Tensor::cat(&[&d.0, &d.1], 2) + } + + const ITERS: usize = 1000; +} + struct Softmax; impl Benchmark for Softmax { type PreProcessData = Tensor; @@ -295,6 +312,7 @@ enum Task { Qmatmul, Softmax, SoftmaxLastDim, + Cat, } #[derive(Parser, Debug)] @@ -319,6 +337,7 @@ fn main() -> Result<()> { Task::Softmax => run::<Softmax>(args.iters)?, Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?, Task::Qmatmul => run::<QMatMul>(args.iters)?, + Task::Cat => run::<Cat>(args.iters)?, } Ok(()) } |