From ce9fbc368211815ef2dddff01575ca1f9d4eccd5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 17 Mar 2024 10:49:13 +0100 Subject: Optimize the cat operation on contiguous tensors (#1855) * Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes. --- candle-core/src/dummy_metal_backend.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'candle-core/src/dummy_metal_backend.rs') diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index e9d92331..791ec153 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -166,6 +166,19 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } + fn copy2d( + &self, + _: &mut Self, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + _: usize, + ) -> Result<()> { + Err(Error::NotCompiledWithMetalSupport) + } + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { Err(Error::NotCompiledWithMetalSupport) } -- cgit v1.2.3