summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-17 10:49:13 +0100
committerGitHub <noreply@github.com>2024-03-17 10:49:13 +0100
commitce9fbc368211815ef2dddff01575ca1f9d4eccd5 (patch)
treee260edd957ab716d1789da05059c9a79696b0730 /candle-metal-kernels
parentdb8b24ae92419377283821ee0a65fb224a4f3c4d (diff)
downloadcandle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.gz
candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.tar.bz2
candle-ce9fbc368211815ef2dddff01575ca1f9d4eccd5.zip
Optimize the cat operation on contiguous tensors (#1855)
* Add a specialized kernel for copy2d. * Move the cat operations. * Avoid transpositions in cat. * Bugfix. * Bugfix for the cuda kernel. * Add a benchmark. * Add more testing. * Test fix. * Faster kernel. * Add the missing kernel. * Tweak the test. * Add a metal kernel. * Fix for the metal kernel. * Get the tests to pass on metal. * Also use this opportunity to fix the metal kernel for ELU. * Add some bf16 kernels. * Clippy fixes.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/affine.metal2
-rw-r--r--candle-metal-kernels/src/lib.rs50
-rw-r--r--candle-metal-kernels/src/unary.metal27
3 files changed, 78 insertions, 1 deletions
diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal
index a4484998..76c0365a 100644
--- a/candle-metal-kernels/src/affine.metal
+++ b/candle-metal-kernels/src/affine.metal
@@ -89,7 +89,7 @@ kernel void FN_NAME( \
return; \
} \
const TYPENAME x = input[id]; \
- output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
+ output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 47ce7e96..a879c86a 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -127,6 +127,16 @@ pub enum Source {
Quantized,
}
+pub mod copy2d {
+ pub struct Kernel(pub &'static str);
+ pub const FLOAT: Kernel = Kernel("copy2d_f32");
+ pub const HALF: Kernel = Kernel("copy2d_f16");
+ pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
+ pub const I64: Kernel = Kernel("copy2d_i64");
+ pub const U32: Kernel = Kernel("copy2d_u32");
+ pub const U8: Kernel = Kernel("copy2d_u8");
+}
+
macro_rules! ops{
($($name:ident),+) => {
@@ -366,6 +376,46 @@ pub fn call_unary_contiguous(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_copy2d(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ name: copy2d::Kernel,
+ input: &Buffer,
+ output: &Buffer,
+ d1: usize,
+ d2: usize,
+ src_s: usize,
+ dst_s: usize,
+ src_o_in_bytes: usize,
+ dst_o_in_bytes: usize,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+ set_params!(
+ encoder,
+ (
+ d1,
+ d2,
+ src_s,
+ dst_s,
+ (input, src_o_in_bytes),
+ (output, dst_o_in_bytes)
+ )
+ );
+
+ let width: usize = d1 * d2;
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal
index 1e0d5526..bdc13f9e 100644
--- a/candle-metal-kernels/src/unary.metal
+++ b/candle-metal-kernels/src/unary.metal
@@ -102,6 +102,30 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
+#define COPY2D(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &d1, \
+ constant size_t &d2, \
+ constant size_t &src_s, \
+ constant size_t &dst_s, \
+ device const TYPENAME *input, \
+ device TYPENAME *output, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (tid >= d1 * d2) { \
+ return; \
+ } \
+ size_t idx1 = tid / d2; \
+ size_t idx2 = tid - idx1 * d2; \
+ size_t src_idx = idx1 * src_s + idx2; \
+ size_t dst_idx = idx1 * dst_s + idx2; \
+ output[dst_idx] = input[src_idx]; \
+}
+
+COPY2D(copy2d_f32, float)
+COPY2D(copy2d_f16, half)
+COPY2D(copy2d_u8, uint8_t)
+COPY2D(copy2d_u32, uint32_t)
UNARY_OP(cos)
UNARY_OP(sin)
@@ -128,6 +152,7 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
+COPY2D(copy2d_i64, int64_t)
#endif
#if defined(__HAVE_BFLOAT__)
@@ -151,4 +176,6 @@ BFLOAT_UNARY_OP(recip)
BFLOAT_UNARY_OP(relu)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
+
+COPY2D(copy2d_bf64, bfloat)
#endif