diff options
author | drbh <david.richard.holtz@gmail.com> | 2024-02-08 15:54:12 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-08 21:54:12 +0100 |
commit | 9cadd4e64441f33e1a0eee9f3ae3bcf508250e38 (patch) | |
tree | 63f5e663bc114019b3cd093845331d168567290f /candle-transformers/src/models/whisper | |
parent | 020a979de2e0d17b5c1a38bb9f40f4de73954cd5 (diff) | |
download | candle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.tar.gz candle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.tar.bz2 candle-9cadd4e64441f33e1a0eee9f3ae3bcf508250e38.zip |
feat: support multithread spectrogram and small perf tweaks (#1674)
* feat: support multithread spectrogram and small perf tweaks
* feat: clippy improvement for loop variable
* fix: add back speed up scale down logic
* fix: readd mirroring logic
* feat: prefer scoped thread and simplify/improve logic/traits
Diffstat (limited to 'candle-transformers/src/models/whisper')
-rw-r--r-- | candle-transformers/src/models/whisper/audio.rs | 162 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/quantized_model.rs | 8 |
3 files changed, 150 insertions, 28 deletions
diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs index 6dbff650..eb795f18 100644 --- a/candle-transformers/src/models/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -1,7 +1,14 @@ // Audio processing code, adapted from whisper.cpp // https://github.com/ggerganov/whisper.cpp -pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} +use candle::utils::get_num_threads; +use std::sync::Arc; +use std::thread; + +pub trait Float: + num_traits::Float + num_traits::FloatConst + num_traits::NumAssign + Send + Sync +{ +} impl Float for f32 {} impl Float for f64 {} @@ -102,22 +109,26 @@ fn log_mel_spectrogram_w<T: Float>( let half = T::from(0.5).unwrap(); let mut fft_in = vec![zero; fft_size]; let mut mel = vec![zero; n_len * n_mel]; + let n_samples = samples.len(); + let end = std::cmp::min(n_samples / fft_step + 1, n_len); - for i in (ith..n_len).step_by(n_threads) { + for i in (ith..end).step_by(n_threads) { let offset = i * fft_step; // apply Hanning window - for j in 0..fft_size { - fft_in[j] = if offset + j < samples.len() { - hann[j] * samples[offset + j] - } else { - zero - } + for j in 0..std::cmp::min(fft_size, n_samples - offset) { + fft_in[j] = hann[j] * samples[offset + j]; } - // FFT -> mag^2 + // fill the rest with zeros + if n_samples - offset < fft_size { + fft_in[n_samples - offset..].fill(zero); + } + + // FFT let mut fft_out: Vec<T> = fft(&fft_in); + // Calculate modulus^2 of complex numbers for j in 0..fft_size { fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; } @@ -136,8 +147,19 @@ fn log_mel_spectrogram_w<T: Float>( // mel spectrogram for j in 0..n_mel { let mut sum = zero; - for k in 0..n_fft { + let mut k = 0; + // Unroll loop + while k < n_fft.saturating_sub(3) { + sum += fft_out[k] * filters[j * n_fft + k] + + fft_out[k + 1] * filters[j * n_fft + k + 1] + + fft_out[k + 2] * filters[j * n_fft + k + 2] + + fft_out[k + 3] * filters[j * n_fft + k + 3]; + k += 4; + } + // Handle remainder + while k < n_fft { sum += fft_out[k] * filters[j * n_fft + k]; + k += 1; } mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); } @@ -145,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>( mel } -fn log_mel_spectrogram_<T: Float + std::fmt::Display>( +fn log_mel_spectrogram_<T: Float>( samples: &[T], filters: &[T], fft_size: usize, @@ -180,10 +202,55 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>( samples_padded }; - // Use a single thread for now. - let mut mel = log_mel_spectrogram_w( - 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, - ); + // ensure that the number of threads is even and less than 12 + let n_threads = std::cmp::min(get_num_threads() - get_num_threads() % 2, 12); + + let hann = Arc::new(hann); + let samples = Arc::new(samples); + let filters = Arc::new(filters); + + // use scope to allow for non static references to be passed to the threads + // and directly collect the results into a single vector + let all_outputs = thread::scope(|s| { + (0..n_threads) + // create threads and return their handles + .map(|thread_id| { + let hann = Arc::clone(&hann); + let samples = Arc::clone(&samples); + let filters = Arc::clone(&filters); + // spawn new thread and start work + s.spawn(move || { + log_mel_spectrogram_w( + thread_id, &hann, &samples, &filters, fft_size, fft_step, speed_up, n_len, + n_mel, n_threads, + ) + }) + }) + .collect::<Vec<_>>() + .into_iter() + // wait for each thread to finish and collect their results + .map(|handle| handle.join().expect("Thread failed")) + .collect::<Vec<_>>() + }); + + let l = all_outputs[0].len(); + let mut mel = vec![zero; l]; + + // iterate over mel spectrogram segments, dividing work by threads. + for segment_start in (0..l).step_by(n_threads) { + // go through each thread's output. + for thread_output in all_outputs.iter() { + // add each thread's piece to our mel spectrogram. + for offset in 0..n_threads { + let mel_index = segment_start + offset; // find location in mel. + if mel_index < mel.len() { + // Make sure we don't go out of bounds. + mel[mel_index] += thread_output[mel_index]; + } + } + } + } + let mmax = mel .iter() .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) @@ -197,11 +264,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>( mel } -pub fn pcm_to_mel<T: Float + std::fmt::Display>( - cfg: &super::Config, - samples: &[T], - filters: &[T], -) -> Vec<T> { +pub fn pcm_to_mel<T: Float>(cfg: &super::Config, samples: &[T], filters: &[T]) -> Vec<T> { log_mel_spectrogram_( samples, filters, @@ -211,3 +274,62 @@ pub fn pcm_to_mel<T: Float + std::fmt::Display>( false, ) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fft() { + let input = vec![0.0, 1.0, 0.0, 0.0]; + let output = fft(&input); + assert_eq!( + output, + vec![ + 1.0, + 0.0, + 6.123233995736766e-17, + -1.0, + -1.0, + 0.0, + -6.123233995736766e-17, + 1.0 + ] + ); + } + + #[test] + fn test_dft() { + let input = vec![0.0, 1.0, 0.0, 0.0]; + let output = dft(&input); + assert_eq!( + output, + vec![ + 1.0, + 0.0, + 6.123233995736766e-17, + -1.0, + -1.0, + -1.2246467991473532e-16, + -1.8369701987210297e-16, + 1.0 + ] + ); + } + + #[test] + fn test_log_mel_spectrogram() { + let samples = vec![0.0; 1000]; + let filters = vec![0.0; 1000]; + let output = log_mel_spectrogram_(&samples, &filters, 100, 10, 10, false); + assert_eq!(output.len(), 30_000); + } + + #[test] + fn test_tiny_log_mel_spectrogram() { + let samples = vec![0.0; 100]; + let filters = vec![0.0; 100]; + let output = log_mel_spectrogram_(&samples, &filters, 20, 2, 2, false); + assert_eq!(output.len(), 6_000); + } +} diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index ea2a59b9..74f708e6 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -195,14 +195,14 @@ impl ResidualAttentionBlock { } } -fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { +fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> { let max_timescale = 10000f32; let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; let inv_timescales: Vec<_> = (0..channels / 2) .map(|i| (i as f32 * (-log_timescale_increment)).exp()) .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, device)? .to_dtype(candle::DType::F32)? .unsqueeze(1)?; let sh = (length, channels / 2); @@ -246,7 +246,7 @@ impl AudioEncoder { }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index 43ea4177..dac78be9 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -191,14 +191,14 @@ impl ResidualAttentionBlock { } } -fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { +fn sinusoids(length: usize, channels: usize, device: &Device) -> Result<Tensor> { let max_timescale = 10000f32; let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; let inv_timescales: Vec<_> = (0..channels / 2) .map(|i| (i as f32 * (-log_timescale_increment)).exp()) .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + let inv_timescales = Tensor::new(inv_timescales.as_slice(), device)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, device)? .to_dtype(candle::DType::F32)? .unsqueeze(1)?; let sh = (length, channels / 2); @@ -242,7 +242,7 @@ impl AudioEncoder { }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let positional_embedding = sinusoids(n_ctx, n_state, vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}"))) |