diff options
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 52 |
1 files changed, 49 insertions, 3 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 9f8810a7..4ea60fb4 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; // Tokenizer dependent bits. const SOT_TOKEN: &str = "<|startoftranscript|>"; const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +const TRANSLATE_TOKEN: &str = "<|translate|>"; +const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; const EOT_TOKEN: &str = "<|endoftext|>"; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; @@ -66,12 +68,16 @@ struct Segment { struct Decoder { model: Whisper, rng: rand::rngs::StdRng, + task: Option<Task>, + timestamps: bool, tokenizer: Tokenizer, suppress_tokens: Tensor, sot_token: u32, transcribe_token: u32, + translate_token: u32, eot_token: u32, no_speech_token: u32, + no_timestamps_token: u32, language_token: Option<u32>, } @@ -82,6 +88,8 @@ impl Decoder { seed: u64, device: &Device, language_token: Option<u32>, + task: Option<Task>, + timestamps: bool, ) -> Result<Self> { let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) .map(|i| { @@ -95,18 +103,24 @@ impl Decoder { let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), tokenizer, + task, + timestamps, suppress_tokens, sot_token, transcribe_token, + translate_token, eot_token, no_speech_token, language_token, + no_timestamps_token, }) } @@ -118,10 +132,19 @@ impl Decoder { let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; let mut tokens = vec![self.sot_token]; + match self.task { + Some(Task::Transcribe) => tokens.push(self.transcribe_token), + Some(Task::Translate) => tokens.push(self.translate_token), + None => { + // Nothing in this case, same as the Python implementation. + } + } if let Some(language_token) = self.language_token { - tokens.push(language_token) + tokens.push(language_token); + } + if !self.timestamps { + tokens.push(self.no_timestamps_token); } - tokens.push(self.transcribe_token); for i in 0..sample_len { let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; @@ -241,6 +264,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> { } #[derive(Clone, Copy, Debug, ValueEnum)] +enum Task { + Transcribe, + Translate, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] enum WhichModel { Tiny, #[value(name = "tiny.en")] @@ -313,6 +342,15 @@ struct Args { /// Language. #[arg(long)] language: Option<String>, + + /// Task, when no task is specified, the input tokens contain only the sot token which can + /// improve things when in no-timestamp mode. + #[arg(long)] + task: Option<Task>, + + /// Timestamps mode, this is not fully implemented yet. + #[arg(long)] + timestamps: bool, } fn main() -> Result<()> { @@ -414,7 +452,15 @@ fn main() -> Result<()> { anyhow::bail!("a language cannot be set for non-multilingual models") } }; - let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?; + let mut dc = Decoder::new( + model, + tokenizer, + args.seed, + &device, + language_token, + args.task, + args.timestamps, + )?; dc.run(&mel)?; Ok(()) } |