diff options
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 76 |
1 files changed, 52 insertions, 24 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c77118f6..6e15fa8a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,6 +10,8 @@ use candle::{DType, Device, Tensor}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; + +mod audio; mod model; use model::{Config, VarBuilder, Whisper}; @@ -38,27 +40,6 @@ const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; const NO_TIMESTAMP_TOKEN: u32 = 50362; -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - - #[arg(long)] - weights: String, - - #[arg(long)] - input: String, - - #[arg(long)] - tokenizer_config: String, - - /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] - seed: u64, -} - #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec<u32>, @@ -176,6 +157,35 @@ impl Decode { } } +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + weights: String, + + /// The input to be processed, in wav formats. + #[arg(long)] + input: String, + + #[arg(long)] + tokenizer_config: String, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The mel filters in safetensors format. + #[arg( + long, + default_value = "candle-examples/examples/whisper/mel_filters.safetensors" + )] + filters: String, +} + fn main() -> Result<()> { let args = Args::parse(); let device = if args.cpu { @@ -187,9 +197,27 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; - let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? }; - let input = input.deserialize()?; - let mel = input.tensor("mel", &device)?; + let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; + let mel_filters = mel_filters.deserialize()?; + let mel_filters = mel_filters.tensor("mel_80", &device)?; + println!("loaded mel filters {:?}", mel_filters.shape()); + let (n_mel, n_fft) = mel_filters.shape().r2()?; + let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; + + let mut input = std::fs::File::open(args.input)?; + let (header, data) = wav::read(&mut input)?; + println!("loaded wav data: {header:?}"); + if header.sampling_rate != SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; |