diff options
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r-- | candle-nn/src/ops.rs | 258 |
1 files changed, 253 insertions, 5 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index eabc95d8..2a76ee5e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> { } pub fn swiglu(xs: &Tensor) -> Result<Tensor> { - let xs = xs.chunk(2, candle::D::Minus1)?; + let xs = xs.chunk(2, D::Minus1)?; &xs[0].silu()? * &xs[1] } @@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let hidden_size = x.dim(candle::D::Minus1)?; + let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; - let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) } pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { - let hidden_size_xs = xs.dim(candle::D::Minus1)?; + let hidden_size_xs = xs.dim(D::Minus1)?; let hidden_size_alpha = alpha.dims1()?; if hidden_size_xs != hidden_size_alpha { candle::bail!( @@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) } +#[derive(Debug, Clone)] +struct LayerNorm { + eps: f32, +} + +impl candle::CustomOp3 for LayerNorm { + fn name(&self) -> &'static str { + "layer-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + s3: &CpuStorage, + l3: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: candle::WithDType + + num_traits::Float + + num_traits::AsPrimitive<f32> + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + beta: &[T], + beta_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let beta = match beta_layout.contiguous_offsets() { + None => candle::bail!("beta has to be contiguous"), + Some((o1, o2)) => &beta[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut sum = 0f32; + let mut sum2 = 0f32; + for v in src { + let v = v.as_(); + sum += v; + sum2 += v * v; + } + let mean = sum / dim_m1 as f32; + let var = sum2 / dim_m1 as f32 - mean * mean; + let inv_std = (var + eps).sqrt().recip(); + for ((d, s), (alpha, beta)) in + dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta)) + { + let alpha = alpha.as_(); + let beta = beta.as_(); + let d_ = (s.as_() - mean) * inv_std * alpha + beta; + *d = T::from_f32(d_).unwrap_or_else(T::nan); + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2, s3) { + (C::BF16(s1), C::BF16(s2), C::BF16(s3)) => { + inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps) + } + (C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps), + (C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps), + _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + s3: &candle::CudaStorage, + l3: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map3 for S { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + layout: &Layout, + alpha: &CudaSlice<T>, + alpha_layout: &Layout, + beta: &CudaSlice<T>, + beta_layout: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let beta = match beta_layout.contiguous_offsets() { + None => candle::bail!("beta has to be contiguous"), + Some((o1, o2)) => beta.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1024, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(el) }.w()?; + let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } + + #[cfg(feature = "metal")] + fn metal_fwd( + &self, + s1: &candle::MetalStorage, + l1: &Layout, + s2: &candle::MetalStorage, + l2: &Layout, + s3: &candle::MetalStorage, + l3: &Layout, + ) -> Result<(candle::MetalStorage, Shape)> { + use candle::backend::BackendStorage; + let device = s1.device(); + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match (s1.dtype(), s2.dtype(), s3.dtype()) { + (DType::F32, DType::F32, DType::F32) => "layernorm_f32", + (DType::F16, DType::F16, DType::F16) => "layernorm_f16", + (DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16", + (dt1, dt2, dt3) => { + candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}") + } + }; + + if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) { + candle::bail!("Non contiguous layernorm is not implemented"); + } + + let last_dim = l1.dims()[l1.shape().rank() - 1]; + let elem_count = l1.shape().elem_count(); + let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?; + candle_metal_kernels::call_layer_norm( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + self.eps, + s1.buffer(), + l1.start_offset() * s1.dtype().size_in_bytes(), + s2.buffer(), + l2.start_offset() * s2.dtype().size_in_bytes(), + s3.buffer(), + l3.start_offset() * s3.dtype().size_in_bytes(), + &output, + ) + .map_err(candle::Error::wrap)?; + let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype()); + Ok((newstorage, l1.shape().clone())) + } +} + +pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let x = { + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + x.broadcast_sub(&mean_x)? + }; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(alpha)? + .broadcast_add(beta) +} + +pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> { + let hidden_size_xs = xs.dim(D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + let hidden_size_beta = beta.dims1()?; + if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta { + candle::bail!( + "shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}", + xs.shape(), + alpha.shape(), + beta.shape() + ) + } + xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps }) +} + // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> { let (b_size, c, h, w) = xs.dims4()?; |