diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-21 06:36:28 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-21 06:36:28 +0100 |
commit | af7f8b87d35e2ee595cf871c3401beed4dc9b3d8 (patch) | |
tree | 840983301567c487424dbd4b0db33e2d1033f247 /candle-nn/src | |
parent | b219903d0f9ee52f70397c7e9aa4df323b89a700 (diff) | |
download | candle-af7f8b87d35e2ee595cf871c3401beed4dc9b3d8.tar.gz candle-af7f8b87d35e2ee595cf871c3401beed4dc9b3d8.tar.bz2 candle-af7f8b87d35e2ee595cf871c3401beed4dc9b3d8.zip |
Custom op for RmsNorm (#1890)
* Trying out a custom RmsNorm cuda kernel.
* CPU implementation for rms-norm.
* Cuda wrappers.
* Add some validation.
* Add some testing.
* More testing.
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/ops.rs | 171 |
1 files changed, 167 insertions, 4 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 88d1b3d6..d725bdc2 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -180,11 +180,10 @@ impl candle::CustomOp1 for SoftmaxLastDim { block_dim: (1, 32, 1), shared_mem_bytes: 0, }; - let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?; // SAFETY: Set later by running the kernel. let dst = unsafe { dev.alloc::<T>(el) }.w()?; - let params = (src, &dst, n_cols as i32); + let params = (&src, &dst, n_cols as i32); // SAFETY: ffi. unsafe { func.launch(cfg, params) }.w()?; Ok(dst) @@ -207,7 +206,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { storage: &candle::MetalStorage, layout: &Layout, ) -> Result<(candle::MetalStorage, Shape)> { - use candle::{backend::BackendStorage, DType}; + use candle::backend::BackendStorage; let device = storage.device(); let command_buffer = device.command_buffer()?; let kernels = device.kernels(); @@ -248,6 +247,170 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { xs.apply_op1_no_bwd(&SoftmaxLastDim) } +#[derive(Debug, Clone)] +struct RmsNorm { + eps: f32, +} + +impl candle::CustomOp2 for RmsNorm { + fn name(&self) -> &'static str { + "rms-norm" + } + + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape)> { + use candle::backend::BackendStorage; + + let eps = self.eps; + fn inner< + T: candle::WithDType + + num_traits::Float + + num_traits::AsPrimitive<f32> + + num_traits::FromPrimitive, + >( + src: &[T], + layout: &Layout, + alpha: &[T], + alpha_layout: &Layout, + eps: f32, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => &alpha[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let sum2 = src + .iter() + .map(|&v| { + let v = v.as_(); + v * v + }) + .sum::<f32>(); + let m = (sum2 / dim_m1 as f32 + eps).sqrt(); + let m = T::from_f32(m).unwrap_or_else(T::nan); + for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) { + *d = *s / m * *alpha + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + use CpuStorage as C; + match (s1, s2) { + (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps), + (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps), + (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps), + _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()), + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &candle::CudaStorage, + l1: &Layout, + s2: &candle::CudaStorage, + l2: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr}; + use candle::{CudaDevice, WithDType}; + + struct S { + eps: f32, + } + impl Map2 for S { + fn f<T: DeviceRepr + WithDType>( + &self, + src: &CudaSlice<T>, + layout: &Layout, + alpha: &CudaSlice<T>, + alpha_layout: &Layout, + dev: &CudaDevice, + ) -> Result<CudaSlice<T>> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => src.slice(o1..o2), + }; + let alpha = match alpha_layout.contiguous_offsets() { + None => candle::bail!("alpha has to be contiguous"), + Some((o1, o2)) => alpha.slice(o1..o2), + }; + let el = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let (n_rows, n_cols) = (el / dim_m1, dim_m1); + + let cfg = LaunchConfig { + grid_dim: (n_rows as u32, 1, 1), + block_dim: (1024, 1, 1), + shared_mem_bytes: 0, + }; + let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::<T>(el) }.w()?; + let params = (&src, &dst, &alpha, n_cols as i32, self.eps); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + Ok(dst) + } + } + + use candle::backend::BackendStorage; + let dev = s1.device(); + let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?; + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, l1.shape().clone())) + } +} + +pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let hidden_size = x.dim(candle::D::Minus1)?; + let x = x.to_dtype(internal_dtype)?; + let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; + x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) +} + +pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { + let hidden_size_xs = xs.dim(candle::D::Minus1)?; + let hidden_size_alpha = alpha.dims1()?; + if hidden_size_xs != hidden_size_alpha { + candle::bail!( + "shape mismatch in rms-norm {:?} {:?}", + xs.shape(), + alpha.shape() + ) + } + xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) +} + // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> { let (b_size, c, h, w) = xs.dims4()?; |