summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-21 09:48:56 +0100
committerGitHub <noreply@github.com>2024-03-21 09:48:56 +0100
commit0fddec762e3c17c56be5b6356478b9565dd628bb (patch)
tree49a1e09d3b397f97187f60739e80f528ae4b083a /candle-metal-kernels
parent74b7f59261c72010e329fd8eb467c088673671f5 (diff)
downloadcandle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.gz
candle-0fddec762e3c17c56be5b6356478b9565dd628bb.tar.bz2
candle-0fddec762e3c17c56be5b6356478b9565dd628bb.zip
RmsNorm kernel for metal. (#1895)
* RmsNorm kernel for metal. * Wrapper for the metal kernel. * Get the ops to actually work. * Fix, get the tests to pass.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs58
-rw-r--r--candle-metal-kernels/src/reduce.metal56
2 files changed, 114 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index f12463a4..bab44a05 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -751,6 +751,64 @@ pub fn call_last_softmax(
}
#[allow(clippy::too_many_arguments)]
+pub fn call_rms_norm(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ elements_to_sum: usize,
+ eps: f32,
+ input: &Buffer,
+ input_offset: usize,
+ alpha: &Buffer,
+ alpha_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ length,
+ elements_to_sum,
+ (input, input_offset),
+ output,
+ (alpha, alpha_offset),
+ eps
+ )
+ );
+
+ let out_length = length / elements_to_sum;
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ elements_to_sum as u64,
+ )
+ .next_power_of_two();
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 93dac662..3c3cbc14 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -260,6 +260,59 @@ kernel void NAME(
} \
} \
+#define RMSNORM(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ device const T *alpha, \
+ constant float &eps, \
+ \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = 0; \
+ size_t start_idx = dst_id * el_to_sum_per_block; \
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
+ size_t idx = start_idx + tid; \
+ \
+ \
+ float tmp = 0; \
+ while (idx < stop_idx) { \
+ tmp = tmp + float(src[idx]) * float(src[idx]); \
+ idx += block_dim; \
+ } \
+ shared_memory[tid] = tmp; \
+ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ for (uint s = block_dim / 2; s > 0; s >>= 1) { \
+ if (tid < s) { \
+ shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; \
+ } \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ } \
+ \
+ /* wait for shared_memory[0] to be filled */ \
+ threadgroup_barrier(mem_flags::mem_threadgroup); \
+ \
+ float norm = sqrt(shared_memory[0] / float(el_to_sum_per_block) + eps); \
+ float inv_norm = 1.0f / norm; \
+ idx = start_idx + tid; \
+ while (idx < stop_idx) { \
+ float val = float(src[idx]) * inv_norm; \
+ if (alpha != nullptr) { \
+ val *= float(alpha[idx - start_idx]); \
+ } \
+ dst[idx] = T(val); \
+ idx += block_dim; \
+ } \
+} \
+
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
@@ -286,6 +339,8 @@ ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
+RMSNORM(rmsnorm_f32, float)
+RMSNORM(rmsnorm_f16, half)
#if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
@@ -303,4 +358,5 @@ REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
+RMSNORM(rmsnorm_bf16, bfloat)
#endif