summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs11
1 files changed, 4 insertions, 7 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c03779e7..82c45348 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -11,7 +11,7 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
-use candle_nn::VarBuilder;
+use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
@@ -120,9 +120,7 @@ impl Decoder {
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
- no_speech_prob = logits
- .get(0)?
- .softmax(0)?
+ no_speech_prob = softmax(&logits.get(0)?, 0)?
.get(NO_SPEECH_TOKEN as usize)?
.to_scalar::<f32>()? as f64;
}
@@ -132,7 +130,7 @@ impl Decoder {
.get(seq_len - 1)?
.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
- let prs = (&logits / t)?.softmax(0)?;
+ let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
@@ -146,8 +144,7 @@ impl Decoder {
.unwrap()
};
tokens.push(next_token);
- let prob = logits
- .softmax(candle::D::Minus1)?
+ let prob = softmax(&logits, candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {