summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs71
1 files changed, 23 insertions, 48 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index dbe9cc8d..c71d562a 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,41 +10,16 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
-mod audio;
-mod model;
-use model::{Config, Whisper};
mod multilingual;
-
-const DTYPE: DType = DType::F32;
-
-// Audio parameters.
-const SAMPLE_RATE: usize = 16000;
-const N_FFT: usize = 400;
-const N_MELS: usize = 80;
-const HOP_LENGTH: usize = 160;
-const CHUNK_LENGTH: usize = 30;
-const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
-const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-
-const NO_SPEECH_THRESHOLD: f64 = 0.6;
-const LOGPROB_THRESHOLD: f64 = -1.0;
-const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
-const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
-
-// Tokenizer dependent bits.
-const SOT_TOKEN: &str = "<|startoftranscript|>";
-const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
-const TRANSLATE_TOKEN: &str = "<|translate|>";
-const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
-const EOT_TOKEN: &str = "<|endoftext|>";
-const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
+use candle_transformers::models::whisper::{self as m, audio, model};
+use model::{Config, Whisper};
#[allow(dead_code)]
#[derive(Debug, Clone)]
@@ -94,7 +69,7 @@ impl Decoder {
timestamps: bool,
verbose: bool,
) -> Result<Self> {
- let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
+ let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
@@ -109,11 +84,11 @@ impl Decoder {
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
- let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
- let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
- let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
- let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
- let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
+ let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
+ let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
+ let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
+ let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
+ let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@@ -220,17 +195,17 @@ impl Decoder {
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
- for (i, &t) in TEMPERATURES.iter().enumerate() {
+ for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
- if i == TEMPERATURES.len() - 1 {
+ if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
- let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
- || dr.avg_logprob < LOGPROB_THRESHOLD;
- if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
+ let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < m::LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
@@ -248,13 +223,13 @@ impl Decoder {
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
- let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
- let segment_size = usize::min(content_frames - seek, N_FRAMES);
+ let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
- let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
- if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
+ if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
@@ -492,8 +467,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
- if header.sampling_rate != SAMPLE_RATE as u32 {
- anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
+ if header.sampling_rate != m::SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@@ -501,14 +476,14 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
- let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
- let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
+ let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;