diff options
Diffstat (limited to 'candle-examples/examples/whisper')
-rw-r--r-- | candle-examples/examples/whisper/audio.rs | 214 | ||||
-rw-r--r-- | candle-examples/examples/whisper/main.rs | 71 | ||||
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 416 | ||||
-rw-r--r-- | candle-examples/examples/whisper/multilingual.rs | 2 |
4 files changed, 24 insertions, 679 deletions
diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs deleted file mode 100644 index 2ceed065..00000000 --- a/candle-examples/examples/whisper/audio.rs +++ /dev/null @@ -1,214 +0,0 @@ -// Audio processing code, adapted from whisper.cpp -// https://github.com/ggerganov/whisper.cpp - -pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} - -impl Float for f32 {} -impl Float for f64 {} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 -fn fft<T: Float>(inp: &[T]) -> Vec<T> { - let n = inp.len(); - let zero = T::zero(); - if n == 1 { - return vec![inp[0], zero]; - } - if n % 2 == 1 { - return dft(inp); - } - let mut out = vec![zero; n * 2]; - - let mut even = Vec::with_capacity(n / 2); - let mut odd = Vec::with_capacity(n / 2); - - for (i, &inp) in inp.iter().enumerate() { - if i % 2 == 0 { - even.push(inp) - } else { - odd.push(inp); - } - } - - let even_fft = fft(&even); - let odd_fft = fft(&odd); - - let two_pi = T::PI() + T::PI(); - let n_t = T::from(n).unwrap(); - for k in 0..n / 2 { - let k_t = T::from(k).unwrap(); - let theta = two_pi * k_t / n_t; - let re = theta.cos(); - let im = -theta.sin(); - - let re_odd = odd_fft[2 * k]; - let im_odd = odd_fft[2 * k + 1]; - - out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; - out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; - - out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; - out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; - } - out -} - -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 -fn dft<T: Float>(inp: &[T]) -> Vec<T> { - let zero = T::zero(); - let n = inp.len(); - let two_pi = T::PI() + T::PI(); - - let mut out = Vec::new(); - out.reserve(2 * n); - let n_t = T::from(n).unwrap(); - for k in 0..n { - let k_t = T::from(k).unwrap(); - let mut re = zero; - let mut im = zero; - - for (j, &inp) in inp.iter().enumerate() { - let j_t = T::from(j).unwrap(); - let angle = two_pi * k_t * j_t / n_t; - re += inp * angle.cos(); - im -= inp * angle.sin(); - } - - out.push(re); - out.push(im); - } - out -} - -#[allow(clippy::too_many_arguments)] -// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 -fn log_mel_spectrogram_w<T: Float>( - ith: usize, - hann: &[T], - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - speed_up: bool, - n_len: usize, - n_mel: usize, - n_threads: usize, -) -> Vec<T> { - let n_fft = if speed_up { - 1 + fft_size / 4 - } else { - 1 + fft_size / 2 - }; - - let zero = T::zero(); - let half = T::from(0.5).unwrap(); - let mut fft_in = vec![zero; fft_size]; - let mut mel = vec![zero; n_len * n_mel]; - - for i in (ith..n_len).step_by(n_threads) { - let offset = i * fft_step; - - // apply Hanning window - for j in 0..fft_size { - fft_in[j] = if offset + j < samples.len() { - hann[j] * samples[offset + j] - } else { - zero - } - } - - // FFT -> mag^2 - let mut fft_out: Vec<T> = fft(&fft_in); - - for j in 0..fft_size { - fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; - } - for j in 1..fft_size / 2 { - let v = fft_out[fft_size - j]; - fft_out[j] += v; - } - - if speed_up { - // scale down in the frequency domain results in a speed up in the time domain - for j in 0..n_fft { - fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); - } - } - - // mel spectrogram - for j in 0..n_mel { - let mut sum = zero; - for k in 0..n_fft { - sum += fft_out[k] * filters[j * n_fft + k]; - } - mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); - } - } - mel -} - -fn log_mel_spectrogram_<T: Float + std::fmt::Display>( - samples: &[T], - filters: &[T], - fft_size: usize, - fft_step: usize, - n_mel: usize, - speed_up: bool, -) -> Vec<T> { - let zero = T::zero(); - let two_pi = T::PI() + T::PI(); - let half = T::from(0.5).unwrap(); - let one = T::from(1.0).unwrap(); - let four = T::from(4.0).unwrap(); - let fft_size_t = T::from(fft_size).unwrap(); - - let hann: Vec<T> = (0..fft_size) - .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) - .collect(); - let n_len = samples.len() / fft_step; - - // pad audio with at least one extra chunk of zeros - let pad = 100 * super::CHUNK_LENGTH / 2; - let n_len = if n_len % pad != 0 { - (n_len / pad + 1) * pad - } else { - n_len - }; - let n_len = n_len + pad; - let samples = { - let mut samples_padded = samples.to_vec(); - let to_add = n_len * fft_step - samples.len(); - samples_padded.extend(std::iter::repeat(zero).take(to_add)); - samples_padded - }; - - // Use a single thread for now. - let mut mel = log_mel_spectrogram_w( - 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, - ); - let mmax = mel - .iter() - .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) - .copied() - .unwrap_or(zero) - - T::from(8).unwrap(); - for m in mel.iter_mut() { - let v = T::max(*m, mmax); - *m = v / four + one - } - mel -} - -pub fn pcm_to_mel<T: Float + std::fmt::Display>( - samples: &[T], - filters: &[T], -) -> anyhow::Result<Vec<T>> { - let mel = log_mel_spectrogram_( - samples, - filters, - super::N_FFT, - super::HOP_LENGTH, - super::N_MELS, - false, - ); - Ok(mel) -} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index dbe9cc8d..c71d562a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,41 +10,16 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor}; +use candle::{Device, IndexOp, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; -mod audio; -mod model; -use model::{Config, Whisper}; mod multilingual; - -const DTYPE: DType = DType::F32; - -// Audio parameters. -const SAMPLE_RATE: usize = 16000; -const N_FFT: usize = 400; -const N_MELS: usize = 80; -const HOP_LENGTH: usize = 160; -const CHUNK_LENGTH: usize = 30; -const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk -const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input - -const NO_SPEECH_THRESHOLD: f64 = 0.6; -const LOGPROB_THRESHOLD: f64 = -1.0; -const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; -const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; - -// Tokenizer dependent bits. -const SOT_TOKEN: &str = "<|startoftranscript|>"; -const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; -const TRANSLATE_TOKEN: &str = "<|translate|>"; -const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; -const EOT_TOKEN: &str = "<|endoftext|>"; -const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; +use candle_transformers::models::whisper::{self as m, audio, model}; +use model::{Config, Whisper}; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -94,7 +69,7 @@ impl Decoder { timestamps: bool, verbose: bool, ) -> Result<Self> { - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; // Suppress the notimestamps token when in timestamps mode. // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) @@ -109,11 +84,11 @@ impl Decoder { }) .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; - let sot_token = token_id(&tokenizer, SOT_TOKEN)?; - let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; - let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let eot_token = token_id(&tokenizer, EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), @@ -220,17 +195,17 @@ impl Decoder { } fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { - for (i, &t) in TEMPERATURES.iter().enumerate() { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { let dr: Result<DecodingResult> = self.decode(segment, t); - if i == TEMPERATURES.len() - 1 { + if i == m::TEMPERATURES.len() - 1 { return dr; } // On errors, we try again with a different temperature. match dr { Ok(dr) => { - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { return Ok(dr); } } @@ -248,13 +223,13 @@ impl Decoder { let mut segments = vec![]; while seek < content_frames { let start = std::time::Instant::now(); - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; let dr = self.decode_with_fallback(&mel_segment)?; seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { println!("no speech detected, skipping {seek} {dr:?}"); continue; } @@ -492,8 +467,8 @@ fn main() -> Result<()> { let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; println!("loaded wav data: {header:?}"); - if header.sampling_rate != SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + if header.sampling_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE) } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] @@ -501,14 +476,14 @@ fn main() -> Result<()> { .map(|v| *v as f32 / 32768.) .collect(); println!("pcm data loaded {}", pcm_data.len()); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters); let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = Whisper::load(&vb, config)?; diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs deleted file mode 100644 index e58ab2ca..00000000 --- a/candle-examples/examples/whisper/model.rs +++ /dev/null @@ -1,416 +0,0 @@ -use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec<u32>, -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} -// -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn conv1d( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - let bias = vb.get(out_channels, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} - -fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, 1e-5)) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 -struct MultiHeadAttention { - query: Linear, - key: Linear, - value: Linear, - out: Linear, - n_head: usize, - span: tracing::Span, - softmax_span: tracing::Span, - matmul_span: tracing::Span, - kv_cache: Option<(Tensor, Tensor)>, -} - -impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); - let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); - let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); - let query = linear(n_state, n_state, vb.pp("q_proj"))?; - let value = linear(n_state, n_state, vb.pp("v_proj"))?; - let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; - let out = linear(n_state, n_state, vb.pp("out_proj"))?; - Ok(Self { - query, - key, - value, - out, - n_head, - span, - softmax_span, - matmul_span, - kv_cache: None, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let q = self.query.forward(x)?; - let (k, v) = match xa { - None => { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - (k, v) - } - Some(x) => { - if flush_cache { - self.kv_cache = None; - } - if let Some((k, v)) = &self.kv_cache { - (k.clone(), v.clone()) - } else { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - self.kv_cache = Some((k.clone(), v.clone())); - (k, v) - } - } - }; - let wv = self.qkv_attention(&q, &k, &v, mask)?; - let out = self.out.forward(&wv)?; - Ok(out) - } - - fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.dims3()?; - let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - x.reshape(target_dims)?.transpose(1, 2) - } - - fn qkv_attention( - &self, - q: &Tensor, - k: &Tensor, - v: &Tensor, - mask: Option<&Tensor>, - ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.dims3()?; - let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = (self.reshape_head(q)? * scale)?; - let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; - let v = self.reshape_head(v)?.contiguous()?; - let mut qk = { - let _enter = self.matmul_span.enter(); - q.matmul(&k)? - }; - if let Some(mask) = mask { - let mask = mask.i((0..n_ctx, 0..n_ctx))?; - qk = qk.broadcast_add(&mask)? - } - let w = { - let _enter = self.softmax_span.enter(); - softmax(&qk, D::Minus1)? - }; - let wv = { - let _enter = self.matmul_span.enter(); - w.matmul(&v)? - } - .transpose(1, 2)? - .flatten_from(2)?; - Ok(wv) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 -struct ResidualAttentionBlock { - attn: MultiHeadAttention, - attn_ln: LayerNorm, - cross_attn: Option<(MultiHeadAttention, LayerNorm)>, - mlp_linear1: Linear, - mlp_linear2: Linear, - mlp_ln: LayerNorm, - span: tracing::Span, -} - -impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); - let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; - let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; - let cross_attn = if ca { - let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; - let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; - Some((cross_attn, cross_attn_ln)) - } else { - None - }; - let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; - let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; - let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; - Ok(Self { - attn, - attn_ln, - cross_attn, - mlp_linear1, - mlp_linear2, - mlp_ln, - span, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_kv_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let attn = self - .attn - .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; - let mut x = (x + attn)?; - if let Some((attn, ln)) = &mut self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; - } - let mlp = self.mlp_linear2.forward( - &self - .mlp_linear1 - .forward(&self.mlp_ln.forward(&x)?)? - .gelu()?, - )?; - x + mlp - } -} - -fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { - let max_timescale = 10000f32; - let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; - let inv_timescales: Vec<_> = (0..channels / 2) - .map(|i| (i as f32 * (-log_timescale_increment)).exp()) - .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? - .to_dtype(candle::DType::F32)? - .unsqueeze(1)?; - let sh = (length, channels / 2); - let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; - let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; - Ok(sincos) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 -pub struct AudioEncoder { - conv1: Conv1d, - conv2: Conv1d, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln_post: LayerNorm, - span: tracing::Span, - conv1_span: tracing::Span, - conv2_span: tracing::Span, -} - -impl AudioEncoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); - let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); - let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); - let n_state = cfg.d_model; - let n_head = cfg.encoder_attention_heads; - let n_ctx = cfg.max_source_positions; - let cfg1 = Conv1dConfig { - padding: 1, - stride: 1, - groups: 1, - dilation: 1, - }; - let cfg2 = Conv1dConfig { - padding: 1, - stride: 2, - groups: 1, - dilation: 1, - }; - let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; - let blocks = (0..cfg.encoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; - Ok(Self { - conv1, - conv2, - positional_embedding, - blocks, - ln_post, - conv1_span, - conv2_span, - span, - }) - } - - pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let x = { - let _enter = self.conv1_span.enter(); - self.conv1.forward(x)?.gelu()? - }; - let x = { - let _enter = self.conv2_span.enter(); - self.conv2.forward(&x)?.gelu()? - }; - let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.dims3()?; - let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; - let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, None, None, flush_kv_cache)? - } - let x = self.ln_post.forward(&x)?; - Ok(x) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 -pub struct TextDecoder { - token_embedding: Embedding, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln: LayerNorm, - mask: Tensor, - span: tracing::Span, - span_final: tracing::Span, -} - -impl TextDecoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); - let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); - let n_state = cfg.d_model; - let n_head = cfg.decoder_attention_heads; - let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; - let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; - let blocks = (0..cfg.decoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln = layer_norm(n_state, vb.pp("layer_norm"))?; - let mask: Vec<_> = (0..n_ctx) - .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) - .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; - Ok(Self { - token_embedding, - positional_embedding, - blocks, - ln, - mask, - span, - span_final, - }) - } - - pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let last = x.dim(D::Minus1)?; - let token_embedding = self.token_embedding.forward(x)?; - let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; - let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; - } - self.ln.forward(&x) - } - - pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { - let b_size = x.dim(0)?; - let w = self.token_embedding.embeddings().broadcast_left(b_size)?; - let logits = { - let _enter = self.span_final.enter(); - x.matmul(&w.t()?)? - }; - Ok(logits) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 -pub struct Whisper { - pub encoder: AudioEncoder, - pub decoder: TextDecoder, - pub config: Config, -} - -impl Whisper { - pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { - let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; - let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; - Ok(Self { - encoder, - decoder, - config, - }) - } -} diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index bc0bae1f..a82b09ef 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::<Result<Vec<_>>>()?; - let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; + let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; |