summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-28 13:13:01 +0100
committerGitHub <noreply@github.com>2023-07-28 13:13:01 +0100
commit3eb2bc6d07f192a5ce73ab6964745275f2c15213 (patch)
treee5a682d0e40f3c258f668652082ff7fa45918e32 /candle-examples/examples/whisper/main.rs
parent68eab38de6e5cabf17159a5dcf45ec703fbea441 (diff)
downloadcandle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.gz
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.tar.bz2
candle-3eb2bc6d07f192a5ce73ab6964745275f2c15213.zip
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs11
1 files changed, 4 insertions, 7 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c03779e7..82c45348 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
@@ -120,9 +120,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
- no_speech_prob = logits
- .get(0)?
- .softmax(0)?
+ no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@@ -132,7 +130,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
- let prs = (&logits / t)?.softmax(0)?;
+ let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
- let prob = logits
- .softmax(candle::D::Minus1)?
+ let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {