summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/whisper/main.rs52
1 files changed, 49 insertions, 3 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index 9f8810a7..4ea60fb4 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
+const TRANSLATE_TOKEN: &str = "<|translate|>";
+const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
@@ -66,12 +68,16 @@ struct Segment {
struct Decoder {
model: Whisper,
rng: rand::rngs::StdRng,
+ task: Option<Task>,
+ timestamps: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
+ translate_token: u32,
eot_token: u32,
no_speech_token: u32,
+ no_timestamps_token: u32,
language_token: Option<u32>,
}
@@ -82,6 +88,8 @@ impl Decoder {
seed: u64,
device: &Device,
language_token: Option<u32>,
+ task: Option<Task>,
+ timestamps: bool,
) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
@@ -95,18 +103,24 @@ impl Decoder {
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
+ let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
+ let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
+ task,
+ timestamps,
suppress_tokens,
sot_token,
transcribe_token,
+ translate_token,
eot_token,
no_speech_token,
language_token,
+ no_timestamps_token,
})
}
@@ -118,10 +132,19 @@ impl Decoder {
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
+ match self.task {
+ Some(Task::Transcribe) => tokens.push(self.transcribe_token),
+ Some(Task::Translate) => tokens.push(self.translate_token),
+ None => {
+ // Nothing in this case, same as the Python implementation.
+ }
+ }
if let Some(language_token) = self.language_token {
- tokens.push(language_token)
+ tokens.push(language_token);
+ }
+ if !self.timestamps {
+ tokens.push(self.no_timestamps_token);
}
- tokens.push(self.transcribe_token);
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@@ -241,6 +264,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
}
#[derive(Clone, Copy, Debug, ValueEnum)]
+enum Task {
+ Transcribe,
+ Translate,
+}
+
+#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
@@ -313,6 +342,15 @@ struct Args {
/// Language.
#[arg(long)]
language: Option<String>,
+
+ /// Task, when no task is specified, the input tokens contain only the sot token which can
+ /// improve things when in no-timestamp mode.
+ #[arg(long)]
+ task: Option<Task>,
+
+ /// Timestamps mode, this is not fully implemented yet.
+ #[arg(long)]
+ timestamps: bool,
}
fn main() -> Result<()> {
@@ -414,7 +452,15 @@ fn main() -> Result<()> {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
- let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
+ let mut dc = Decoder::new(
+ model,
+ tokenizer,
+ args.seed,
+ &device,
+ language_token,
+ args.task,
+ args.timestamps,
+ )?;
dc.run(&mel)?;
Ok(())
}