diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-05 15:20:23 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-05 14:20:23 +0100 |
commit | 1c9e5394a5056aadc948f9330ea31fea4972e65e (patch) | |
tree | afabffd5e6663ee1b6231020981ab50273154ba6 /candle-nn/src | |
parent | a8410bf35ea3ad8eb973f48d301e65309d232377 (diff) | |
download | candle-1c9e5394a5056aadc948f9330ea31fea4972e65e.tar.gz candle-1c9e5394a5056aadc948f9330ea31fea4972e65e.tar.bz2 candle-1c9e5394a5056aadc948f9330ea31fea4972e65e.zip |
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation.
* Add softmaxlastdim to the benchmarks.
* And add a test.
* Support more dtypes.
* Polish the code.
* Use the slow implementation on cuda.
* Add a todo for the cuda kernel.
Diffstat (limited to 'candle-nn/src')
-rw-r--r-- | candle-nn/src/ops.rs | 69 |
1 files changed, 68 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index c3b6ffa2..55da46f8 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,5 @@ -use candle::{Result, Tensor}; +use candle::{CpuStorage, Layout, Result, Shape, Tensor}; +use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. @@ -77,3 +78,69 @@ impl Dropout { } } } + +struct SoftmaxLastDim; + +impl candle::CustomOp1 for SoftmaxLastDim { + fn name(&self) -> &'static str { + "softmax-last-dim" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + fn softmax<T: candle::WithDType + num_traits::Float>( + src: &[T], + layout: &Layout, + ) -> Result<(CpuStorage, Shape)> { + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &src[o1..o2], + }; + let el_count = layout.shape().elem_count(); + let dims = layout.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let mut dst = vec![T::zero(); el_count]; + src.par_chunks(dim_m1) + .zip(dst.par_chunks_mut(dim_m1)) + .for_each(|(src, dst)| { + let mut max = T::neg_infinity(); + for &s in src.iter() { + max = T::max(s, max) + } + let mut sum_exp = T::zero(); + for (s, d) in src.iter().zip(dst.iter_mut()) { + *d = (*s - max).exp(); + sum_exp += *d + } + for d in dst.iter_mut() { + *d /= sum_exp + } + }); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, Shape::from_dims(dims))) + } + + match storage { + CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout), + CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout), + CpuStorage::F32(slice) => softmax::<f32>(slice, layout), + CpuStorage::F64(slice) => softmax::<f64>(slice, layout), + _ => candle::bail!("unsupported dtype for softmax {:?}", storage), + } + } + + fn cuda_fwd( + &self, + _storage: &candle::CudaStorage, + _layout: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + candle::bail!("TODO: implement a cuda kernel") + } +} + +pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { + if xs.device().is_cpu() { + xs.apply_op1_no_bwd(&SoftmaxLastDim) + } else { + softmax(xs, candle::D::Minus1) + } +} |