summaryrefslogtreecommitdiff
path: root/candle-metal-kernels
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-05-24 15:58:01 +0200
committerGitHub <noreply@github.com>2024-05-24 15:58:01 +0200
commit1df2bddccfbb4ab511a8cc3a87476d1fa72416bc (patch)
tree3633bc51e3bac3d542d9dfe06d509db20f5374e9 /candle-metal-kernels
parent6f0b807ffd553fed27325a2a118b0e30bb6d9cbd (diff)
downloadcandle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.gz
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.tar.bz2
candle-1df2bddccfbb4ab511a8cc3a87476d1fa72416bc.zip
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
Diffstat (limited to 'candle-metal-kernels')
-rw-r--r--candle-metal-kernels/src/lib.rs63
-rw-r--r--candle-metal-kernels/src/reduce.metal81
2 files changed, 144 insertions, 0 deletions
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 814ca0b9..aa157a2f 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -739,6 +739,69 @@ pub fn call_rms_norm(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_layer_norm(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ elements_to_sum: usize,
+ eps: f32,
+ input: &Buffer,
+ input_offset: usize,
+ alpha: &Buffer,
+ alpha_offset: usize,
+ beta: &Buffer,
+ beta_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ length,
+ elements_to_sum,
+ (input, input_offset),
+ output,
+ (alpha, alpha_offset),
+ (beta, beta_offset),
+ eps
+ )
+ );
+
+ let out_length = length / elements_to_sum;
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ elements_to_sum as u64,
+ )
+ .next_power_of_two();
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 14bfb297..e009ca1d 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
}
}
+template<typename T>
+METAL_FUNC void layernorm(
+ constant size_t & src_numel,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device T * dst,
+ device const T * alpha,
+ device const T * beta,
+ constant float & eps,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup float * shared_memory
+) {
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ float tmp1 = 0;
+ float tmp2 = 0;
+ while (idx < stop_idx) {
+ tmp1 += float(src[idx]);
+ tmp2 += float(src[idx]) * float(src[idx]);
+ idx += block_dim;
+ }
+ shared_memory[tid] = tmp1;
+ shared_memory[tid + block_dim] = tmp2;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
+ shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ /* wait for shared_memory[0] to be filled */
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float mean = shared_memory[0] / float(el_to_sum_per_block);
+ float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
+ float inv_norm = 1.0f / sqrt(var + eps);
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ float val = (float(src[idx]) - mean) * inv_norm;
+ if (alpha != nullptr) {
+ val *= float(alpha[idx - start_idx]);
+ }
+ if (beta != nullptr) {
+ val += float(beta[idx - start_idx]);
+ }
+ dst[idx] = T(val);
+ idx += block_dim;
+ }
+}
+
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@@ -371,6 +430,25 @@ kernel void NAME( \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
+#define LAYERNORM(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ device const T *alpha, \
+ device const T *beta, \
+ constant float &eps, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = 0; \
+ layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
+} \
+
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
@@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
+LAYERNORM(layernorm_f32, float)
+LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
@@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
+LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif