summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/whisper
diff options
context:
space:
mode:
authordrbh <david.richard.holtz@gmail.com>2024-02-08 15:54:12 -0500
committerGitHub <noreply@github.com>2024-02-08 21:54:12 +0100
commit9cadd4e64441f33e1a0eee9f3ae3bcf508250e38 (patch)
tree63f5e663bc114019b3cd093845331d168567290f /candle-transformers/src/models/whisper
parent020a979de2e0d17b5c1a38bb9f40f4de73954cd5 (diff)
downloadcandle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.tar.gz
candle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.tar.bz2
candle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.zip
feat: support multithread spectrogram and small perf tweaks (#1674)
* feat: support multithread spectrogram and small perf tweaks * feat: clippy improvement for loop variable * fix: add back speed up scale down logic * fix: readd mirroring logic * feat: prefer scoped thread and simplify/improve logic/traits
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r--candle-transformers/src/models/whisper/audio.rs162
-rw-r--r--candle-transformers/src/models/whisper/model.rs8
-rw-r--r--candle-transformers/src/models/whisper/quantized_model.rs8
3 files changed, 150 insertions, 28 deletions
diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs
index 6dbff650..eb795f18 100644
--- a/candle-transformers/src/models/whisper/audio.rs
+++ b/candle-transformers/src/models/whisper/audio.rs
@@ -1,7 +1,14 @@
// Audio processing code, adapted from whisper.cpp
// https://github.com/ggerganov/whisper.cpp
-pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
+use candle::utils::get_num_threads;
+use std::sync::Arc;
+use std::thread;
+
+pub trait Float:
+ num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync
+{
+}
impl Float for f32 {}
impl Float for f64 {}
@@ -102,22 +109,26 @@ fn log_mel_spectrogram_w<T: Float>(
let half = T::from(0.5).unwrap();
let mut fft_in = vec![zero; fft_size];
let mut mel = vec![zero; n_len * n_mel];
+ let n_samples = samples.len();
+ let end = std::cmp::min(n_samples / fft_step + 1, n_len);
- for i in (ith..n_len).step_by(n_threads) {
+ for i in (ith..end).step_by(n_threads) {
let offset = i * fft_step;
// apply Hanning window
- for j in 0..fft_size {
- fft_in[j] = if offset + j < samples.len() {
- hann[j] * samples[offset + j]
- } else {
- zero
- }
+ for j in 0..std::cmp::min(fft_size, n_samples - offset) {
+ fft_in[j] = hann[j] * samples[offset + j];
}
- // FFT -> mag^2
+ // fill the rest with zeros
+ if n_samples - offset < fft_size {
+ fft_in[n_samples - offset..].fill(zero);
+ }
+
+ // FFT
let mut fft_out: Vec<T> = fft(&fft_in);
+ // Calculate modulus^2 of complex numbers
for j in 0..fft_size {
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
}
@@ -136,8 +147,19 @@ fn log_mel_spectrogram_w<T: Float>(
// mel spectrogram
for j in 0..n_mel {
let mut sum = zero;
- for k in 0..n_fft {
+ let mut k = 0;
+ // Unroll loop
+ while k < n_fft.saturating_sub(3) {
+ sum += fft_out[k] * filters[j * n_fft + k]
+ + fft_out[k + 1] * filters[j * n_fft + k + 1]
+ + fft_out[k + 2] * filters[j * n_fft + k + 2]
+ + fft_out[k + 3] * filters[j * n_fft + k + 3];
+ k += 4;
+ }
+ // Handle remainder
+ while k < n_fft {
sum += fft_out[k] * filters[j * n_fft + k];
+ k += 1;
}
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
}
@@ -145,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel
}
-fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
+fn log_mel_spectrogram_<T: Float>(
samples: &[T],
filters: &[T],
fft_size: usize,
@@ -180,10 +202,55 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
samples_padded
};
- // Use a single thread for now.
- let mut mel = log_mel_spectrogram_w(
- 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
- );
+ // ensure that the number of threads is even and less than 12
+ let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12);
+
+ let hann = Arc::new(hann);
+ let samples = Arc::new(samples);
+ let filters = Arc::new(filters);
+
+ // use scope to allow for non static references to be passed to the threads
+ // and directly collect the results into a single vector
+ let all_outputs = thread::scope(|s| {
+ (0..n_threads)
+ // create threads and return their handles
+ .map(|thread_id| {
+ let hann = Arc::clone(&hann);
+ let samples = Arc::clone(&samples);
+ let filters = Arc::clone(&filters);
+ // spawn new thread and start work
+ s.spawn(move || {
+ log_mel_spectrogram_w(
+ thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len,
+ n_mel, n_threads,
+ )
+ })
+ })
+ .collect::<Vec<_>>()
+ .into_iter()
+ // wait for each thread to finish and collect their results
+ .map(|handle| handle.join().expect("Thread failed"))
+ .collect::<Vec<_>>()
+ });
+
+ let l = all_outputs[0].len();
+ let mut mel = vec![zero; l];
+
+ // iterate over mel spectrogram segments, dividing work by threads.
+ for segment_start in (0..l).step_by(n_threads) {
+ // go through each thread's output.
+ for thread_output in all_outputs.iter() {
+ // add each thread's piece to our mel spectrogram.
+ for offset in 0..n_threads {
+ let mel_index = segment_start + offset; // find location in mel.
+ if mel_index < mel.len() {
+ // Make sure we don't go out of bounds.
+ mel[mel_index] += thread_output[mel_index];
+ }
+ }
+ }
+ }
+
let mmax = mel
.iter()
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
@@ -197,11 +264,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel
}
-pub fn pcm_to_mel<T: Float + std::fmt::Display>(
- cfg: &super::Config,
- samples: &[T],
- filters: &[T],
-) -> Vec<T> {
+pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> {
log_mel_spectrogram_(
samples,
filters,
@@ -211,3 +274,62 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>(
false,
)
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_fft() {
+ let input = vec![0.0, 1.0, 0.0, 0.0];
+ let output = fft(&input);
+ assert_eq!(
+ output,
+ vec![
+ 1.0,
+ 0.0,
+ 6.123233995736766e-17,
+ -1.0,
+ -1.0,
+ 0.0,
+ -6.123233995736766e-17,
+ 1.0
+ ]
+ );
+ }
+
+ #[test]
+ fn test_dft() {
+ let input = vec![0.0, 1.0, 0.0, 0.0];
+ let output = dft(&input);
+ assert_eq!(
+ output,
+ vec![
+ 1.0,
+ 0.0,
+ 6.123233995736766e-17,
+ -1.0,
+ -1.0,
+ -1.2246467991473532e-16,
+ -1.8369701987210297e-16,
+ 1.0
+ ]
+ );
+ }
+
+ #[test]
+ fn test_log_mel_spectrogram() {
+ let samples = vec![0.0; 1000];
+ let filters = vec![0.0; 1000];
+ let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false);
+ assert_eq!(output.len(), 30_000);
+ }
+
+ #[test]
+ fn test_tiny_log_mel_spectrogram() {
+ let samples = vec![0.0; 100];
+ let filters = vec![0.0; 100];
+ let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false);
+ assert_eq!(output.len(), 6_000);
+ }
+}
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index ea2a59b9..74f708e6 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -195,14 +195,14 @@ impl ResidualAttentionBlock {
}
}
-fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
+fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
- let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
- let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
+ let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
+ let arange = Tensor::arange(0, length as u32, device)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
@@ -246,7 +246,7 @@ impl AudioEncoder {
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
- let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
+ let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs
index 43ea4177..dac78be9 100644
--- a/candle-transformers/src/models/whisper/quantized_model.rs
+++ b/candle-transformers/src/models/whisper/quantized_model.rs
@@ -191,14 +191,14 @@ impl ResidualAttentionBlock {
}
}
-fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
+fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> {
let max_timescale = 10000f32;
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
let inv_timescales: Vec<_> = (0..channels / 2)
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
.collect();
- let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
- let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
+ let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?;
+ let arange = Tensor::arange(0, length as u32, device)?
.to_dtype(candle::DType::F32)?
.unsqueeze(1)?;
let sh = (length, channels / 2);
@@ -242,7 +242,7 @@ impl AudioEncoder {
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
- let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
+ let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?;
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))