summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/backend.rs2
-rw-r--r--candle-core/src/backprop.rs2
-rw-r--r--candle-core/src/cpu_backend.rs50
-rw-r--r--candle-core/src/cuda_backend.rs4
-rw-r--r--candle-core/src/dummy_cuda_backend.rs4
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/storage.rs9
-rw-r--r--candle-core/src/tensor.rs34
-rw-r--r--candle-core/tests/tensor_tests.rs38
-rw-r--r--candle-examples/examples/bert/model.rs2
-rw-r--r--candle-examples/examples/bigcode/model.rs12
-rw-r--r--candle-examples/examples/falcon/model.rs12
-rw-r--r--candle-examples/examples/llama/model.rs2
-rw-r--r--candle-examples/examples/llama2-c/model.rs2
-rw-r--r--candle-examples/examples/llama_multiprocess/model.rs2
-rw-r--r--candle-examples/examples/musicgen/musicgen_model.rs2
-rw-r--r--candle-examples/examples/musicgen/t5_model.rs2
-rw-r--r--candle-examples/examples/whisper/main.rs11
-rw-r--r--candle-examples/examples/whisper/model.rs4
-rw-r--r--candle-flash-attn/Cargo.toml1
-rw-r--r--candle-flash-attn/tests/flash_attn_tests.rs2
-rw-r--r--candle-nn/src/ops.rs24
-rw-r--r--candle-nn/tests/ops.rs62
-rw-r--r--candle-transformers/src/generation/mod.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/model.rs2
-rw-r--r--candle-wasm-examples/llama2-c/src/worker.rs4
-rw-r--r--candle-wasm-examples/whisper/src/model.rs2
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs11
28 files changed, 117 insertions, 188 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs
index 8815c08d..cee1cad0 100644
--- a/candle-core/src/backend.rs
+++ b/candle-core/src/backend.rs
@@ -21,8 +21,6 @@ pub trait BackendStorage: Sized {
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>;
- fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
-
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>;
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>;
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index d6beb70e..fd1650bb 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -90,7 +90,6 @@ impl Tensor {
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
- | Op::Softmax(node, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::CustomOp1(node, _) => {
@@ -324,7 +323,6 @@ impl Tensor {
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
- Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;
diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs
index 27d0f7da..c39cb9f7 100644
--- a/candle-core/src/cpu_backend.rs
+++ b/candle-core/src/cpu_backend.rs
@@ -1236,45 +1236,6 @@ impl Map2 for MatMul {
}
}
-fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
- // [self] stores data in a contiguous way starting at offset 0.
- let dims = shape.dims();
- let elem_per_slice = dims[dim];
- let prod_pre_dim = dims[..dim].iter().product();
- let prod_post_dim = dims[dim + 1..].iter().product();
- if prod_post_dim == 1 {
- for pre_idx in 0..prod_pre_dim {
- let mut sum = 0f64;
- let idx = pre_idx * elem_per_slice;
- for v in s[idx..idx + elem_per_slice].iter() {
- sum += v.to_f64();
- }
- let sum = T::from_f64(sum);
- for v in s[idx..idx + elem_per_slice].iter_mut() {
- *v /= sum
- }
- }
- } else {
- for pre_idx in 0..prod_pre_dim {
- for post_idx in 0..prod_post_dim {
- let mut sum = 0f64;
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- sum += s[idx].to_f64();
- idx += prod_post_dim
- }
- let sum = T::from_f64(sum);
- let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
- for _ in 0..elem_per_slice {
- s[idx] /= sum;
- idx += prod_post_dim
- }
- }
- }
- }
- Ok(())
-}
-
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
if v.is_sign_positive() {
v
@@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage {
Cmp(op).map(self, lhs_l, rhs, rhs_l)
}
- fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
- // [self] stores data in a contiguous way starting at offset 0.
- match self {
- Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
- Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
- Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
- Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
- Self::U8(_) | Self::U32(_) => Ok(()),
- }
- }
-
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
Affine(mul, add).map(self, layout)
}
diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs
index b3d542b9..4050b595 100644
--- a/candle-core/src/cuda_backend.rs
+++ b/candle-core/src/cuda_backend.rs
@@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
- fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
- Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
- }
-
fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device().clone();
let slice = U::V.map(&self.slice, &device, layout)?;
diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs
index c195cade..1213c502 100644
--- a/candle-core/src/dummy_cuda_backend.rs
+++ b/candle-core/src/dummy_cuda_backend.rs
@@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
- fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
- Err(Error::NotCompiledWithCudaSupport)
- }
-
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index 525383b2..4f489f30 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -93,7 +93,6 @@ pub enum Op {
Broadcast(Tensor),
Narrow(Tensor, usize, usize, usize),
Reshape(Tensor),
- Softmax(Tensor, usize),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
Elu(Tensor, f64),
diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs
index 52af5861..545f549b 100644
--- a/candle-core/src/storage.rs
+++ b/candle-core/src/storage.rs
@@ -125,15 +125,6 @@ impl Storage {
}
}
- // This assumes a contiguous layout and no offset.
- pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
- match self {
- Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
- Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?,
- }
- Ok(())
- }
-
pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 09f61340..8ae92c2e 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -553,40 +553,6 @@ impl Tensor {
}
}
- /// Applies the softmax function to the input tensor, rescaling the element so that elements on
- /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
- ///
- /// ```rust
- /// use candle::{Tensor, Device};
- /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
- /// let a = a.softmax(1)?;
- /// assert_eq!(
- /// a.to_vec2::<f32>()?,
- /// &[
- /// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
- /// [0.004892866, 0.26714143, 0.7261657, 0.0017999847],
- /// ]);
- /// # Ok::<(), candle::Error>(())
- /// ```
- pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> {
- let dim = dim.to_index(self.shape(), "softmax")?;
- // TODO: unify the two branches.
- if self.device().is_cuda() {
- // We do not have a cuda kernel for divide_by_sum_over_dim so split
- // the operation.
- let exp = self.exp()?;
- let sum_exp = exp.sum_keepdim(dim)?;
- exp.broadcast_div(&sum_exp)
- } else {
- let shape = self.shape();
- let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
- // The resulting storage is contiguous.
- storage.divide_by_sum_over_dim(shape, dim)?;
- let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
- Ok(from_storage(storage, shape.clone(), op, false))
- }
- }
-
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
match dims {
[] => Ok(self),
diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs
index a38b6d3d..a439ba30 100644
--- a/candle-core/tests/tensor_tests.rs
+++ b/candle-core/tests/tensor_tests.rs
@@ -1,6 +1,5 @@
mod test_utils;
use candle::{DType, Device, IndexOp, Result, Tensor};
-use test_utils::to_vec3_round;
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> {
Ok(())
}
-fn softmax(device: &Device) -> Result<()> {
- let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
- let tensor = Tensor::new(data, device)?;
- let t0 = tensor.log()?.softmax(0)?;
- let t1 = tensor.log()?.softmax(1)?;
- let t2 = tensor.log()?.softmax(2)?;
- assert_eq!(
- to_vec3_round(t0, 4)?,
- &[
- // 3/5, 1/2, 4/11
- [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
- // 2/5, 1/2, 7/11
- [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
- ]
- );
- assert_eq!(
- to_vec3_round(t1, 4)?,
- &[
- // 3/4, 1/6, 4/13
- [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
- // 2/10, 1/3, 7/15
- [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
- ]
- );
- assert_eq!(
- to_vec3_round(t2, 4)?,
- &[
- // (3, 1, 4) / 8, (1, 5, 9) / 15
- [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
- // (2, 1, 7) / 10, (8, 2, 8) / 18
- [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
- ]
- );
- Ok(())
-}
-
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, device)?;
@@ -620,7 +583,6 @@ test_device!(cat, cat_cpu, cat_gpu);
test_device!(sum, sum_cpu, sum_gpu);
test_device!(transpose, transpose_cpu, transpose_gpu);
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
-test_device!(softmax, softmax_cpu, softmax_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs
index 3bf412b2..b2438e71 100644
--- a/candle-examples/examples/bert/model.rs
+++ b/candle-examples/examples/bert/model.rs
@@ -333,7 +333,7 @@ impl BertSelfAttention {
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
- attention_scores.softmax(candle::D::Minus1)?
+ candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs)?;
diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs
index 3f68a5be..12993e2d 100644
--- a/candle-examples/examples/bigcode/model.rs
+++ b/candle-examples/examples/bigcode/model.rs
@@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
Ok(mask)
}
-// TODO: Use a numerically stable implementation by default.
-fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
- let d = d.to_index(xs.shape(), "log-softmax")?;
- let max = xs.max_keepdim(d)?;
- let diff = xs.broadcast_sub(&max)?;
- let num = diff.exp()?;
- let den = num.sum_keepdim(d)?;
- num.broadcast_div(&den)
-}
-
#[derive(Debug)]
pub struct Config {
pub vocab_size: usize,
@@ -192,7 +182,7 @@ impl Attention {
let mask_value =
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
- let attn_weights = softmax(&attn_weights, D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
let value = value.contiguous()?;
let attn_output = if self.multi_query {
attn_weights
diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs
index bce93c81..cab0b314 100644
--- a/candle-examples/examples/falcon/model.rs
+++ b/candle-examples/examples/falcon/model.rs
@@ -309,11 +309,13 @@ impl FalconAttention {
// Only handle the case where alibi is None here, and non-flash attention.
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
- let attention_scores = attention_scores
- .broadcast_add(&mask.squeeze(1)?)?
- .to_dtype(DType::F32)?
- .softmax(D::Minus1)?
- .to_dtype(x.dtype())?;
+ let attention_scores = candle_nn::ops::softmax(
+ &attention_scores
+ .broadcast_add(&mask.squeeze(1)?)?
+ .to_dtype(DType::F32)?,
+ D::Minus1,
+ )?
+ .to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index d519cafe..c4d33f0b 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -233,7 +233,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
index 13f939db..6d9e4bcd 100644
--- a/candle-examples/examples/llama2-c/model.rs
+++ b/candle-examples/examples/llama2-c/model.rs
@@ -158,7 +158,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs
index bcf6ed2b..ae2ef3e7 100644
--- a/candle-examples/examples/llama_multiprocess/model.rs
+++ b/candle-examples/examples/llama_multiprocess/model.rs
@@ -323,7 +323,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs
index 212f6818..01266e63 100644
--- a/candle-examples/examples/musicgen/musicgen_model.rs
+++ b/candle-examples/examples/musicgen/musicgen_model.rs
@@ -187,7 +187,7 @@ impl MusicgenAttention {
let attn_weights = attn_weights
.reshape((b_sz, self.num_heads, tgt_len, src_len))?
.broadcast_add(attention_mask)?;
- let attn_weights = attn_weights.softmax(D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// TODO: layer_head_mask?
let attn_output = attn_weights
.matmul(&value_states)?
diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs
index 61c0a1bb..ef65df39 100644
--- a/candle-examples/examples/musicgen/t5_model.rs
+++ b/candle-examples/examples/musicgen/t5_model.rs
@@ -223,7 +223,7 @@ impl T5Attention {
.transpose(1, 2)?;
let scores = q.matmul(&k.t()?)?;
// TODO: position_bias_masked
- let attn_weights = scores.softmax(D::Minus1)?;
+ let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = self.o.forward(&attn_output)?;
Ok(attn_output)
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c03779e7..82c45348 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
@@ -120,9 +120,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
- no_speech_prob = logits
- .get(0)?
- .softmax(0)?
+ no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@@ -132,7 +130,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
- let prs = (&logits / t)?.softmax(0)?;
+ let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
- let prob = logits
- .softmax(candle::D::Minus1)?
+ let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
index 330b2a00..4d80c0c8 100644
--- a/candle-examples/examples/whisper/model.rs
+++ b/candle-examples/examples/whisper/model.rs
@@ -2,7 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{Device, Tensor};
-use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
+use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
use serde::Deserialize;
// The names in comments correspond to the original implementation:
@@ -154,7 +154,7 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
- let w = qk.softmax(candle::D::Minus1)?;
+ let w = softmax(&qk, candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml
index 013da854..1bd6ec32 100644
--- a/candle-flash-attn/Cargo.toml
+++ b/candle-flash-attn/Cargo.toml
@@ -21,3 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
+candle-nn = { path = "../candle-nn", features = ["cuda"] }
diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs
index c6780659..43cb324f 100644
--- a/candle-flash-attn/tests/flash_attn_tests.rs
+++ b/candle-flash-attn/tests/flash_attn_tests.rs
@@ -21,7 +21,7 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs
index 88196aa7..611c66d8 100644
--- a/candle-nn/src/ops.rs
+++ b/candle-nn/src/ops.rs
@@ -1,5 +1,29 @@
use candle::{Result, Tensor};
+/// Applies the softmax function to the input tensor, rescaling the element so that elements on
+/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
+///
+/// ```rust
+/// use candle::{Tensor, Device};
+/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?;
+/// let a = candle_nn::ops::softmax(&a, 1)?;
+/// assert_eq!(
+/// a.to_vec2::<f32>()?,
+/// &[
+/// [0.13447072, 0.3655293, 0.13447072, 0.3655293],
+/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851]
+/// ]);
+/// # Ok::<(), candle::Error>(())
+/// ```
+pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> {
+ let dim = dim.to_index(xs.shape(), "softmax")?;
+ let max = xs.max_keepdim(dim)?;
+ let diff = xs.broadcast_sub(&max)?;
+ let num = diff.exp()?;
+ let den = num.sum_keepdim(dim)?;
+ num.broadcast_div(&den)
+}
+
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs
new file mode 100644
index 00000000..ca82dd1f
--- /dev/null
+++ b/candle-nn/tests/ops.rs
@@ -0,0 +1,62 @@
+use candle::{Device, Result, Tensor};
+
+pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
+ let b = 10f32.powi(digits);
+ let t = t.to_vec3::<f32>()?;
+ let t = t
+ .iter()
+ .map(|t| {
+ t.iter()
+ .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
+ .collect()
+ })
+ .collect();
+ Ok(t)
+}
+
+#[test]
+fn softmax() -> Result<()> {
+ let device = &Device::Cpu;
+ let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
+ let tensor = Tensor::new(data, device)?;
+ let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?;
+ let t1 = candle_nn::ops::softmax(&tensor.log()?, 1)?;
+ let t2 = candle_nn::ops::softmax(&tensor.log()?, 2)?;
+ assert_eq!(
+ to_vec3_round(t0, 4)?,
+ &[
+ // 3/5, 1/2, 4/11
+ [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]],
+ // 2/5, 1/2, 7/11
+ [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]]
+ ]
+ );
+ assert_eq!(
+ to_vec3_round(t1, 4)?,
+ &[
+ // 3/4, 1/6, 4/13
+ [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]],
+ // 2/10, 1/3, 7/15
+ [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]]
+ ]
+ );
+ assert_eq!(
+ to_vec3_round(t2, 4)?,
+ &[
+ // (3, 1, 4) / 8, (1, 5, 9) / 15
+ [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
+ // (2, 1, 7) / 10, (8, 2, 8) / 18
+ [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
+ ]
+ );
+ Ok(())
+}
+
+#[test]
+fn softmax_numerical_stability() -> Result<()> {
+ let dev = &Device::Cpu;
+ let xs = Tensor::new(&[1234f32, 0.], dev)?;
+ let softmax = candle_nn::ops::softmax(&xs, 0)?;
+ assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]);
+ Ok(())
+}
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index f954f322..d2ac33e9 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -17,7 +17,7 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
- let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
distr.sample(&mut self.rng) as u32
diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs
index 8b0b3c3e..d95672b9 100644
--- a/candle-wasm-examples/llama2-c/src/model.rs
+++ b/candle-wasm-examples/llama2-c/src/model.rs
@@ -158,7 +158,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = att.softmax(D::Minus1)?;
+ let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index d64da8c6..79f7c1fd 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -1,7 +1,7 @@
use crate::model::{Cache, Config, Llama};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, SeedableRng};
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
@@ -88,7 +88,7 @@ impl LogitsProcessor {
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
- let prs = (&logits / temperature)?.softmax(D::Minus1)?;
+ let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr =
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs
index 97eff839..9f3d92f5 100644
--- a/candle-wasm-examples/whisper/src/model.rs
+++ b/candle-wasm-examples/whisper/src/model.rs
@@ -200,7 +200,7 @@ impl MultiHeadAttention {
}
let w = {
let _timer = crate::Timer::new("qk::softmax");
- qk.softmax(candle::D::Minus1)?
+ candle_nn::ops::softmax(&qk, candle::D::Minus1)?
};
let wv = {
let _timer = crate::Timer::new("wv::matmul");
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
index 62eaa16f..139755cb 100644
--- a/candle-wasm-examples/whisper/src/worker.rs
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -1,7 +1,7 @@
use crate::model::{Config, Whisper};
use anyhow::Error as E;
use candle::{safetensors::Load, DType, Device, Tensor};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
@@ -127,9 +127,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
- no_speech_prob = logits
- .get(0)?
- .softmax(0)?
+ no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@@ -139,7 +137,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
- let prs = (&logits / t)?.softmax(0)?;
+ let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(rng) as u32
@@ -153,8 +151,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
- let prob = logits
- .softmax(candle::D::Minus1)?
+ let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {