summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/cuda_backend/mod.rs2
-rw-r--r--candle-core/src/cuda_backend/utils.rs38
-rw-r--r--candle-kernels/src/reduce.cu84
-rw-r--r--candle-metal-kernels/src/lib.rs63
-rw-r--r--candle-metal-kernels/src/reduce.metal81
-rw-r--r--candle-nn/src/ops.rs258
-rw-r--r--candle-nn/tests/ops.rs27
7 files changed, 547 insertions, 6 deletions
diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs
index 88f325f4..9e72dcc8 100644
--- a/candle-core/src/cuda_backend/mod.rs
+++ b/candle-core/src/cuda_backend/mod.rs
@@ -16,7 +16,7 @@ mod error;
mod utils;
pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr};
-pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
+pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>),
diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs
index 8dd5be77..c1210727 100644
--- a/candle-core/src/cuda_backend/utils.rs
+++ b/candle-core/src/cuda_backend/utils.rs
@@ -54,6 +54,44 @@ pub trait Map2 {
}
}
+pub trait Map3 {
+ #[allow(clippy::too_many_arguments)]
+ fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
+ &self,
+ src1: &CudaSlice<T>,
+ layout1: &Layout,
+ src2: &CudaSlice<T>,
+ layout2: &Layout,
+ src3: &CudaSlice<T>,
+ layout3: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>>;
+
+ #[allow(clippy::too_many_arguments)]
+ fn map(
+ &self,
+ s1: &S,
+ l1: &Layout,
+ s2: &S,
+ l2: &Layout,
+ s3: &S,
+ l3: &Layout,
+ d: &CudaDevice,
+ ) -> Result<S> {
+ let out = match (s1, s2, s3) {
+ (S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
+ _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
+ };
+ Ok(out)
+ }
+}
+
pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu
index 4dbd8dcc..aaac24a1 100644
--- a/candle-kernels/src/reduce.cu
+++ b/candle-kernels/src/reduce.cu
@@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0];
}
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+ }
+ return a;
+}
+
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x;
}
+// LayerNorm implementation adapted from ggml, accumulation is made using f32.
+// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
+template <typename T>
+__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+ const int block_size = blockDim.x;
+
+ float2 mean_var = make_float2(0.f, 0.f);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ mean_var.x += xi;
+ mean_var.y += xi * xi;
+ }
+
+ // sum up partial sums
+ mean_var = warp_reduce_sum(mean_var);
+ if (block_size > WARP_SIZE) {
+ __shared__ float2 s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = mean_var;
+ }
+ __syncthreads();
+ mean_var = s_sum[lane_id];
+ mean_var = warp_reduce_sum(mean_var);
+ }
+
+ const float mean = mean_var.x / ncols;
+ const float var = mean_var.y / ncols - mean * mean;
+ const float inv_std = rsqrtf(var + eps);
+
+ if (alpha == nullptr && beta == nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs);
+ }
+ }
+ else if (alpha == nullptr && beta != nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float b = static_cast<float>(beta[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs + b);
+ }
+ }
+ else if (alpha != nullptr && beta == nullptr) {
+ for (int col = tid; col < ncols; col += block_size) {
+ float a = static_cast<float>(alpha[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs * a);
+ }
+ }
+ else {
+ for (int col = tid; col < ncols; col += block_size) {
+ float a = static_cast<float>(alpha[col]);
+ float b = static_cast<float>(beta[col]);
+ float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
+ dst[row*ncols + col] = static_cast<T>(lhs * a + b);
+ }
+ }
+}
+
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
@@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \
+#define LAYERNORM_OP(TYPENAME, FN_NAME) \
+ extern "C" __global__ void FN_NAME( \
+ const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
+ const TYPENAME *beta, const int n_cols, const float eps) { \
+ layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
+ } \
+
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \
@@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
+LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
@@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
+LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
@@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
+LAYERNORM_OP(float, layernorm_f32)
+LAYERNORM_OP(double, layernorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs
index 814ca0b9..aa157a2f 100644
--- a/candle-metal-kernels/src/lib.rs
+++ b/candle-metal-kernels/src/lib.rs
@@ -739,6 +739,69 @@ pub fn call_rms_norm(
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
+ encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
+ encoder.end_encoding();
+ Ok(())
+}
+
+#[allow(clippy::too_many_arguments)]
+pub fn call_layer_norm(
+ device: &Device,
+ command_buffer: &CommandBufferRef,
+ kernels: &Kernels,
+ kernel_name: &'static str,
+ length: usize,
+ elements_to_sum: usize,
+ eps: f32,
+ input: &Buffer,
+ input_offset: usize,
+ alpha: &Buffer,
+ alpha_offset: usize,
+ beta: &Buffer,
+ beta_offset: usize,
+ output: &Buffer,
+) -> Result<(), MetalKernelError> {
+ let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
+ let encoder = command_buffer.new_compute_command_encoder();
+ encoder.set_compute_pipeline_state(&pipeline);
+
+ set_params!(
+ encoder,
+ (
+ length,
+ elements_to_sum,
+ (input, input_offset),
+ output,
+ (alpha, alpha_offset),
+ (beta, beta_offset),
+ eps
+ )
+ );
+
+ let out_length = length / elements_to_sum;
+
+ let thread_group_count = MTLSize {
+ width: out_length as u64,
+ height: 1,
+ depth: 1,
+ };
+
+ let width = std::cmp::min(
+ pipeline.max_total_threads_per_threadgroup(),
+ elements_to_sum as u64,
+ )
+ .next_power_of_two();
+
+ let thread_group_size = MTLSize {
+ width,
+ height: 1,
+ depth: 1,
+ };
+
+ encoder.use_resource(input, metal::MTLResourceUsage::Read);
+ encoder.use_resource(output, metal::MTLResourceUsage::Write);
+ encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal
index 14bfb297..e009ca1d 100644
--- a/candle-metal-kernels/src/reduce.metal
+++ b/candle-metal-kernels/src/reduce.metal
@@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
}
}
+template<typename T>
+METAL_FUNC void layernorm(
+ constant size_t & src_numel,
+ constant size_t & el_to_sum_per_block,
+ device const T * src,
+ device T * dst,
+ device const T * alpha,
+ device const T * beta,
+ constant float & eps,
+ uint id,
+ uint tid,
+ uint dst_id,
+ uint block_dim,
+ threadgroup float * shared_memory
+) {
+ size_t start_idx = dst_id * el_to_sum_per_block;
+ size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
+ size_t idx = start_idx + tid;
+
+ float tmp1 = 0;
+ float tmp2 = 0;
+ while (idx < stop_idx) {
+ tmp1 += float(src[idx]);
+ tmp2 += float(src[idx]) * float(src[idx]);
+ idx += block_dim;
+ }
+ shared_memory[tid] = tmp1;
+ shared_memory[tid + block_dim] = tmp2;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (uint s = block_dim / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
+ shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+ /* wait for shared_memory[0] to be filled */
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float mean = shared_memory[0] / float(el_to_sum_per_block);
+ float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
+ float inv_norm = 1.0f / sqrt(var + eps);
+ idx = start_idx + tid;
+ while (idx < stop_idx) {
+ float val = (float(src[idx]) - mean) * inv_norm;
+ if (alpha != nullptr) {
+ val *= float(alpha[idx - start_idx]);
+ }
+ if (beta != nullptr) {
+ val += float(beta[idx - start_idx]);
+ }
+ dst[idx] = T(val);
+ idx += block_dim;
+ }
+}
+
#define RMSNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@@ -371,6 +430,25 @@ kernel void NAME( \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
+#define LAYERNORM(NAME, T) \
+kernel void NAME( \
+ constant size_t &src_numel, \
+ constant size_t &el_to_sum_per_block, \
+ device const T *src, \
+ device T *dst, \
+ device const T *alpha, \
+ device const T *beta, \
+ constant float &eps, \
+ uint id [[ thread_position_in_grid ]], \
+ uint tid [[ thread_index_in_threadgroup ]], \
+ uint dst_id [[ threadgroup_position_in_grid ]], \
+ uint block_dim [[ threads_per_threadgroup ]] \
+) { \
+ threadgroup float shared_memory[THREADGROUP_SIZE]; \
+ shared_memory[tid] = 0; \
+ layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
+} \
+
template<typename T>
METAL_FUNC void ropei(
constant size_t &bh,
@@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half)
+LAYERNORM(layernorm_f32, float)
+LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
@@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat)
+LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index eabc95d8..2a76ee5e 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,4 +1,4 @@
-use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
+use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
@@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
- let xs = xs.chunk(2, candle::D::Minus1)?;
+ let xs = xs.chunk(2, D::Minus1)?;
&xs[0].silu()? * &xs[1]
}
@@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
- let hidden_size = x.dim(candle::D::Minus1)?;
+ let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
- let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
}
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
- let hidden_size_xs = xs.dim(candle::D::Minus1)?;
+ let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
if hidden_size_xs != hidden_size_alpha {
candle::bail!(
@@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
}
+#[derive(Debug, Clone)]
+struct LayerNorm {
+ eps: f32,
+}
+
+impl candle::CustomOp3 for LayerNorm {
+ fn name(&self) -> &'static str {
+ "layer-norm"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &CpuStorage,
+ l1: &Layout,
+ s2: &CpuStorage,
+ l2: &Layout,
+ s3: &CpuStorage,
+ l3: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ use candle::backend::BackendStorage;
+
+ let eps = self.eps;
+ fn inner<
+ T: candle::WithDType
+ + num_traits::Float
+ + num_traits::AsPrimitive<f32>
+ + num_traits::FromPrimitive,
+ >(
+ src: &[T],
+ layout: &Layout,
+ alpha: &[T],
+ alpha_layout: &Layout,
+ beta: &[T],
+ beta_layout: &Layout,
+ eps: f32,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let alpha = match alpha_layout.contiguous_offsets() {
+ None => candle::bail!("alpha has to be contiguous"),
+ Some((o1, o2)) => &alpha[o1..o2],
+ };
+ let beta = match beta_layout.contiguous_offsets() {
+ None => candle::bail!("beta has to be contiguous"),
+ Some((o1, o2)) => &beta[o1..o2],
+ };
+ let el_count = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(dim_m1)
+ .zip(dst.par_chunks_mut(dim_m1))
+ .for_each(|(src, dst)| {
+ let mut sum = 0f32;
+ let mut sum2 = 0f32;
+ for v in src {
+ let v = v.as_();
+ sum += v;
+ sum2 += v * v;
+ }
+ let mean = sum / dim_m1 as f32;
+ let var = sum2 / dim_m1 as f32 - mean * mean;
+ let inv_std = (var + eps).sqrt().recip();
+ for ((d, s), (alpha, beta)) in
+ dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
+ {
+ let alpha = alpha.as_();
+ let beta = beta.as_();
+ let d_ = (s.as_() - mean) * inv_std * alpha + beta;
+ *d = T::from_f32(d_).unwrap_or_else(T::nan);
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, Shape::from_dims(dims)))
+ }
+
+ use CpuStorage as C;
+ match (s1, s2, s3) {
+ (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
+ inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
+ }
+ (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
+ (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
+ _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
+ }
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s1: &candle::CudaStorage,
+ l1: &Layout,
+ s2: &candle::CudaStorage,
+ l2: &Layout,
+ s3: &candle::CudaStorage,
+ l3: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
+ };
+ use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ struct S {
+ eps: f32,
+ }
+ impl Map3 for S {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ layout: &Layout,
+ alpha: &CudaSlice<T>,
+ alpha_layout: &Layout,
+ beta: &CudaSlice<T>,
+ beta_layout: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let alpha = match alpha_layout.contiguous_offsets() {
+ None => candle::bail!("alpha has to be contiguous"),
+ Some((o1, o2)) => alpha.slice(o1..o2),
+ };
+ let beta = match beta_layout.contiguous_offsets() {
+ None => candle::bail!("beta has to be contiguous"),
+ Some((o1, o2)) => beta.slice(o1..o2),
+ };
+ let el = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let (n_rows, n_cols) = (el / dim_m1, dim_m1);
+
+ let cfg = LaunchConfig {
+ grid_dim: (n_rows as u32, 1, 1),
+ block_dim: (1024, 1, 1),
+ shared_mem_bytes: 0,
+ };
+ let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(el) }.w()?;
+ let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+ }
+
+ use candle::backend::BackendStorage;
+ let dev = s1.device();
+ let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
+ let dst = candle::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, l1.shape().clone()))
+ }
+
+ #[cfg(feature = "metal")]
+ fn metal_fwd(
+ &self,
+ s1: &candle::MetalStorage,
+ l1: &Layout,
+ s2: &candle::MetalStorage,
+ l2: &Layout,
+ s3: &candle::MetalStorage,
+ l3: &Layout,
+ ) -> Result<(candle::MetalStorage, Shape)> {
+ use candle::backend::BackendStorage;
+ let device = s1.device();
+ let command_buffer = device.command_buffer()?;
+ let kernels = device.kernels();
+ let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
+ (DType::F32, DType::F32, DType::F32) => "layernorm_f32",
+ (DType::F16, DType::F16, DType::F16) => "layernorm_f16",
+ (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
+ (dt1, dt2, dt3) => {
+ candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
+ }
+ };
+
+ if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
+ candle::bail!("Non contiguous layernorm is not implemented");
+ }
+
+ let last_dim = l1.dims()[l1.shape().rank() - 1];
+ let elem_count = l1.shape().elem_count();
+ let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
+ candle_metal_kernels::call_layer_norm(
+ device.metal_device(),
+ &command_buffer,
+ kernels,
+ name,
+ elem_count,
+ last_dim,
+ self.eps,
+ s1.buffer(),
+ l1.start_offset() * s1.dtype().size_in_bytes(),
+ s2.buffer(),
+ l2.start_offset() * s2.dtype().size_in_bytes(),
+ s3.buffer(),
+ l3.start_offset() * s3.dtype().size_in_bytes(),
+ &output,
+ )
+ .map_err(candle::Error::wrap)?;
+ let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
+ Ok((newstorage, l1.shape().clone()))
+ }
+}
+
+pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let hidden_size = x.dim(D::Minus1)?;
+ let x = x.to_dtype(internal_dtype)?;
+ let x = {
+ let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ x.broadcast_sub(&mean_x)?
+ };
+ let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
+ x_normed
+ .to_dtype(x_dtype)?
+ .broadcast_mul(alpha)?
+ .broadcast_add(beta)
+}
+
+pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
+ let hidden_size_xs = xs.dim(D::Minus1)?;
+ let hidden_size_alpha = alpha.dims1()?;
+ let hidden_size_beta = beta.dims1()?;
+ if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
+ candle::bail!(
+ "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
+ xs.shape(),
+ alpha.shape(),
+ beta.shape()
+ )
+ }
+ xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
+}
+
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
index f9cfe46d..65a8fbf2 100644
--- a/candle-nn/tests/ops.rs
+++ b/candle-nn/tests/ops.rs
@@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}
+fn layer_norm(device: &Device) -> Result<()> {
+ let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
+ let tensor = Tensor::new(data, device)?;
+ let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
+ let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
+ let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
+ assert_eq!(
+ to_vec3_round(&t, 4)?,
+ &[
+ [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
+ [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
+ ]
+ );
+ let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
+ assert_eq!(
+ to_vec3_round(&t2, 4)?,
+ &[
+ [[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
+ [[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
+ ]
+ );
+ let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
+ assert!(diff < 1e-5);
+ Ok(())
+}
+
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
+test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);