diff options
28 files changed, 117 insertions, 188 deletions
diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 8815c08d..cee1cad0 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -21,8 +21,6 @@ pub trait BackendStorage: Sized { fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self>; - fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>; - fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self>; fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index d6beb70e..fd1650bb 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -90,7 +90,6 @@ impl Tensor { | Op::ToDevice(node) | Op::Transpose(node, _, _) | Op::Narrow(node, _, _, _) - | Op::Softmax(node, _) | Op::Unary(node, _) | Op::Elu(node, _) | Op::CustomOp1(node, _) => { @@ -324,7 +323,6 @@ impl Tensor { } Op::Reduce(_, ReduceOp::ArgMin, _) => {} Op::Reduce(_, ReduceOp::ArgMax, _) => {} - Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?, Op::Reshape(arg) => { let arg_grad = grad.reshape(arg.dims())?; let sum_grad = grads.or_insert(arg)?; diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 27d0f7da..c39cb9f7 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1236,45 +1236,6 @@ impl Map2 for MatMul { } } -fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> { - // [self] stores data in a contiguous way starting at offset 0. - let dims = shape.dims(); - let elem_per_slice = dims[dim]; - let prod_pre_dim = dims[..dim].iter().product(); - let prod_post_dim = dims[dim + 1..].iter().product(); - if prod_post_dim == 1 { - for pre_idx in 0..prod_pre_dim { - let mut sum = 0f64; - let idx = pre_idx * elem_per_slice; - for v in s[idx..idx + elem_per_slice].iter() { - sum += v.to_f64(); - } - let sum = T::from_f64(sum); - for v in s[idx..idx + elem_per_slice].iter_mut() { - *v /= sum - } - } - } else { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += s[idx].to_f64(); - idx += prod_post_dim - } - let sum = T::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - s[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Ok(()) -} - fn elu<T: num_traits::Float>(v: T, alpha: T) -> T { if v.is_sign_positive() { v @@ -1513,17 +1474,6 @@ impl BackendStorage for CpuStorage { Cmp(op).map(self, lhs_l, rhs, rhs_l) } - fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { - // [self] stores data in a contiguous way starting at offset 0. - match self { - Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim), - Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), - Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), - Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), - Self::U8(_) | Self::U32(_) => Ok(()), - } - } - fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> { Affine(mul, add).map(self, layout) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index b3d542b9..4050b595 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1303,10 +1303,6 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { - Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into()) - } - fn unary_impl<U: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { let device = self.device().clone(); let slice = U::V.map(&self.slice, &device, layout)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index c195cade..1213c502 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -49,10 +49,6 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { - Err(Error::NotCompiledWithCudaSupport) - } - fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 525383b2..4f489f30 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -93,7 +93,6 @@ pub enum Op { Broadcast(Tensor), Narrow(Tensor, usize, usize, usize), Reshape(Tensor), - Softmax(Tensor, usize), ToDevice(Tensor), Transpose(Tensor, usize, usize), Elu(Tensor, f64), diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 52af5861..545f549b 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -125,15 +125,6 @@ impl Storage { } } - // This assumes a contiguous layout and no offset. - pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { - match self { - Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, - Self::Cuda(storage) => storage.divide_by_sum_over_dim(shape, dim)?, - } - Ok(()) - } - pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> { match self { Storage::Cpu(storage) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 09f61340..8ae92c2e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -553,40 +553,6 @@ impl Tensor { } } - /// Applies the softmax function to the input tensor, rescaling the element so that elements on - /// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. - /// - /// ```rust - /// use candle::{Tensor, Device}; - /// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; - /// let a = a.softmax(1)?; - /// assert_eq!( - /// a.to_vec2::<f32>()?, - /// &[ - /// [0.13447072, 0.3655293, 0.13447072, 0.3655293], - /// [0.004892866, 0.26714143, 0.7261657, 0.0017999847], - /// ]); - /// # Ok::<(), candle::Error>(()) - /// ``` - pub fn softmax<D: Dim>(&self, dim: D) -> Result<Self> { - let dim = dim.to_index(self.shape(), "softmax")?; - // TODO: unify the two branches. - if self.device().is_cuda() { - // We do not have a cuda kernel for divide_by_sum_over_dim so split - // the operation. - let exp = self.exp()?; - let sum_exp = exp.sum_keepdim(dim)?; - exp.broadcast_div(&sum_exp) - } else { - let shape = self.shape(); - let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?; - // The resulting storage is contiguous. - storage.divide_by_sum_over_dim(shape, dim)?; - let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim)); - Ok(from_storage(storage, shape.clone(), op, false)) - } - } - fn squeeze_dims(self, dims: &[usize]) -> Result<Self> { match dims { [] => Ok(self), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index a38b6d3d..a439ba30 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,6 +1,5 @@ mod test_utils; use candle::{DType, Device, IndexOp, Result, Tensor}; -use test_utils::to_vec3_round; fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; @@ -68,42 +67,6 @@ fn transpose(device: &Device) -> Result<()> { Ok(()) } -fn softmax(device: &Device) -> Result<()> { - let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; - let tensor = Tensor::new(data, device)?; - let t0 = tensor.log()?.softmax(0)?; - let t1 = tensor.log()?.softmax(1)?; - let t2 = tensor.log()?.softmax(2)?; - assert_eq!( - to_vec3_round(t0, 4)?, - &[ - // 3/5, 1/2, 4/11 - [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]], - // 2/5, 1/2, 7/11 - [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]] - ] - ); - assert_eq!( - to_vec3_round(t1, 4)?, - &[ - // 3/4, 1/6, 4/13 - [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]], - // 2/10, 1/3, 7/15 - [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]] - ] - ); - assert_eq!( - to_vec3_round(t2, 4)?, - &[ - // (3, 1, 4) / 8, (1, 5, 9) / 15 - [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], - // (2, 1, 7) / 10, (8, 2, 8) / 18 - [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] - ] - ); - Ok(()) -} - fn sum(device: &Device) -> Result<()> { let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; let tensor = Tensor::new(data, device)?; @@ -620,7 +583,6 @@ test_device!(cat, cat_cpu, cat_gpu); test_device!(sum, sum_cpu, sum_gpu); test_device!(transpose, transpose_cpu, transpose_gpu); test_device!(binary_op, binary_op_cpu, binary_op_gpu); -test_device!(softmax, softmax_cpu, softmax_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(cmp, cmp_cpu, cmp_gpu); test_device!(matmul, matmul_cpu, matmul_gpu); diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs index 3bf412b2..b2438e71 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-examples/examples/bert/model.rs @@ -333,7 +333,7 @@ impl BertSelfAttention { let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; let attention_probs = { let _enter_sm = self.span_softmax.enter(); - attention_scores.softmax(candle::D::Minus1)? + candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? }; let attention_probs = self.dropout.forward(&attention_probs)?; diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 3f68a5be..12993e2d 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -30,16 +30,6 @@ fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> { Ok(mask) } -// TODO: Use a numerically stable implementation by default. -fn softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { - let d = d.to_index(xs.shape(), "log-softmax")?; - let max = xs.max_keepdim(d)?; - let diff = xs.broadcast_sub(&max)?; - let num = diff.exp()?; - let den = num.sum_keepdim(d)?; - num.broadcast_div(&den) -} - #[derive(Debug)] pub struct Config { pub vocab_size: usize, @@ -192,7 +182,7 @@ impl Attention { let mask_value = Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; - let attn_weights = softmax(&attn_weights, D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; let value = value.contiguous()?; let attn_output = if self.multi_query { attn_weights diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index bce93c81..cab0b314 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -309,11 +309,13 @@ impl FalconAttention { // Only handle the case where alibi is None here, and non-flash attention. let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = attention_scores - .broadcast_add(&mask.squeeze(1)?)? - .to_dtype(DType::F32)? - .softmax(D::Minus1)? - .to_dtype(x.dtype())?; + let attention_scores = candle_nn::ops::softmax( + &attention_scores + .broadcast_add(&mask.squeeze(1)?)? + .to_dtype(DType::F32)?, + D::Minus1, + )? + .to_dtype(x.dtype())?; let attn_output = attention_scores .matmul(&value)? .reshape((b_sz, self.num_heads, seq_len, head_dim))? diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index d519cafe..c4d33f0b 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -233,7 +233,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? }; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs index 13f939db..6d9e4bcd 100644 --- a/candle-examples/examples/llama2-c/model.rs +++ b/candle-examples/examples/llama2-c/model.rs @@ -158,7 +158,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index bcf6ed2b..ae2ef3e7 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -323,7 +323,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; diff --git a/candle-examples/examples/musicgen/musicgen_model.rs b/candle-examples/examples/musicgen/musicgen_model.rs index 212f6818..01266e63 100644 --- a/candle-examples/examples/musicgen/musicgen_model.rs +++ b/candle-examples/examples/musicgen/musicgen_model.rs @@ -187,7 +187,7 @@ impl MusicgenAttention { let attn_weights = attn_weights .reshape((b_sz, self.num_heads, tgt_len, src_len))? .broadcast_add(attention_mask)?; - let attn_weights = attn_weights.softmax(D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?; // TODO: layer_head_mask? let attn_output = attn_weights .matmul(&value_states)? diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 61c0a1bb..ef65df39 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -223,7 +223,7 @@ impl T5Attention { .transpose(1, 2)?; let scores = q.matmul(&k.t()?)?; // TODO: position_bias_masked - let attn_weights = scores.softmax(D::Minus1)?; + let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; let attn_output = attn_weights.matmul(&v)?; let attn_output = self.o.forward(&attn_output)?; Ok(attn_output) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c03779e7..82c45348 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,7 +11,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{safetensors::Load, DType, Device, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{ops::softmax, VarBuilder}; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; @@ -120,9 +120,7 @@ impl Decoder { // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - no_speech_prob = logits - .get(0)? - .softmax(0)? + no_speech_prob = softmax(&logits.get(0)?, 0)? .get(NO_SPEECH_TOKEN as usize)? .to_scalar::<f32>()? as f64; } @@ -132,7 +130,7 @@ impl Decoder { .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { - let prs = (&logits / t)?.softmax(0)?; + let prs = softmax(&(&logits / t)?, 0)?; let logits_v: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; distr.sample(&mut self.rng) as u32 @@ -146,8 +144,7 @@ impl Decoder { .unwrap() }; tokens.push(next_token); - let prob = logits - .softmax(candle::D::Minus1)? + let prob = softmax(&logits, candle::D::Minus1)? .get(next_token as usize)? .to_scalar::<f32>()? as f64; if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 330b2a00..4d80c0c8 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -2,7 +2,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; -use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -154,7 +154,7 @@ impl MultiHeadAttention { let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; qk = qk.broadcast_add(&mask)? } - let w = qk.softmax(candle::D::Minus1)?; + let w = softmax(&qk, candle::D::Minus1)?; let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?; Ok(wv) } diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index 013da854..1bd6ec32 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -21,3 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } +candle-nn = { path = "../candle-nn", features = ["cuda"] } diff --git a/candle-flash-attn/tests/flash_attn_tests.rs b/candle-flash-attn/tests/flash_attn_tests.rs index c6780659..43cb324f 100644 --- a/candle-flash-attn/tests/flash_attn_tests.rs +++ b/candle-flash-attn/tests/flash_attn_tests.rs @@ -21,7 +21,7 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result< let k = k.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?; let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?; Ok(output) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 88196aa7..611c66d8 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,5 +1,29 @@ use candle::{Result, Tensor}; +/// Applies the softmax function to the input tensor, rescaling the element so that elements on +/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1. +/// +/// ```rust +/// use candle::{Tensor, Device}; +/// let a = Tensor::new(&[[0f32, 1., 0., 1.], [-2., 2., 3., -3.]], &Device::Cpu)?; +/// let a = candle_nn::ops::softmax(&a, 1)?; +/// assert_eq!( +/// a.to_vec2::<f32>()?, +/// &[ +/// [0.13447072, 0.3655293, 0.13447072, 0.3655293], +/// [0.0048928666, 0.26714146, 0.7261658, 0.0017999851] +/// ]); +/// # Ok::<(), candle::Error>(()) +/// ``` +pub fn softmax<D: candle::shape::Dim>(xs: &Tensor, dim: D) -> Result<Tensor> { + let dim = dim.to_index(xs.shape(), "softmax")?; + let max = xs.max_keepdim(dim)?; + let diff = xs.broadcast_sub(&max)?; + let num = diff.exp()?; + let den = num.sum_keepdim(dim)?; + num.broadcast_div(&den) +} + pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> { let d = d.to_index(xs.shape(), "log-softmax")?; let max = xs.max_keepdim(d)?; diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs new file mode 100644 index 00000000..ca82dd1f --- /dev/null +++ b/candle-nn/tests/ops.rs @@ -0,0 +1,62 @@ +use candle::{Device, Result, Tensor}; + +pub fn to_vec3_round(t: Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> { + let b = 10f32.powi(digits); + let t = t.to_vec3::<f32>()?; + let t = t + .iter() + .map(|t| { + t.iter() + .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect()) + .collect() + }) + .collect(); + Ok(t) +} + +#[test] +fn softmax() -> Result<()> { + let device = &Device::Cpu; + let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; + let tensor = Tensor::new(data, device)?; + let t0 = candle_nn::ops::softmax(&tensor.log()?, 0)?; + let t1 = candle_nn::ops::softmax(&tensor.log()?, 1)?; + let t2 = candle_nn::ops::softmax(&tensor.log()?, 2)?; + assert_eq!( + to_vec3_round(t0, 4)?, + &[ + // 3/5, 1/2, 4/11 + [[0.6, 0.5, 0.3636], [0.1111, 0.7143, 0.5294]], + // 2/5, 1/2, 7/11 + [[0.4, 0.5, 0.6364], [0.8889, 0.2857, 0.4706]] + ] + ); + assert_eq!( + to_vec3_round(t1, 4)?, + &[ + // 3/4, 1/6, 4/13 + [[0.75, 0.1667, 0.3077], [0.25, 0.8333, 0.6923]], + // 2/10, 1/3, 7/15 + [[0.2, 0.3333, 0.4667], [0.8, 0.6667, 0.5333]] + ] + ); + assert_eq!( + to_vec3_round(t2, 4)?, + &[ + // (3, 1, 4) / 8, (1, 5, 9) / 15 + [[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]], + // (2, 1, 7) / 10, (8, 2, 8) / 18 + [[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]] + ] + ); + Ok(()) +} + +#[test] +fn softmax_numerical_stability() -> Result<()> { + let dev = &Device::Cpu; + let xs = Tensor::new(&[1234f32, 0.], dev)?; + let softmax = candle_nn::ops::softmax(&xs, 0)?; + assert_eq!(softmax.to_vec1::<f32>()?, &[1f32, 0.]); + Ok(()) +} diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index f954f322..d2ac33e9 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -17,7 +17,7 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; let next_token = if let Some(temperature) = self.temperature { - let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; distr.sample(&mut self.rng) as u32 diff --git a/candle-wasm-examples/llama2-c/src/model.rs b/candle-wasm-examples/llama2-c/src/model.rs index 8b0b3c3e..d95672b9 100644 --- a/candle-wasm-examples/llama2-c/src/model.rs +++ b/candle-wasm-examples/llama2-c/src/model.rs @@ -158,7 +158,7 @@ impl CausalSelfAttention { let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(D::Minus1)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index d64da8c6..79f7c1fd 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -1,7 +1,7 @@ use crate::model::{Cache, Config, Llama}; use byteorder::{LittleEndian, ReadBytesExt}; use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D}; -use candle_nn::VarBuilder; +use candle_nn::{ops::softmax, VarBuilder}; use rand::{distributions::Distribution, SeedableRng}; use serde::{Deserialize, Serialize}; use wasm_bindgen::prelude::*; @@ -88,7 +88,7 @@ impl LogitsProcessor { pub fn sample(&mut self, logits: &Tensor) -> Result<u32> { let logits = logits.to_dtype(DType::F32)?; let next_token = if let Some(temperature) = self.temperature { - let prs = (&logits / temperature)?.softmax(D::Minus1)?; + let prs = softmax(&(&logits / temperature)?, D::Minus1)?; let prs: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?; diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 97eff839..9f3d92f5 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -200,7 +200,7 @@ impl MultiHeadAttention { } let w = { let _timer = crate::Timer::new("qk::softmax"); - qk.softmax(candle::D::Minus1)? + candle_nn::ops::softmax(&qk, candle::D::Minus1)? }; let wv = { let _timer = crate::Timer::new("wv::matmul"); diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs index 62eaa16f..139755cb 100644 --- a/candle-wasm-examples/whisper/src/worker.rs +++ b/candle-wasm-examples/whisper/src/worker.rs @@ -1,7 +1,7 @@ use crate::model::{Config, Whisper}; use anyhow::Error as E; use candle::{safetensors::Load, DType, Device, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{ops::softmax, VarBuilder}; use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; use tokenizers::Tokenizer; @@ -127,9 +127,7 @@ impl Decoder { // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - no_speech_prob = logits - .get(0)? - .softmax(0)? + no_speech_prob = softmax(&logits.get(0)?, 0)? .get(NO_SPEECH_TOKEN as usize)? .to_scalar::<f32>()? as f64; } @@ -139,7 +137,7 @@ impl Decoder { .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { - let prs = (&logits / t)?.softmax(0)?; + let prs = softmax(&(&logits / t)?, 0)?; let logits_v: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; distr.sample(rng) as u32 @@ -153,8 +151,7 @@ impl Decoder { .unwrap() }; tokens.push(next_token); - let prob = logits - .softmax(candle::D::Minus1)? + let prob = softmax(&logits, candle::D::Minus1)? .get(next_token as usize)? .to_scalar::<f32>()? as f64; if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { |