summaryrefslogtreecommitdiff
path: root/candle-nn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-17 10:49:13 +0100
committerGitHub <noreply@github.com>2024-03-17 10:49:13 +0100
commitce9fbc368211815ef2dddff01575ca1f9d4eccd5 (patch)
treee260edd957ab716d1789da05059c9a79696b0730 /candle-nn
parentdb8b24ae92419377283821ee0a65fb224a4f3c4d (diff)
downloadcandle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.gz
candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.bz2
candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.zip
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
Diffstat (limited to 'candle-nn')
-rw-r--r--candle-nn/examples/cpu_benchmarks.rs19
1 files changed, 19 insertions, 0 deletions
diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs
index 001be116..430316b8 100644
--- a/candle-nn/examples/cpu_benchmarks.rs
+++ b/candle-nn/examples/cpu_benchmarks.rs
@@ -238,6 +238,23 @@ impl Benchmark for QMatMul {
const ITERS: usize = 100;
}
+struct Cat;
+impl Benchmark for Cat {
+ type PreProcessData = (Tensor, Tensor);
+ type RunResult = Tensor;
+ fn preprocess() -> Result<Self::PreProcessData> {
+ let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?;
+ let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?;
+ Ok((lhs, rhs))
+ }
+
+ fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
+ Tensor::cat(&[&d.0, &d.1], 2)
+ }
+
+ const ITERS: usize = 1000;
+}
+
struct Softmax;
impl Benchmark for Softmax {
type PreProcessData = Tensor;
@@ -295,6 +312,7 @@ enum Task {
Qmatmul,
Softmax,
SoftmaxLastDim,
+ Cat,
}
#[derive(Parser, Debug)]
@@ -319,6 +337,7 @@ fn main() -> Result<()> {
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,
+ Task::Cat => run::<Cat>(args.iters)?,
}
Ok(())
}