summaryrefslogtreecommitdiff
path: root/candle-nn/src/ops.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-nn/src/ops.rs')
-rw-r--r--candle-nn/src/ops.rs171
1 files changed, 167 insertions, 4 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 88d1b3d6..d725bdc2 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,4 +1,4 @@
-use candle::{CpuStorage, Layout, Result, Shape, Tensor};
+use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
@@ -180,11 +180,10 @@ impl candle::CustomOp1 for SoftmaxLastDim {
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
- let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
- let params = (src, &dst, n_cols as i32);
+ let params = (&src, &dst, n_cols as i32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@@ -207,7 +206,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
- use candle::{backend::BackendStorage, DType};
+ use candle::backend::BackendStorage;
let device = storage.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
@@ -248,6 +247,170 @@ pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
xs.apply_op1_no_bwd(&SoftmaxLastDim)
}
+#[derive(Debug, Clone)]
+struct RmsNorm {
+ eps: f32,
+}
+
+impl candle::CustomOp2 for RmsNorm {
+ fn name(&self) -> &'static str {
+ "rms-norm"
+ }
+
+ fn cpu_fwd(
+ &self,
+ s1: &CpuStorage,
+ l1: &Layout,
+ s2: &CpuStorage,
+ l2: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ use candle::backend::BackendStorage;
+
+ let eps = self.eps;
+ fn inner<
+ T: candle::WithDType
+ + num_traits::Float
+ + num_traits::AsPrimitive<f32>
+ + num_traits::FromPrimitive,
+ >(
+ src: &[T],
+ layout: &Layout,
+ alpha: &[T],
+ alpha_layout: &Layout,
+ eps: f32,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let alpha = match alpha_layout.contiguous_offsets() {
+ None => candle::bail!("alpha has to be contiguous"),
+ Some((o1, o2)) => &alpha[o1..o2],
+ };
+ let el_count = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(dim_m1)
+ .zip(dst.par_chunks_mut(dim_m1))
+ .for_each(|(src, dst)| {
+ let sum2 = src
+ .iter()
+ .map(|&v| {
+ let v = v.as_();
+ v * v
+ })
+ .sum::<f32>();
+ let m = (sum2 / dim_m1 as f32 + eps).sqrt();
+ let m = T::from_f32(m).unwrap_or_else(T::nan);
+ for ((d, s), alpha) in dst.iter_mut().zip(src.iter()).zip(alpha) {
+ *d = *s / m * *alpha
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, Shape::from_dims(dims)))
+ }
+
+ use CpuStorage as C;
+ match (s1, s2) {
+ (C::BF16(s1), C::BF16(s2)) => inner::<half::bf16>(s1, l1, s2, l2, eps),
+ (C::F16(s1), C::F16(s2)) => inner::<half::f16>(s1, l1, s2, l2, eps),
+ (C::F32(s1), C::F32(s2)) => inner::<f32>(s1, l1, s2, l2, eps),
+ _ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
+ }
+ }
+
+ #[cfg(feature = "cuda")]
+ fn cuda_fwd(
+ &self,
+ s1: &candle::CudaStorage,
+ l1: &Layout,
+ s2: &candle::CudaStorage,
+ l2: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ use candle::cuda_backend::cudarc::driver::{
+ CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
+ };
+ use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
+ use candle::{CudaDevice, WithDType};
+
+ struct S {
+ eps: f32,
+ }
+ impl Map2 for S {
+ fn f<T: DeviceRepr + WithDType>(
+ &self,
+ src: &CudaSlice<T>,
+ layout: &Layout,
+ alpha: &CudaSlice<T>,
+ alpha_layout: &Layout,
+ dev: &CudaDevice,
+ ) -> Result<CudaSlice<T>> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => src.slice(o1..o2),
+ };
+ let alpha = match alpha_layout.contiguous_offsets() {
+ None => candle::bail!("alpha has to be contiguous"),
+ Some((o1, o2)) => alpha.slice(o1..o2),
+ };
+ let el = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let (n_rows, n_cols) = (el / dim_m1, dim_m1);
+
+ let cfg = LaunchConfig {
+ grid_dim: (n_rows as u32, 1, 1),
+ block_dim: (1024, 1, 1),
+ shared_mem_bytes: 0,
+ };
+ let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
+ // SAFETY: Set later by running the kernel.
+ let dst = unsafe { dev.alloc::<T>(el) }.w()?;
+ let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
+ // SAFETY: ffi.
+ unsafe { func.launch(cfg, params) }.w()?;
+ Ok(dst)
+ }
+ }
+
+ use candle::backend::BackendStorage;
+ let dev = s1.device();
+ let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, dev)?;
+ let dst = candle::cuda_backend::CudaStorage {
+ slice,
+ device: dev.clone(),
+ };
+ Ok((dst, l1.shape().clone()))
+ }
+}
+
+pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
+ let x_dtype = x.dtype();
+ let internal_dtype = match x_dtype {
+ DType::F16 | DType::BF16 => DType::F32,
+ d => d,
+ };
+ let hidden_size = x.dim(candle::D::Minus1)?;
+ let x = x.to_dtype(internal_dtype)?;
+ let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
+ let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
+ x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
+}
+
+pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
+ let hidden_size_xs = xs.dim(candle::D::Minus1)?;
+ let hidden_size_alpha = alpha.dims1()?;
+ if hidden_size_xs != hidden_size_alpha {
+ candle::bail!(
+ "shape mismatch in rms-norm {:?} {:?}",
+ xs.shape(),
+ alpha.shape()
+ )
+ }
+ xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
+}
+
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?;