summaryrefslogtreecommitdiff
path: root/candle-nn/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-05 15:20:23 +0200
committerGitHub <noreply@github.com>2023-09-05 14:20:23 +0100
commit1c9e5394a5056aadc948f9330ea31fea4972e65e (patch)
treeafabffd5e6663ee1b6231020981ab50273154ba6 /candle-nn/src
parenta8410bf35ea3ad8eb973f48d301e65309d232377 (diff)
downloadcandle-1c9e5394a5056aadc948f9330ea31fea4972e65e.tar.gz
candle-1c9e5394a5056aadc948f9330ea31fea4972e65e.tar.bz2
candle-1c9e5394a5056aadc948f9330ea31fea4972e65e.zip
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel.
Diffstat (limited to 'candle-nn/src')
-rw-r--r--candle-nn/src/ops.rs69
1 files changed, 68 insertions, 1 deletions
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index c3b6ffa2..55da46f8 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,4 +1,5 @@
-use candle::{Result, Tensor};
+use candle::{CpuStorage, Layout, Result, Shape, Tensor};
+use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@@ -77,3 +78,69 @@ impl Dropout {
}
}
}
+
+struct SoftmaxLastDim;
+
+impl candle::CustomOp1 for SoftmaxLastDim {
+ fn name(&self) -> &'static str {
+ "softmax-last-dim"
+ }
+
+ fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
+ fn softmax<T: candle::WithDType + num_traits::Float>(
+ src: &[T],
+ layout: &Layout,
+ ) -> Result<(CpuStorage, Shape)> {
+ let src = match layout.contiguous_offsets() {
+ None => candle::bail!("input has to be contiguous"),
+ Some((o1, o2)) => &src[o1..o2],
+ };
+ let el_count = layout.shape().elem_count();
+ let dims = layout.shape().dims();
+ let dim_m1 = dims[dims.len() - 1];
+ let mut dst = vec![T::zero(); el_count];
+ src.par_chunks(dim_m1)
+ .zip(dst.par_chunks_mut(dim_m1))
+ .for_each(|(src, dst)| {
+ let mut max = T::neg_infinity();
+ for &s in src.iter() {
+ max = T::max(s, max)
+ }
+ let mut sum_exp = T::zero();
+ for (s, d) in src.iter().zip(dst.iter_mut()) {
+ *d = (*s - max).exp();
+ sum_exp += *d
+ }
+ for d in dst.iter_mut() {
+ *d /= sum_exp
+ }
+ });
+ let storage = candle::WithDType::to_cpu_storage_owned(dst);
+ Ok((storage, Shape::from_dims(dims)))
+ }
+
+ match storage {
+ CpuStorage::BF16(slice) => softmax::<half::bf16>(slice, layout),
+ CpuStorage::F16(slice) => softmax::<half::f16>(slice, layout),
+ CpuStorage::F32(slice) => softmax::<f32>(slice, layout),
+ CpuStorage::F64(slice) => softmax::<f64>(slice, layout),
+ _ => candle::bail!("unsupported dtype for softmax {:?}", storage),
+ }
+ }
+
+ fn cuda_fwd(
+ &self,
+ _storage: &candle::CudaStorage,
+ _layout: &Layout,
+ ) -> Result<(candle::CudaStorage, Shape)> {
+ candle::bail!("TODO: implement a cuda kernel")
+ }
+}
+
+pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
+ if xs.device().is_cpu() {
+ xs.apply_op1_no_bwd(&SoftmaxLastDim)
+ } else {
+ softmax(xs, candle::D::Minus1)
+ }
+}