diff options
Diffstat (limited to 'candle-transformers/src/models/whisper/model.rs')
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 416 |
1 files changed, 416 insertions, 0 deletions
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs new file mode 100644 index 00000000..e58ab2ca --- /dev/null +++ b/candle-transformers/src/models/whisper/model.rs @@ -0,0 +1,416 @@ +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec<u32>, +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} +// +// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting +// model. +#[derive(Debug)] +pub struct Linear { + inner: candle_nn::Linear, + span: tracing::Span, +} + +impl Linear { + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, + span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + span, + softmax_span, + matmul_span, + kv_cache: None, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let q = self.query.forward(x)?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + x.reshape(target_dims)?.transpose(1, 2) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.i((0..n_ctx, 0..n_ctx))?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _enter = self.softmax_span.enter(); + softmax(&qk, D::Minus1)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, + span: tracing::Span, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + span, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + x + mlp + } +} + +fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + groups: 1, + dilation: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + groups: 1, + dilation: 1, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + conv1_span, + conv2_span, + span, + }) + } + + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln: LayerNorm, + mask: Tensor, + span: tracing::Span, + span_final: tracing::Span, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + span, + span_final, + }) + } + + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let last = x.dim(D::Minus1)?; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; + } + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } +} |