summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/whisper/audio.rs219
-rw-r--r--candle-examples/examples/whisper/main.rs76
-rw-r--r--candle-examples/examples/whisper/mel_filters.safetensorsbin0 -> 64400 bytes
3 files changed, 271 insertions, 24 deletions
diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs
new file mode 100644
index 00000000..d50b7923
--- /dev/null
+++ b/candle-examples/examples/whisper/audio.rs
@@ -0,0 +1,219 @@
+// Audio processing code, adapted from whisper.cpp
+// https://github.com/ggerganov/whisper.cpp
+
+pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
+
+impl Float for f32 {}
+impl Float for f64 {}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
+fn fft<T: Float>(inp: &[T]) -> Vec<T> {
+ let n = inp.len();
+ let zero = T::zero();
+ if n == 1 {
+ return vec![inp[0], zero];
+ }
+ if n % 2 == 1 {
+ return dft(inp);
+ }
+ let mut out = vec![zero; n * 2];
+
+ let mut even = vec![];
+ even.reserve(n / 2);
+ let mut odd = vec![];
+ odd.reserve(n / 2);
+
+ for (i, &inp) in inp.iter().enumerate() {
+ if i % 2 == 0 {
+ even.push(inp)
+ } else {
+ odd.push(inp);
+ }
+ }
+
+ let even_fft = fft(&even);
+ let odd_fft = fft(&odd);
+
+ let two_pi = T::PI() + T::PI();
+ let n_t = T::from(n).unwrap();
+ for k in 0..n / 2 {
+ let k_t = T::from(k).unwrap();
+ let theta = two_pi * k_t / n_t;
+ let re = theta.cos();
+ let im = -theta.sin();
+
+ let re_odd = odd_fft[2 * k];
+ let im_odd = odd_fft[2 * k + 1];
+
+ out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
+ out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
+
+ out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
+ out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
+ }
+ out
+}
+
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
+fn dft<T: Float>(inp: &[T]) -> Vec<T> {
+ let zero = T::zero();
+ let n = inp.len();
+ let two_pi = T::PI() + T::PI();
+
+ let mut out = Vec::new();
+ out.reserve(2 * n);
+ let n_t = T::from(n).unwrap();
+ for k in 0..n {
+ let k_t = T::from(k).unwrap();
+ let mut re = zero;
+ let mut im = zero;
+
+ for (j, &inp) in inp.iter().enumerate() {
+ let j_t = T::from(j).unwrap();
+ let angle = two_pi * k_t * j_t / n_t;
+ re += inp * angle.cos();
+ im -= inp * angle.sin();
+ }
+
+ out.push(re);
+ out.push(im);
+ }
+ out
+}
+
+#[allow(clippy::too_many_arguments)]
+// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
+fn log_mel_spectrogram_w<T: Float>(
+ ith: usize,
+ hann: &[T],
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ speed_up: bool,
+ n_len: usize,
+ n_mel: usize,
+ n_threads: usize,
+) -> Vec<T> {
+ let n_fft = if speed_up {
+ 1 + fft_size / 4
+ } else {
+ 1 + fft_size / 2
+ };
+
+ let zero = T::zero();
+ let half = T::from(0.5).unwrap();
+ let mut fft_in = vec![zero; fft_size];
+ let mut mel = vec![zero; n_len * n_mel];
+
+ for i in (ith..n_len).step_by(n_threads) {
+ let offset = i * fft_step;
+
+ // apply Hanning window
+ for j in 0..fft_size {
+ fft_in[j] = if offset + j < samples.len() {
+ hann[j] * samples[offset + j]
+ } else {
+ zero
+ }
+ }
+
+ // FFT -> mag^2
+ let mut fft_out: Vec<T> = fft(&fft_in);
+
+ for j in 0..fft_size {
+ fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
+ }
+ for j in 1..fft_size / 2 {
+ let v = fft_out[fft_size - j];
+ fft_out[j] += v;
+ }
+
+ if speed_up {
+ // scale down in the frequency domain results in a speed up in the time domain
+ for j in 0..n_fft {
+ fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
+ }
+ }
+
+ // mel spectrogram
+ for j in 0..n_mel {
+ let mut sum = zero;
+ for k in 0..n_fft {
+ sum += fft_out[k] * filters[j * n_fft + k];
+ }
+ mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
+ }
+ }
+ mel
+}
+
+fn log_mel_spectrogram_<T: Float>(
+ samples: &[T],
+ filters: &[T],
+ fft_size: usize,
+ fft_step: usize,
+ n_mel: usize,
+ speed_up: bool,
+) -> Vec<T> {
+ let zero = T::zero();
+ let two_pi = T::PI() + T::PI();
+ let half = T::from(0.5).unwrap();
+ let one = T::from(1.0).unwrap();
+ let four = T::from(4.0).unwrap();
+ let fft_size_t = T::from(fft_size).unwrap();
+
+ let hann: Vec<T> = (0..fft_size)
+ .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
+ .collect();
+ let n_len = samples.len() / fft_step;
+
+ // pad audio with at least one extra chunk of zeros
+ let pad = 100 * super::CHUNK_LENGTH / 2;
+ let n_len = if n_len % pad != 0 {
+ (n_len / pad + 1) * pad
+ } else {
+ n_len
+ };
+ let n_len = n_len + pad;
+ let samples = {
+ let mut samples_padded = samples.to_vec();
+ let to_add = n_len * fft_step - samples.len();
+ samples_padded.extend(std::iter::repeat(zero).take(to_add));
+ samples_padded
+ };
+
+ // Use a single thread for now.
+ let mut mel = log_mel_spectrogram_w(
+ 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
+ );
+ let mmax = mel
+ .iter()
+ .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
+ .copied()
+ .unwrap_or(zero)
+ - T::from(8).unwrap();
+ for m in mel.iter_mut() {
+ let v = T::max(*m, mmax);
+ *m = v / four + one
+ }
+ mel
+}
+
+pub fn pcm_to_mel<T: Float>(
+ samples: &[T],
+ filters: &[T],
+ n_mel: usize,
+ n_fft: usize,
+) -> anyhow::Result<Vec<T>> {
+ if filters.len() != n_mel * n_fft {
+ anyhow::bail!(
+ "unexpected filter length {} (n_mel: {}, n_fft: {})",
+ filters.len(),
+ n_mel,
+ n_fft
+ )
+ }
+ let mel = log_mel_spectrogram_(samples, filters, n_fft, super::HOP_LENGTH, n_mel, false);
+ Ok(mel)
+}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index c77118f6..6e15fa8a 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,6 +10,8 @@ use candle::{DType, Device, Tensor};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
+
+mod audio;
mod model;
use model::{Config, VarBuilder, Whisper};
@@ -38,27 +40,6 @@ const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
-#[derive(Parser, Debug)]
-#[command(author, version, about, long_about = None)]
-struct Args {
- /// Run on CPU rather than on GPU.
- #[arg(long)]
- cpu: bool,
-
- #[arg(long)]
- weights: String,
-
- #[arg(long)]
- input: String,
-
- #[arg(long)]
- tokenizer_config: String,
-
- /// The seed to use when generating random samples.
- #[arg(long, default_value_t = 299792458)]
- seed: u64,
-}
-
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@@ -176,6 +157,35 @@ impl Decode {
}
}
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ #[arg(long)]
+ weights: String,
+
+ /// The input to be processed, in wav formats.
+ #[arg(long)]
+ input: String,
+
+ #[arg(long)]
+ tokenizer_config: String,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The mel filters in safetensors format.
+ #[arg(
+ long,
+ default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
+ )]
+ filters: String,
+}
+
fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
@@ -187,9 +197,27 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
- let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
- let input = input.deserialize()?;
- let mel = input.tensor("mel", &device)?;
+ let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
+ let mel_filters = mel_filters.deserialize()?;
+ let mel_filters = mel_filters.tensor("mel_80", &device)?;
+ println!("loaded mel filters {:?}", mel_filters.shape());
+ let (n_mel, n_fft) = mel_filters.shape().r2()?;
+ let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
+
+ let mut input = std::fs::File::open(args.input)?;
+ let (header, data) = wav::read(&mut input)?;
+ println!("loaded wav data: {header:?}");
+ if header.sampling_rate != SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
+ }
+ let data = data.as_sixteen().expect("expected 16 bit wav file");
+ let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
+ .iter()
+ .map(|v| *v as f32 / 32768.)
+ .collect();
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters, n_mel, n_fft)?;
+ let mel_len = mel.len();
+ let mel = Tensor::from_vec(mel, (1, n_mel, mel_len / n_mel), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/mel_filters.safetensors
new file mode 100644
index 00000000..98f3af44
--- /dev/null
+++ b/candle-examples/examples/whisper/mel_filters.safetensors
Binary files differ