diff options
-rw-r--r-- | candle-examples/Cargo.toml | 2 | ||||
-rw-r--r-- | candle-examples/examples/whisper/audio.rs | 219 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 76 | ||||
-rw-r--r-- | candle-examples/examples/whisper/mel_filters.safetensors | bin | 0 -> 64400 bytes |
4 files changed, 273 insertions, 24 deletions
diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index a71ca17b..a3e64a17 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -12,6 +12,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core", default-features=false } +num-traits = "0.2.15" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } @@ -20,6 +21,7 @@ clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } +wav = "1.0.0" [features] default = ["cuda"] diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs new file mode 100644 index 00000000..d50b7923 --- /dev/null +++ b/candle-examples/examples/whisper/audio.rs @@ -0,0 +1,219 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft<T: Float>(inp: &[T]) -> Vec<T> { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = vec![]; + even.reserve(n / 2); + let mut odd = vec![]; + odd.reserve(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft<T: Float>(inp: &[T]) -> Vec<T> { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w<T: Float>( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec<T> { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec<T> = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_<T: Float>( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec<T> { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec<T> = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * super::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel<T: Float>( + samples: &[T], + filters: &[T], + n_mel: usize, + n_fft: usize, +) -> anyhow::Result<Vec<T>> { + if filters.len() != n_mel * n_fft { + anyhow::bail!( + "unexpected filter length {} (n_mel: {}, n_fft: {})", + filters.len(), + n_mel, + n_fft + ) + } + let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false); + Ok(mel) +} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c77118f6..6e15fa8a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,6 +10,8 @@ use candle::{DType, Device, Tensor}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; + +mod audio; mod model; use model::{Config, VarBuilder, Whisper}; @@ -38,27 +40,6 @@ const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; const NO_TIMESTAMP_TOKEN: u32 = 50362; -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - - #[arg(long)] - weights: String, - - #[arg(long)] - input: String, - - #[arg(long)] - tokenizer_config: String, - - /// The seed to use when generating random samples. - #[arg(long, default_value_t = 299792458)] - seed: u64, -} - #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec<u32>, @@ -176,6 +157,35 @@ impl Decode { } } +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + weights: String, + + /// The input to be processed, in wav formats. + #[arg(long)] + input: String, + + #[arg(long)] + tokenizer_config: String, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The mel filters in safetensors format. + #[arg( + long, + default_value = "candle-examples/examples/whisper/mel_filters.safetensors" + )] + filters: String, +} + fn main() -> Result<()> { let args = Args::parse(); let device = if args.cpu { @@ -187,9 +197,27 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; - let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? }; - let input = input.deserialize()?; - let mel = input.tensor("mel", &device)?; + let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; + let mel_filters = mel_filters.deserialize()?; + let mel_filters = mel_filters.tensor("mel_80", &device)?; + println!("loaded mel filters {:?}", mel_filters.shape()); + let (n_mel, n_fft) = mel_filters.shape().r2()?; + let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?; + + let mut input = std::fs::File::open(args.input)?; + let (header, data) = wav::read(&mut input)?; + println!("loaded wav data: {header:?}"); + if header.sampling_rate != SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/mel_filters.safetensors Binary files differnew file mode 100644 index 00000000..98f3af44 --- /dev/null +++ b/candle-examples/examples/whisper/mel_filters.safetensors |