summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-24 22:48:52 +0100
committerGitHub <noreply@github.com>2024-03-24 22:48:52 +0100
commit1b98f84a2baa23192b97e36131011da658bfa1c2 (patch)
tree92c4e9e8a263edfc8d3fedeab2cc02271d87d51e /candle-metal-kernels
parentcf7d7fcf2f20c24aae633483c3a107c1219a7f9a (diff)
downloadcandle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.gz
candle-1b98f84a2baa23192b97e36131011da658bfa1c2.tar.bz2
candle-1b98f84a2baa23192b97e36131011da658bfa1c2.zip
Fast kernels for rotary embeddings. (#1928)
* Fast kernels for rotary embeddings. * Add a test for the fast CPU kernel. * Rope cuda bindings. * Cuda kernel. * Metal kernel (part 1). * Cuda kernels. * Finish the metal kernel. * Use the new kernels in the quantized example. * Fix warning.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs41
-rw-r--r--candle-metal-kernels/src/reduce.metal23
2 files changed, 64 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index e17365a0..e83814a8 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -809,6 +809,47 @@ pub fn call_rms_norm(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_rope_i(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ bh: usize,
+ td: usize,
+ src: &Buffer,
+ src_offset: usize,
+ cos: &Buffer,
+ cos_offset: usize,
+ sin: &Buffer,
+ sin_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ bh,
+ td,
+ (src, src_offset),
+ (cos, cos_offset),
+ (sin, sin_offset),
+ output
+ )
+ );
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, (bh * td) / 2);
+ encoder.use_resource(src, metal::MTLResourceUsage::Read);
+ encoder.use_resource(cos, metal::MTLResourceUsage::Read);
+ encoder.use_resource(sin, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 3c3cbc14..fa980dea 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -313,6 +313,26 @@ kernel void NAME(
} \
} \
+#define ROPEI(FN_NAME, TYPENAME) \
+kernel void FN_NAME( \
+ constant size_t &bh, \
+ constant size_t &td, \
+ device const TYPENAME *src, \
+ device const TYPENAME *cos, \
+ device const TYPENAME *sin, \
+ device TYPENAME *dst, \
+ uint tid [[ thread_position_in_grid ]] \
+) { \
+ if (2 * tid >= bh * td) { \
+ return; \
+ } \
+ size_t rope_idx = tid % (td / 2); \
+ TYPENAME c = cos[rope_idx]; \
+ TYPENAME s = sin[rope_idx]; \
+ dst[2 * tid] = src[2 * tid] * c - src[2 * tid + 1] * s; \
+ dst[2 * tid + 1] = src[2 * tid] * s + src[2 * tid + 1] * c; \
+}\
+
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
@@ -341,6 +361,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
+ROPEI(rope_i_f32, float)
+ROPEI(rope_i_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@@ -359,4 +381,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
+ROPEI(rope_i_bf16, bfloat)
#endif