diff options
-rw-r--r-- | candle-core/src/cuda_backend/mod.rs | 2 | ||||
-rw-r--r-- | candle-core/src/cuda_backend/utils.rs | 38 | ||||
-rw-r--r-- | candle-kernels/src/reduce.cu | 84 | ||||
-rw-r--r-- | candle-metal-kernels/src/lib.rs | 63 | ||||
-rw-r--r-- | candle-metal-kernels/src/reduce.metal | 81 | ||||
-rw-r--r-- | candle-nn/src/ops.rs | 258 | ||||
-rw-r--r-- | candle-nn/tests/ops.rs | 27 |
7 files changed, 547 insertions, 6 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 88f325f4..9e72dcc8 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -16,7 +16,7 @@ mod error; mod utils; pub use device::{CudaDevice, DeviceId}; pub use error::{CudaError, WrapErr}; -pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; pub enum SlicePtrOrNull<T> { Ptr(CudaSlice<T>), diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index 8dd5be77..c1210727 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -54,6 +54,44 @@ pub trait Map2 { } } +pub trait Map3 { + #[allow(clippy::too_many_arguments)] + fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( + &self, + src1: &CudaSlice<T>, + layout1: &Layout, + src2: &CudaSlice<T>, + layout2: &Layout, + src3: &CudaSlice<T>, + layout3: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + s1: &S, + l1: &Layout, + s2: &S, + l2: &Layout, + s3: &S, + l3: &Layout, + d: &CudaDevice, + ) -> Result<S> { + let out = match (s1, s2, s3) { + (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, + }; + Ok(out) + } +} + pub trait Map2InPlace { fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( &self, diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 4dbd8dcc..aaac24a1 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block, dst[dst_id] = shr[0]; } +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); + a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + } + return a; +} + static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { return x; } +// LayerNorm implementation adapted from ggml, accumulation is made using f32. +// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477 +template <typename T> +__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + if (alpha == nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs); + } + } + else if (alpha == nullptr && beta != nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float b = static_cast<float>(beta[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs + b); + } + } + else if (alpha != nullptr && beta == nullptr) { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast<float>(alpha[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs * a); + } + } + else { + for (int col = tid; col < ncols; col += block_size) { + float a = static_cast<float>(alpha[col]); + float b = static_cast<float>(beta[col]); + float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std; + dst[row*ncols + col] = static_cast<T>(lhs * a + b); + } + } +} + // RmsNorm implementation adapted from ggml, accumulation is made using f32. // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 template <typename T> @@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \ } \ +#define LAYERNORM_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \ + const TYPENAME *beta, const int n_cols, const float eps) { \ + layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \ + } \ + #define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ extern "C" __global__ void FN_NAME_I( \ const TYPENAME *src, \ @@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, #if __CUDA_ARCH__ >= 800 SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) +LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) @@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) +LAYERNORM_OP(__half, layernorm_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) SUM_OP(__half, sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) @@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(double, rmsnorm_f64) +LAYERNORM_OP(float, layernorm_f32) +LAYERNORM_OP(double, layernorm_f64) ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 814ca0b9..aa157a2f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -739,6 +739,69 @@ pub fn call_rms_norm( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +pub fn call_layer_norm( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + eps: f32, + input: &Buffer, + input_offset: usize, + alpha: &Buffer, + alpha_offset: usize, + beta: &Buffer, + beta_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + length, + elements_to_sum, + (input, input_offset), + output, + (alpha, alpha_offset), + (beta, beta_offset), + eps + ) + ); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 14bfb297..e009ca1d 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm( } } +template<typename T> +METAL_FUNC void layernorm( + constant size_t & src_numel, + constant size_t & el_to_sum_per_block, + device const T * src, + device T * dst, + device const T * alpha, + device const T * beta, + constant float & eps, + uint id, + uint tid, + uint dst_id, + uint block_dim, + threadgroup float * shared_memory +) { + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + float tmp1 = 0; + float tmp2 = 0; + while (idx < stop_idx) { + tmp1 += float(src[idx]); + tmp2 += float(src[idx]) * float(src[idx]); + idx += block_dim; + } + shared_memory[tid] = tmp1; + shared_memory[tid + block_dim] = tmp2; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint s = block_dim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s]; + shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + /* wait for shared_memory[0] to be filled */ + threadgroup_barrier(mem_flags::mem_threadgroup); + + float mean = shared_memory[0] / float(el_to_sum_per_block); + float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean; + float inv_norm = 1.0f / sqrt(var + eps); + idx = start_idx + tid; + while (idx < stop_idx) { + float val = (float(src[idx]) - mean) * inv_norm; + if (alpha != nullptr) { + val *= float(alpha[idx - start_idx]); + } + if (beta != nullptr) { + val += float(beta[idx - start_idx]); + } + dst[idx] = T(val); + idx += block_dim; + } +} + #define RMSNORM(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -371,6 +430,25 @@ kernel void NAME( \ rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ } \ +#define LAYERNORM(NAME, T) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + device const T *alpha, \ + device const T *beta, \ + constant float &eps, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + shared_memory[tid] = 0; \ + layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \ +} \ + template<typename T> METAL_FUNC void ropei( constant size_t &bh, @@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f16, half) +LAYERNORM(layernorm_f32, float) +LAYERNORM(layernorm_f16, half) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) @@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat) +LAYERNORM(layernorm_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) #endif diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index eabc95d8..2a76ee5e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> { } pub fn swiglu(xs: &Tensor) -> Result<Tensor> { - let xs = xs.chunk(2, candle::D::Minus1)?; + let xs = xs.chunk(2, D::Minus1)?; &xs[0].silu()? * &xs[1] } @@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let hidden_size = x.dim(candle::D::Minus1)?; + let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; - let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) } pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { - let hidden_size_xs = xs.dim(candle::D::Minus1)?; + let hidden_size_xs = xs.dim(D::Minus1)?; let hidden_size_alpha = alpha.dims1()?; if hidden_size_xs != hidden_size_alpha { candle::bail!( @@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) } +#[derive(Debug, Clone)] +struct LayerNorm { + eps: f32, +} + +impl candle::CustomOp3 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: candle::WithDType + + num_traits::Float + + num_traits::AsPrimitive<f32> + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + beta: &[T], + beta_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let beta = match beta_layout.contiguous_offsets() { + None => candle::bail!("beta has to be contiguous"), + Some((o1, o2)) => &beta[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut sum = 0f32; + let mut sum2 = 0f32; + for v in src { + let v = v.as_(); + sum += v; + sum2 += v * v; + } + let mean = sum / dim_m1 as f32; + let var = sum2 / dim_m1 as f32 - mean * mean; + let inv_std = (var + eps).sqrt().recip(); + for ((d, s), (alpha, beta)) in + dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta)) + { + let alpha = alpha.as_(); + let beta = beta.as_(); + let d_ = (s.as_() - mean) * inv_std * alpha + beta; + *d = T::from_f32(d_).unwrap_or_else(T::nan); + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2, s3) { + (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => { + inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps) + } + (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps), + (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps), + _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map3 for S { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + layout: &Layout, + alpha: &CudaSlice<T>, + alpha_layout: &Layout, + beta: &CudaSlice<T>, + beta_layout: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let beta = match beta_layout.contiguous_offsets() { + None => candle::bail!("beta has to be contiguous"), + Some((o1, o2)) => beta.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1024, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(el) }.w()?; + let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &candle::MetalStorage, + l1: &Layout, + s2: &candle::MetalStorage, + l2: &Layout, + s3: &candle::MetalStorage, + l3: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype(), s3.dtype()) { + (DType::F32, DType::F32, DType::F32) => "layernorm_f32", + (DType::F16, DType::F16, DType::F16) => "layernorm_f16", + (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16", + (dt1, dt2, dt3) => { + candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}") + } + }; + + if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) { + candle::bail!("Non contiguous layernorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?; + candle_metal_kernels::call_layer_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + s3.buffer(), + l3.start_offset() * s3.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } +} + +pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let x = { + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + }; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(alpha)? + .broadcast_add(beta) +} + +pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> { + let hidden_size_xs = xs.dim(D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + let hidden_size_beta = beta.dims1()?; + if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta { + candle::bail!( + "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}", + xs.shape(), + alpha.shape(), + beta.shape() + ) + } + xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps }) +} + // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> { let (b_size, c, h, w) = xs.dims4()?; diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index f9cfe46d..65a8fbf2 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> { Ok(()) } +fn layer_norm(device: &Device) -> Result<()> { + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?; + let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?; + let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?; + assert_eq!( + to_vec3_round(&t, 4)?, + &[ + [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]], + [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]] + ] + ); + let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?; + assert_eq!( + to_vec3_round(&t2, 4)?, + &[ + [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]], + [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]] + ] + ); + let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?; + assert!(diff < 1e-5); + Ok(()) +} + #[test] fn softmax_numerical_stability() -> Result<()> { let dev = &Device::Cpu; @@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); +test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); |