summaryrefslogtreecommitdiff
path: root/candle-metal-kernels/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-04-05 08:32:58 +0200
committerGitHub <noreply@github.com>2024-04-05 08:32:58 +0200
commit2ac302a5d170953a1d2fe850645563fc55d1567f (patch)
treeb35b32efe5c8eac25a9b5681fb0778ef84e57d0e /candle-metal-kernels/src
parentace282e5c2ef24ca2fb90683babb852936d4df17 (diff)
downloadcandle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.gz
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.tar.bz2
candle-2ac302a5d170953a1d2fe850645563fc55d1567f.zip
Add the rope THD kernel. (#2014)
* Add the rope THD kernel. * Cuda kernel for rope-thd. * Add the metal kernels. * Add a dedicated test.
Diffstat (limited to 'candle-metal-kernels/src')
-rw-r--r--candle-metal-kernels/src/lib.rs45
-rw-r--r--candle-metal-kernels/src/reduce.metal48
2 files changed, 89 insertions, 4 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 5af48fae..4cff9bda 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -850,6 +850,51 @@ pub fn call_rope_i(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_rope_thd(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ b: usize,
+ t: usize,
+ h: usize,
+ d: usize,
+ src: &Buffer,
+ src_offset: usize,
+ cos: &Buffer,
+ cos_offset: usize,
+ sin: &Buffer,
+ sin_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ b,
+ t,
+ h,
+ d,
+ (src, src_offset),
+ (cos, cos_offset),
+ (sin, sin_offset),
+ output
+ )
+ );
+ let (thread_group_count, thread_group_size) = linear_split(&pipeline, (b * t * h * d) / 2);
+ encoder.use_resource(src, metal::MTLResourceUsage::Read);
+ encoder.use_resource(cos, metal::MTLResourceUsage::Read);
+ encoder.use_resource(sin, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_rope(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index acb69299..14bfb297 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -418,7 +418,34 @@ METAL_FUNC void rope(
dst[i2] = src[i1] * s + src[i2] * c;
}
-#define ROPEI(FN_NAME, FN_NAME_I, TYPENAME) \
+template<typename T>
+METAL_FUNC void rope_thd(
+ constant size_t &b,
+ constant size_t &t,
+ constant size_t &h,
+ constant size_t &d,
+ device const T *src,
+ device const T *cos,
+ device const T *sin,
+ device T *dst,
+ uint idx
+) {
+ if (2 * idx >= b * t * h * d) {
+ return;
+ }
+ const size_t i_bth = idx / (d / 2);
+ const size_t i_d = idx - (d / 2) * i_bth;
+ const size_t i_t = (i_bth / h) % t;
+ const size_t i1 = i_bth * d + i_d;
+ const size_t i2 = i1 + d / 2;
+ const size_t i_cs = i_t * (d / 2) + i_d;
+ T c = cos[i_cs];
+ T s = sin[i_cs];
+ dst[i1] = src[i1] * c - src[i2] * s;
+ dst[i2] = src[i1] * s + src[i2] * c;
+}
+
+#define ROPE(FN_NAME, FN_NAME_I, FN_NAME_THD, TYPENAME) \
kernel void FN_NAME_I( \
constant size_t &bh, \
constant size_t &td, \
@@ -442,6 +469,19 @@ kernel void FN_NAME( \
) { \
rope<TYPENAME>(bh, td, d, src, cos, sin, dst, idx); \
}\
+kernel void FN_NAME_THD( \
+ constant size_t &b, \
+ constant size_t &t, \
+ constant size_t &h, \
+ constant size_t &d, \
+ device const TYPENAME *src, \
+ device const TYPENAME *cos, \
+ device const TYPENAME *sin, \
+ device TYPENAME *dst, \
+ uint idx [[ thread_position_in_grid ]] \
+) { \
+ rope_thd<TYPENAME>(b, t, h, d, src, cos, sin, dst, idx); \
+}\
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
@@ -471,8 +511,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
-ROPEI(rope_f32, rope_i_f32, float)
-ROPEI(rope_f16, rope_i_f16, half)
+ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
+ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@@ -495,5 +535,5 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
-ROPEI(rope_bf16, rope_i_bf16, bfloat)
+ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif