diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-17 10:49:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-17 10:49:13 +0100 |
commit | ce9fbc368211815ef2dddff01575ca1f9d4eccd5 (patch) | |
tree | e260edd957ab716d1789da05059c9a79696b0730 /candle-metal-kernels | |
parent | db8b24ae92419377283821ee0a65fb224a4f3c4d (diff) | |
download | candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.gz candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.bz2 candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.zip |
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d.
* Move the cat operations.
* Avoid transpositions in cat.
* Bugfix.
* Bugfix for the cuda kernel.
* Add a benchmark.
* Add more testing.
* Test fix.
* Faster kernel.
* Add the missing kernel.
* Tweak the test.
* Add a metal kernel.
* Fix for the metal kernel.
* Get the tests to pass on metal.
* Also use this opportunity to fix the metal kernel for ELU.
* Add some bf16 kernels.
* Clippy fixes.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r-- | candle-metal-kernels/src/affine.metal | 2 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 50 | ||||
-rw-r--r-- | candle-metal-kernels/src/unary.metal | 27 |
3 files changed, 78 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index a4484998..76c0365a 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -89,7 +89,7 @@ kernel void FN_NAME( \ return; \ } \ const TYPENAME x = input[id]; \ - output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ } \ kernel void FN_NAME##_strided( \ constant size_t &dim, \ diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 47ce7e96..a879c86a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -127,6 +127,16 @@ pub enum Source { Quantized, } +pub mod copy2d { + pub struct Kernel(pub &'static str); + pub const FLOAT: Kernel = Kernel("copy2d_f32"); + pub const HALF: Kernel = Kernel("copy2d_f16"); + pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); + pub const I64: Kernel = Kernel("copy2d_i64"); + pub const U32: Kernel = Kernel("copy2d_u32"); + pub const U8: Kernel = Kernel("copy2d_u8"); +} + macro_rules! ops{ ($($name:ident),+) => { @@ -366,6 +376,46 @@ pub fn call_unary_contiguous( } #[allow(clippy::too_many_arguments)] +pub fn call_copy2d( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: copy2d::Kernel, + input: &Buffer, + output: &Buffer, + d1: usize, + d2: usize, + src_s: usize, + dst_s: usize, + src_o_in_bytes: usize, + dst_o_in_bytes: usize, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + ( + d1, + d2, + src_s, + dst_s, + (input, src_o_in_bytes), + (output, dst_o_in_bytes) + ) + ); + + let width: usize = d1 * d2; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 1e0d5526..bdc13f9e 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -102,6 +102,30 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); #define BFLOAT_UNARY_OP(NAME) \ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); +#define COPY2D(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &d1, \ + constant size_t &d2, \ + constant size_t &src_s, \ + constant size_t &dst_s, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint tid [[ thread_position_in_grid ]] \ +) { \ + if (tid >= d1 * d2) { \ + return; \ + } \ + size_t idx1 = tid / d2; \ + size_t idx2 = tid - idx1 * d2; \ + size_t src_idx = idx1 * src_s + idx2; \ + size_t dst_idx = idx1 * dst_s + idx2; \ + output[dst_idx] = input[src_idx]; \ +} + +COPY2D(copy2d_f32, float) +COPY2D(copy2d_f16, half) +COPY2D(copy2d_u8, uint8_t) +COPY2D(copy2d_u32, uint32_t) UNARY_OP(cos) UNARY_OP(sin) @@ -128,6 +152,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided) #if __METAL_VERSION__ >= 220 UNARY(id, int64_t, copy_i64, copy_i64_strided) +COPY2D(copy2d_i64, int64_t) #endif #if defined(__HAVE_BFLOAT__) @@ -151,4 +176,6 @@ BFLOAT_UNARY_OP(recip) BFLOAT_UNARY_OP(relu) UNARY(id, bfloat, copy_bf16, copy_bf16_strided) + +COPY2D(copy2d_bf64, bfloat) #endif |