summaryrefslogtreecommitdiff
path: root/candle-examples/examples/whisper
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/whisper')
-rw-r--r--candle-examples/examples/whisper/audio.rs214
-rw-r--r--candle-examples/examples/whisper/main.rs71
-rw-r--r--candle-examples/examples/whisper/model.rs416
-rw-r--r--candle-examples/examples/whisper/multilingual.rs2
4 files changed, 24 insertions, 679 deletions
diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs
deleted file mode 100644
index 2ceed065..00000000
--- a/candle-examples/examples/whisper/audio.rs
+++ /dev/null
@@ -1,214 +0,0 @@
-// Audio processing code, adapted from whisper.cpp
-// https://github.com/ggerganov/whisper.cpp
-
-pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
-
-impl Float for f32 {}
-impl Float for f64 {}
-
-// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
-fn fft<T: Float>(inp: &[T]) -> Vec<T> {
- let n = inp.len();
- let zero = T::zero();
- if n == 1 {
- return vec![inp[0], zero];
- }
- if n % 2 == 1 {
- return dft(inp);
- }
- let mut out = vec![zero; n * 2];
-
- let mut even = Vec::with_capacity(n / 2);
- let mut odd = Vec::with_capacity(n / 2);
-
- for (i, &inp) in inp.iter().enumerate() {
- if i % 2 == 0 {
- even.push(inp)
- } else {
- odd.push(inp);
- }
- }
-
- let even_fft = fft(&even);
- let odd_fft = fft(&odd);
-
- let two_pi = T::PI() + T::PI();
- let n_t = T::from(n).unwrap();
- for k in 0..n / 2 {
- let k_t = T::from(k).unwrap();
- let theta = two_pi * k_t / n_t;
- let re = theta.cos();
- let im = -theta.sin();
-
- let re_odd = odd_fft[2 * k];
- let im_odd = odd_fft[2 * k + 1];
-
- out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
- out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
-
- out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
- out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
- }
- out
-}
-
-// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
-fn dft<T: Float>(inp: &[T]) -> Vec<T> {
- let zero = T::zero();
- let n = inp.len();
- let two_pi = T::PI() + T::PI();
-
- let mut out = Vec::new();
- out.reserve(2 * n);
- let n_t = T::from(n).unwrap();
- for k in 0..n {
- let k_t = T::from(k).unwrap();
- let mut re = zero;
- let mut im = zero;
-
- for (j, &inp) in inp.iter().enumerate() {
- let j_t = T::from(j).unwrap();
- let angle = two_pi * k_t * j_t / n_t;
- re += inp * angle.cos();
- im -= inp * angle.sin();
- }
-
- out.push(re);
- out.push(im);
- }
- out
-}
-
-#[allow(clippy::too_many_arguments)]
-// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
-fn log_mel_spectrogram_w<T: Float>(
- ith: usize,
- hann: &[T],
- samples: &[T],
- filters: &[T],
- fft_size: usize,
- fft_step: usize,
- speed_up: bool,
- n_len: usize,
- n_mel: usize,
- n_threads: usize,
-) -> Vec<T> {
- let n_fft = if speed_up {
- 1 + fft_size / 4
- } else {
- 1 + fft_size / 2
- };
-
- let zero = T::zero();
- let half = T::from(0.5).unwrap();
- let mut fft_in = vec![zero; fft_size];
- let mut mel = vec![zero; n_len * n_mel];
-
- for i in (ith..n_len).step_by(n_threads) {
- let offset = i * fft_step;
-
- // apply Hanning window
- for j in 0..fft_size {
- fft_in[j] = if offset + j < samples.len() {
- hann[j] * samples[offset + j]
- } else {
- zero
- }
- }
-
- // FFT -> mag^2
- let mut fft_out: Vec<T> = fft(&fft_in);
-
- for j in 0..fft_size {
- fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
- }
- for j in 1..fft_size / 2 {
- let v = fft_out[fft_size - j];
- fft_out[j] += v;
- }
-
- if speed_up {
- // scale down in the frequency domain results in a speed up in the time domain
- for j in 0..n_fft {
- fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
- }
- }
-
- // mel spectrogram
- for j in 0..n_mel {
- let mut sum = zero;
- for k in 0..n_fft {
- sum += fft_out[k] * filters[j * n_fft + k];
- }
- mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
- }
- }
- mel
-}
-
-fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
- samples: &[T],
- filters: &[T],
- fft_size: usize,
- fft_step: usize,
- n_mel: usize,
- speed_up: bool,
-) -> Vec<T> {
- let zero = T::zero();
- let two_pi = T::PI() + T::PI();
- let half = T::from(0.5).unwrap();
- let one = T::from(1.0).unwrap();
- let four = T::from(4.0).unwrap();
- let fft_size_t = T::from(fft_size).unwrap();
-
- let hann: Vec<T> = (0..fft_size)
- .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
- .collect();
- let n_len = samples.len() / fft_step;
-
- // pad audio with at least one extra chunk of zeros
- let pad = 100 * super::CHUNK_LENGTH / 2;
- let n_len = if n_len % pad != 0 {
- (n_len / pad + 1) * pad
- } else {
- n_len
- };
- let n_len = n_len + pad;
- let samples = {
- let mut samples_padded = samples.to_vec();
- let to_add = n_len * fft_step - samples.len();
- samples_padded.extend(std::iter::repeat(zero).take(to_add));
- samples_padded
- };
-
- // Use a single thread for now.
- let mut mel = log_mel_spectrogram_w(
- 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
- );
- let mmax = mel
- .iter()
- .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
- .copied()
- .unwrap_or(zero)
- - T::from(8).unwrap();
- for m in mel.iter_mut() {
- let v = T::max(*m, mmax);
- *m = v / four + one
- }
- mel
-}
-
-pub fn pcm_to_mel<T: Float + std::fmt::Display>(
- samples: &[T],
- filters: &[T],
-) -> anyhow::Result<Vec<T>> {
- let mel = log_mel_spectrogram_(
- samples,
- filters,
- super::N_FFT,
- super::HOP_LENGTH,
- super::N_MELS,
- false,
- );
- Ok(mel)
-}
diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs
index dbe9cc8d..c71d562a 100644
--- a/candle-examples/examples/whisper/main.rs
+++ b/candle-examples/examples/whisper/main.rs
@@ -10,41 +10,16 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
-mod audio;
-mod model;
-use model::{Config, Whisper};
mod multilingual;
-
-const DTYPE: DType = DType::F32;
-
-// Audio parameters.
-const SAMPLE_RATE: usize = 16000;
-const N_FFT: usize = 400;
-const N_MELS: usize = 80;
-const HOP_LENGTH: usize = 160;
-const CHUNK_LENGTH: usize = 30;
-const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
-const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
-
-const NO_SPEECH_THRESHOLD: f64 = 0.6;
-const LOGPROB_THRESHOLD: f64 = -1.0;
-const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
-const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
-
-// Tokenizer dependent bits.
-const SOT_TOKEN: &str = "<|startoftranscript|>";
-const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
-const TRANSLATE_TOKEN: &str = "<|translate|>";
-const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
-const EOT_TOKEN: &str = "<|endoftext|>";
-const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
+use candle_transformers::models::whisper::{self as m, audio, model};
+use model::{Config, Whisper};
#[allow(dead_code)]
#[derive(Debug, Clone)]
@@ -94,7 +69,7 @@ impl Decoder {
timestamps: bool,
verbose: bool,
) -> Result<Self> {
- let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
+ let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
@@ -109,11 +84,11 @@ impl Decoder {
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
- let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
- let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
- let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
- let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
- let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
+ let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
+ let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
+ let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
+ let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
+ let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
@@ -220,17 +195,17 @@ impl Decoder {
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
- for (i, &t) in TEMPERATURES.iter().enumerate() {
+ for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
- if i == TEMPERATURES.len() - 1 {
+ if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
- let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
- || dr.avg_logprob < LOGPROB_THRESHOLD;
- if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
+ let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
+ || dr.avg_logprob < m::LOGPROB_THRESHOLD;
+ if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
@@ -248,13 +223,13 @@ impl Decoder {
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
- let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
- let segment_size = usize::min(content_frames - seek, N_FRAMES);
+ let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
+ let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
- let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
+ let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
- if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
+ if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
@@ -492,8 +467,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
- if header.sampling_rate != SAMPLE_RATE as u32 {
- anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
+ if header.sampling_rate != m::SAMPLE_RATE as u32 {
+ anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@@ -501,14 +476,14 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len());
- let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
+ let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len();
- let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
+ let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
- let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
+ let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?;
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs
deleted file mode 100644
index e58ab2ca..00000000
--- a/candle-examples/examples/whisper/model.rs
+++ /dev/null
@@ -1,416 +0,0 @@
-use candle::{Device, IndexOp, Result, Tensor, D};
-use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
-use serde::Deserialize;
-
-// The names in comments correspond to the original implementation:
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
-#[derive(Debug, Clone, PartialEq, Deserialize)]
-pub struct Config {
- pub num_mel_bins: usize, // n_mels
- pub max_source_positions: usize, // n_audio_ctx
- pub d_model: usize, // n_audio_state
- pub encoder_attention_heads: usize, // n_audio_head
- pub encoder_layers: usize, // n_audio_layer
- pub vocab_size: usize, // n_vocab
- pub max_target_positions: usize, // n_text_ctx
- // pub n_text_state: usize,
- pub decoder_attention_heads: usize, // n_text_head
- pub decoder_layers: usize, // n_text_layer
- pub suppress_tokens: Vec<u32>,
-}
-
-fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
- let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
- Ok(Embedding::new(embeddings, hidden_size))
-}
-//
-// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
-// model.
-#[derive(Debug)]
-pub struct Linear {
- inner: candle_nn::Linear,
- span: tracing::Span,
-}
-
-impl Linear {
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let _enter = self.span.enter();
- self.inner.forward(x)
- }
-}
-
-fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- let inner = candle_nn::linear(size1, size2, vb)?;
- Ok(Linear { inner, span })
-}
-
-fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
- let span = tracing::span!(tracing::Level::TRACE, "linear");
- let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
- Ok(Linear { inner, span })
-}
-
-fn conv1d(
- in_channels: usize,
- out_channels: usize,
- kernel_size: usize,
- config: Conv1dConfig,
- vb: VarBuilder,
-) -> Result<Conv1d> {
- let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
- let bias = vb.get(out_channels, "bias")?;
- Ok(Conv1d::new(weight, Some(bias), config))
-}
-
-fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
- let weight = vb.get(size, "weight")?;
- let bias = vb.get(size, "bias")?;
- Ok(LayerNorm::new(weight, bias, 1e-5))
-}
-
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
-struct MultiHeadAttention {
- query: Linear,
- key: Linear,
- value: Linear,
- out: Linear,
- n_head: usize,
- span: tracing::Span,
- softmax_span: tracing::Span,
- matmul_span: tracing::Span,
- kv_cache: Option<(Tensor, Tensor)>,
-}
-
-impl MultiHeadAttention {
- fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
- let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
- let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
- let query = linear(n_state, n_state, vb.pp("q_proj"))?;
- let value = linear(n_state, n_state, vb.pp("v_proj"))?;
- let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
- let out = linear(n_state, n_state, vb.pp("out_proj"))?;
- Ok(Self {
- query,
- key,
- value,
- out,
- n_head,
- span,
- softmax_span,
- matmul_span,
- kv_cache: None,
- })
- }
-
- fn forward(
- &mut self,
- x: &Tensor,
- xa: Option<&Tensor>,
- mask: Option<&Tensor>,
- flush_cache: bool,
- ) -> Result<Tensor> {
- let _enter = self.span.enter();
- let q = self.query.forward(x)?;
- let (k, v) = match xa {
- None => {
- let k = self.key.forward(x)?;
- let v = self.value.forward(x)?;
- (k, v)
- }
- Some(x) => {
- if flush_cache {
- self.kv_cache = None;
- }
- if let Some((k, v)) = &self.kv_cache {
- (k.clone(), v.clone())
- } else {
- let k = self.key.forward(x)?;
- let v = self.value.forward(x)?;
- self.kv_cache = Some((k.clone(), v.clone()));
- (k, v)
- }
- }
- };
- let wv = self.qkv_attention(&q, &k, &v, mask)?;
- let out = self.out.forward(&wv)?;
- Ok(out)
- }
-
- fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
- let (n_batch, n_ctx, n_state) = x.dims3()?;
- let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
- x.reshape(target_dims)?.transpose(1, 2)
- }
-
- fn qkv_attention(
- &self,
- q: &Tensor,
- k: &Tensor,
- v: &Tensor,
- mask: Option<&Tensor>,
- ) -> Result<Tensor> {
- let (_, n_ctx, n_state) = q.dims3()?;
- let scale = ((n_state / self.n_head) as f64).powf(-0.25);
- let q = (self.reshape_head(q)? * scale)?;
- let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
- let v = self.reshape_head(v)?.contiguous()?;
- let mut qk = {
- let _enter = self.matmul_span.enter();
- q.matmul(&k)?
- };
- if let Some(mask) = mask {
- let mask = mask.i((0..n_ctx, 0..n_ctx))?;
- qk = qk.broadcast_add(&mask)?
- }
- let w = {
- let _enter = self.softmax_span.enter();
- softmax(&qk, D::Minus1)?
- };
- let wv = {
- let _enter = self.matmul_span.enter();
- w.matmul(&v)?
- }
- .transpose(1, 2)?
- .flatten_from(2)?;
- Ok(wv)
- }
-}
-
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
-struct ResidualAttentionBlock {
- attn: MultiHeadAttention,
- attn_ln: LayerNorm,
- cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
- mlp_linear1: Linear,
- mlp_linear2: Linear,
- mlp_ln: LayerNorm,
- span: tracing::Span,
-}
-
-impl ResidualAttentionBlock {
- fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
- let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
- let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
- let cross_attn = if ca {
- let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
- let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
- Some((cross_attn, cross_attn_ln))
- } else {
- None
- };
- let n_mlp = n_state * 4;
- let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
- let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
- let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
- Ok(Self {
- attn,
- attn_ln,
- cross_attn,
- mlp_linear1,
- mlp_linear2,
- mlp_ln,
- span,
- })
- }
-
- fn forward(
- &mut self,
- x: &Tensor,
- xa: Option<&Tensor>,
- mask: Option<&Tensor>,
- flush_kv_cache: bool,
- ) -> Result<Tensor> {
- let _enter = self.span.enter();
- let attn = self
- .attn
- .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
- let mut x = (x + attn)?;
- if let Some((attn, ln)) = &mut self.cross_attn {
- x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
- }
- let mlp = self.mlp_linear2.forward(
- &self
- .mlp_linear1
- .forward(&self.mlp_ln.forward(&x)?)?
- .gelu()?,
- )?;
- x + mlp
- }
-}
-
-fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
- let max_timescale = 10000f32;
- let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
- let inv_timescales: Vec<_> = (0..channels / 2)
- .map(|i| (i as f32 * (-log_timescale_increment)).exp())
- .collect();
- let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
- let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
- .to_dtype(candle::DType::F32)?
- .unsqueeze(1)?;
- let sh = (length, channels / 2);
- let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
- let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
- Ok(sincos)
-}
-
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
-pub struct AudioEncoder {
- conv1: Conv1d,
- conv2: Conv1d,
- positional_embedding: Tensor,
- blocks: Vec<ResidualAttentionBlock>,
- ln_post: LayerNorm,
- span: tracing::Span,
- conv1_span: tracing::Span,
- conv2_span: tracing::Span,
-}
-
-impl AudioEncoder {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
- let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
- let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
- let n_state = cfg.d_model;
- let n_head = cfg.encoder_attention_heads;
- let n_ctx = cfg.max_source_positions;
- let cfg1 = Conv1dConfig {
- padding: 1,
- stride: 1,
- groups: 1,
- dilation: 1,
- };
- let cfg2 = Conv1dConfig {
- padding: 1,
- stride: 2,
- groups: 1,
- dilation: 1,
- };
- let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
- let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
- let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
- let blocks = (0..cfg.encoder_layers)
- .map(|i| {
- ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
- })
- .collect::<Result<Vec<_>>>()?;
- let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
- Ok(Self {
- conv1,
- conv2,
- positional_embedding,
- blocks,
- ln_post,
- conv1_span,
- conv2_span,
- span,
- })
- }
-
- pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
- let _enter = self.span.enter();
- let x = {
- let _enter = self.conv1_span.enter();
- self.conv1.forward(x)?.gelu()?
- };
- let x = {
- let _enter = self.conv2_span.enter();
- self.conv2.forward(&x)?.gelu()?
- };
- let x = x.transpose(1, 2)?;
- let (_bsize, seq_len, _hidden) = x.dims3()?;
- let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
- let mut x = x.broadcast_add(&positional_embedding)?;
- for block in self.blocks.iter_mut() {
- x = block.forward(&x, None, None, flush_kv_cache)?
- }
- let x = self.ln_post.forward(&x)?;
- Ok(x)
- }
-}
-
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
-pub struct TextDecoder {
- token_embedding: Embedding,
- positional_embedding: Tensor,
- blocks: Vec<ResidualAttentionBlock>,
- ln: LayerNorm,
- mask: Tensor,
- span: tracing::Span,
- span_final: tracing::Span,
-}
-
-impl TextDecoder {
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
- let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
- let n_state = cfg.d_model;
- let n_head = cfg.decoder_attention_heads;
- let n_ctx = cfg.max_target_positions;
- let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
- let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
- let blocks = (0..cfg.decoder_layers)
- .map(|i| {
- ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
- })
- .collect::<Result<Vec<_>>>()?;
- let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
- let mask: Vec<_> = (0..n_ctx)
- .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
- .collect();
- let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
- Ok(Self {
- token_embedding,
- positional_embedding,
- blocks,
- ln,
- mask,
- span,
- span_final,
- })
- }
-
- pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
- let _enter = self.span.enter();
- let last = x.dim(D::Minus1)?;
- let token_embedding = self.token_embedding.forward(x)?;
- let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
- let mut x = token_embedding.broadcast_add(&positional_embedding)?;
- for block in self.blocks.iter_mut() {
- x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
- }
- self.ln.forward(&x)
- }
-
- pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
- let b_size = x.dim(0)?;
- let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
- let logits = {
- let _enter = self.span_final.enter();
- x.matmul(&w.t()?)?
- };
- Ok(logits)
- }
-}
-
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
-pub struct Whisper {
- pub encoder: AudioEncoder,
- pub decoder: TextDecoder,
- pub config: Config,
-}
-
-impl Whisper {
- pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
- let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
- let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
- Ok(Self {
- encoder,
- decoder,
- config,
- })
- }
-}
diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs
index bc0bae1f..a82b09ef 100644
--- a/candle-examples/examples/whisper/multilingual.rs
+++ b/candle-examples/examples/whisper/multilingual.rs
@@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
.iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
- let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
+ let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;