diff options
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c03779e7..82c45348 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,7 +11,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{safetensors::Load, DType, Device, Tensor}; -use candle_nn::VarBuilder; +use candle_nn::{ops::softmax, VarBuilder}; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; @@ -120,9 +120,7 @@ impl Decoder { // Extract the no speech probability on the first iteration by looking at the first // token logits and the probability for the according token. if i == 0 { - no_speech_prob = logits - .get(0)? - .softmax(0)? + no_speech_prob = softmax(&logits.get(0)?, 0)? .get(NO_SPEECH_TOKEN as usize)? .to_scalar::<f32>()? as f64; } @@ -132,7 +130,7 @@ impl Decoder { .get(seq_len - 1)? .broadcast_add(&self.suppress_tokens)?; let next_token = if t > 0f64 { - let prs = (&logits / t)?.softmax(0)?; + let prs = softmax(&(&logits / t)?, 0)?; let logits_v: Vec<f32> = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; distr.sample(&mut self.rng) as u32 @@ -146,8 +144,7 @@ impl Decoder { .unwrap() }; tokens.push(next_token); - let prob = logits - .softmax(candle::D::Minus1)? + let prob = softmax(&logits, candle::D::Minus1)? .get(next_token as usize)? .to_scalar::<f32>()? as f64; if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { |