summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper/main.rs')
-rw-r--r--candle-examples/examples/whisper/main.rs76
1 files changed, 52 insertions, 24 deletions
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c77118f6..6e15fa8a 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,6 +10,8 @@ use candle::{DType, Device, Tensor};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
+
+mod audio;
mod model;
use model::{Config, VarBuilder, Whisper};
@@ -38,27 +40,6 @@ const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-struct Args {
- /// Run on CPU rather than on GPU.
- #[arg(long)]
- cpu: bool,
-
- #[arg(long)]
- weights: String,
-
- #[arg(long)]
- input: String,
-
- #[arg(long)]
- tokenizer_config: String,
-
- /// The seed to use when generating random samples.
- #[arg(long, default_value_t = 299792458)]
- seed: u64,
-}
-
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@@ -176,6 +157,35 @@ impl Decode {
}
}
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(long)]
+ weights: String,
+
+ /// The input to be processed, in wav formats.
+ #[arg(long)]
+ input: String,
+
+ #[arg(long)]
+ tokenizer_config: String,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The mel filters in safetensors format.
+ #[arg(
+ long,
+ default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
+ )]
+ filters: String,
+}
+
fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
@@ -187,9 +197,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
- let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
- let input = input.deserialize()?;
- let mel = input.tensor("mel", &device)?;
+ let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
+ let mel_filters = mel_filters.deserialize()?;
+ let mel_filters = mel_filters.tensor("mel_80", &device)?;
+ println!("loaded mel filters {:?}", mel_filters.shape());
+ let (n_mel, n_fft) = mel_filters.shape().r2()?;
+ let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+
+ let mut input = std::fs::File::open(args.input)?;
+ let (header, data) = wav::read(&mut input)?;
+ println!("loaded wav data: {header:?}");
+ if header.sampling_rate != SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
+ }
+ let data = data.as_sixteen().expect("expected 16 bit wav file");
+ let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
+ .iter()
+ .map(|v| *v as f32 / 32768.)
+ .collect();
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?;
+ let mel_len = mel.len();
+ let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };