diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-02 14:59:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 14:59:53 +0100 |
commit | e04c789230c609c285991b78c29f1d6eef0d104f (patch) | |
tree | 718a61d3838c7ac82b56cb5a202ee4b172465aa4 /candle-transformers | |
parent | 263a1722021cdf24c801422c58887d93ad2e382a (diff) | |
download | candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.gz candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.bz2 candle-e04c789230c609c285991b78c29f1d6eef0d104f.zip |
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model.
* Quantized the whisper model.
* Adapt the whisper example to handle quantization.
* Add the quantized flag.
* Load the proper weights.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/whisper/mod.rs | 20 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/model.rs | 19 | ||||
-rw-r--r-- | candle-transformers/src/models/whisper/quantized_model.rs | 403 |
3 files changed, 424 insertions, 18 deletions
diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs index 7dc8107b..35d35e77 100644 --- a/candle-transformers/src/models/whisper/mod.rs +++ b/candle-transformers/src/models/whisper/mod.rs @@ -1,5 +1,25 @@ pub mod audio; pub mod model; +pub mod quantized_model; + +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer + pub suppress_tokens: Vec<u32>, +} pub const DTYPE: candle::DType = candle::DType::F32; diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index d2eda796..2a58afaf 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,23 +1,6 @@ +use super::Config; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec<u32>, -} fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs new file mode 100644 index 00000000..59942cbf --- /dev/null +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -0,0 +1,403 @@ +use super::Config; +use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module}; + +#[derive(Debug)] +struct Linear { + weight: QMatMul, + bias: Option<Tensor>, +} + +impl Module for Linear { + fn forward(&self, x: &Tensor) -> candle::Result<Tensor> { + let x = x.apply(&self.weight)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { + weight, + bias: Some(bias), + }) +} + +fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias: None }) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result<Conv1d> { + let weight = vb + .get((out_channels, in_channels, kernel_size), "weight")? + .dequantize(vb.device())?; + let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> { + let weight = vb.get(size, "weight")?.dequantize(vb.device())?; + let bias = vb.get(size, "bias")?.dequantize(vb.device())?; + Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, + span: tracing::Span, + softmax_span: tracing::Span, + matmul_span: tracing::Span, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); + let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); + let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + span, + softmax_span, + matmul_span, + kv_cache: None, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let q = self.query.forward(x)?; + let (k, v) = match xa { + None => { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + (k, v) + } + Some(x) => { + if flush_cache { + self.kv_cache = None; + } + if let Some((k, v)) = &self.kv_cache { + (k.clone(), v.clone()) + } else { + let k = self.key.forward(x)?; + let v = self.value.forward(x)?; + self.kv_cache = Some((k.clone(), v.clone())); + (k, v) + } + } + }; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { + let (n_batch, n_ctx, n_state) = x.dims3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + x.reshape(target_dims)?.transpose(1, 2) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (_, n_ctx, n_state) = q.dims3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = { + let _enter = self.matmul_span.enter(); + q.matmul(&k)? + }; + if let Some(mask) = mask { + let mask = mask.i((0..n_ctx, 0..n_ctx))?; + qk = qk.broadcast_add(&mask)? + } + let w = { + let _enter = self.softmax_span.enter(); + candle_nn::ops::softmax_last_dim(&qk)? + }; + let wv = { + let _enter = self.matmul_span.enter(); + w.matmul(&v)? + } + .transpose(1, 2)? + .flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, + span: tracing::Span, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + span, + }) + } + + fn forward( + &mut self, + x: &Tensor, + xa: Option<&Tensor>, + mask: Option<&Tensor>, + flush_kv_cache: bool, + ) -> Result<Tensor> { + let _enter = self.span.enter(); + let attn = self + .attn + .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &mut self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + x + mlp + } +} + +fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + groups: 1, + dilation: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + groups: 1, + dilation: 1, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + conv1_span, + conv2_span, + span, + }) + } + + pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.dims3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, None, None, flush_kv_cache)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec<ResidualAttentionBlock>, + ln: LayerNorm, + mask: Tensor, + span: tracing::Span, + span_final: tracing::Span, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); + let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb + .get((n_ctx, n_state), "embed_positions.weight")? + .dequantize(vb.device())?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}"))) + }) + .collect::<Result<Vec<_>>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + span, + span_final, + }) + } + + pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { + let _enter = self.span.enter(); + let last = x.dim(D::Minus1)?; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter_mut() { + x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; + } + self.ln.forward(&x) + } + + pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { + let b_size = x.dim(0)?; + let w = self.token_embedding.embeddings().broadcast_left(b_size)?; + let logits = { + let _enter = self.span_final.enter(); + x.matmul(&w.t()?)? + }; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } +} |