summaryrefslogtreecommitdiff
path: root/candle-examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples')
-rw-r--r--candle-examples/examples/whisper/main.rs78
1 files changed, 74 insertions, 4 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 4788385b..f0d7cf47 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -70,6 +70,7 @@ struct Decoder {
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
+ verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
@@ -82,6 +83,7 @@ struct Decoder {
}
impl Decoder {
+ #[allow(clippy::too_many_arguments)]
fn new(
model: Whisper,
tokenizer: Tokenizer,
@@ -90,10 +92,16 @@ impl Decoder {
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
+ verbose: bool,
) -> Result<Self> {
+ let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
+ // Suppress the notimestamps token when in timestamps mode.
+ // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
- if model.config.suppress_tokens.contains(&i) {
+ if model.config.suppress_tokens.contains(&i)
+ || timestamps && i == no_timestamps_token
+ {
f32::NEG_INFINITY
} else {
0f32
@@ -104,7 +112,6 @@ impl Decoder {
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
- let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
@@ -113,6 +120,7 @@ impl Decoder {
tokenizer,
task,
timestamps,
+ verbose,
suppress_tokens,
sot_token,
transcribe_token,
@@ -127,7 +135,9 @@ impl Decoder {
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder.forward(mel, true)?;
- println!("audio features: {:?}", audio_features.dims());
+ if self.verbose {
+ println!("audio features: {:?}", audio_features.dims());
+ }
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
@@ -168,6 +178,13 @@ impl Decoder {
.final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
+ // TODO: Besides suppress tokens, we should apply the heuristics from
+ // ApplyTimestampRules, i.e.:
+ // - Timestamps come in pairs, except before EOT.
+ // - Timestamps should be non-decreasing.
+ // - If the sum of the probabilities of timestamps is higher than any other tokens,
+ // only consider timestamps when sampling.
+ // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
@@ -249,7 +266,55 @@ impl Decoder {
duration: segment_duration,
dr,
};
- println!("{seek}: {segment:?}, in {:?}", start.elapsed());
+ if self.timestamps {
+ println!(
+ "{:.1}s -- {:.1}s",
+ segment.start,
+ segment.start + segment.duration,
+ );
+ let mut tokens_to_decode = vec![];
+ let mut prev_timestamp_s = 0f32;
+ for &token in segment.dr.tokens.iter() {
+ if token == self.sot_token || token == self.eot_token {
+ continue;
+ }
+ // The no_timestamp_token is the last before the timestamp ones.
+ if token > self.no_timestamps_token {
+ let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
+ if !tokens_to_decode.is_empty() {
+ let text = self
+ .tokenizer
+ .decode(&tokens_to_decode, true)
+ .map_err(E::msg)?;
+ println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
+ tokens_to_decode.clear()
+ }
+ prev_timestamp_s = timestamp_s;
+ } else {
+ tokens_to_decode.push(token)
+ }
+ }
+ if !tokens_to_decode.is_empty() {
+ let text = self
+ .tokenizer
+ .decode(&tokens_to_decode, true)
+ .map_err(E::msg)?;
+ if !text.is_empty() {
+ println!(" {:.1}s-...: {}", prev_timestamp_s, text);
+ }
+ tokens_to_decode.clear()
+ }
+ } else {
+ println!(
+ "{:.1}s -- {:.1}s: {}",
+ segment.start,
+ segment.start + segment.duration,
+ segment.dr.text,
+ )
+ }
+ if self.verbose {
+ println!("{seek}: {segment:?}, in {:?}", start.elapsed());
+ }
segments.push(segment)
}
Ok(segments)
@@ -357,6 +422,10 @@ struct Args {
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
+
+ /// Print the full DecodingResult structure rather than just the text.
+ #[arg(long)]
+ verbose: bool,
}
fn main() -> Result<()> {
@@ -466,6 +535,7 @@ fn main() -> Result<()> {
language_token,
args.task,
args.timestamps,
+ args.verbose,
)?;
dc.run(&mel)?;
Ok(())