summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/whisper/main.rs14
-rw-r--r--candle-transformers/src/models/whisper/mod.rs2
-rw-r--r--candle-wasm-examples/whisper/src/worker.rs8
3 files changed, 16 insertions, 8 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index d2caebcd..5be81f2d 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -128,7 +128,13 @@ impl Decoder {
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
- let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
+ let no_speech_token = m::NO_SPEECH_TOKENS
+ .iter()
+ .find_map(|token| token_id(&tokenizer, token).ok());
+ let no_speech_token = match no_speech_token {
+ None => anyhow::bail!("unable to find any non-speech token"),
+ Some(n) => n,
+ };
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@@ -512,11 +518,7 @@ fn main() -> Result<()> {
)
} else {
let config = repo.get("config.json")?;
- let tokenizer = if args.model == WhichModel::LargeV3 {
- panic!("openai/whisper-large-v3 does not provide a compatible tokenizer.json config at the moment")
- } else {
- repo.get("tokenizer.json")?
- };
+ let tokenizer = repo.get("tokenizer.json")?;
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};
diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs
index bf24045a..8028cf2c 100644
--- a/candle-transformers/src/models/whisper/mod.rs
+++ b/candle-transformers/src/models/whisper/mod.rs
@@ -43,4 +43,4 @@ pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
pub const EOT_TOKEN: &str = "<|endoftext|>";
-pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
+pub const NO_SPEECH_TOKENS: [&str; 2] = ["<|nocaptions|>", "<|nospeech|>"];
diff --git a/candle-wasm-examples/whisper/src/worker.rs b/candle-wasm-examples/whisper/src/worker.rs
index 09d4f580..db5e8bb1 100644
--- a/candle-wasm-examples/whisper/src/worker.rs
+++ b/candle-wasm-examples/whisper/src/worker.rs
@@ -129,7 +129,13 @@ impl Decoder {
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
- let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
+ let no_speech_token = m::NO_SPEECH_TOKENS
+ .iter()
+ .find_map(|token| token_id(&tokenizer, token).ok());
+ let no_speech_token = match no_speech_token {
+ None => anyhow::bail!("unable to find any non-speech token"),
+ Some(n) => n,
+ };
let seed = 299792458;
Ok(Self {
model,