summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-02 14:59:53 +0100
committerGitHub <noreply@github.com>2023-10-02 14:59:53 +0100
commite04c789230c609c285991b78c29f1d6eef0d104f (patch)
tree718a61d3838c7ac82b56cb5a202ee4b172465aa4 /candle-transformers
parent263a1722021cdf24c801422c58887d93ad2e382a (diff)
downloadcandle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.gz
candle-e04c789230c609c285991b78c29f1d6eef0d104f.tar.bz2
candle-e04c789230c609c285991b78c29f1d6eef0d104f.zip
Add a quantized variant of whisper (#1017)
* Add the quantized-whisper model. * Quantized the whisper model. * Adapt the whisper example to handle quantization. * Add the quantized flag. * Load the proper weights.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/whisper/mod.rs20
-rw-r--r--candle-transformers/src/models/whisper/model.rs19
-rw-r--r--candle-transformers/src/models/whisper/quantized_model.rs403
3 files changed, 424 insertions, 18 deletions
diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs
index 7dc8107b..35d35e77 100644
--- a/candle-transformers/src/models/whisper/mod.rs
+++ b/candle-transformers/src/models/whisper/mod.rs
@@ -1,5 +1,25 @@
pub mod audio;
pub mod model;
+pub mod quantized_model;
+
+use serde::Deserialize;
+
+// The names in comments correspond to the original implementation:
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
+#[derive(Debug, Clone, PartialEq, Deserialize)]
+pub struct Config {
+ pub num_mel_bins: usize, // n_mels
+ pub max_source_positions: usize, // n_audio_ctx
+ pub d_model: usize, // n_audio_state
+ pub encoder_attention_heads: usize, // n_audio_head
+ pub encoder_layers: usize, // n_audio_layer
+ pub vocab_size: usize, // n_vocab
+ pub max_target_positions: usize, // n_text_ctx
+ // pub n_text_state: usize,
+ pub decoder_attention_heads: usize, // n_text_head
+ pub decoder_layers: usize, // n_text_layer
+ pub suppress_tokens: Vec<u32>,
+}
pub const DTYPE: candle::DType = candle::DType::F32;
diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs
index d2eda796..2a58afaf 100644
--- a/candle-transformers/src/models/whisper/model.rs
+++ b/candle-transformers/src/models/whisper/model.rs
@@ -1,23 +1,6 @@
+use super::Config;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
-use serde::Deserialize;
-
-// The names in comments correspond to the original implementation:
-// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
-#[derive(Debug, Clone, PartialEq, Deserialize)]
-pub struct Config {
- pub num_mel_bins: usize, // n_mels
- pub max_source_positions: usize, // n_audio_ctx
- pub d_model: usize, // n_audio_state
- pub encoder_attention_heads: usize, // n_audio_head
- pub encoder_layers: usize, // n_audio_layer
- pub vocab_size: usize, // n_vocab
- pub max_target_positions: usize, // n_text_ctx
- // pub n_text_state: usize,
- pub decoder_attention_heads: usize, // n_text_head
- pub decoder_layers: usize, // n_text_layer
- pub suppress_tokens: Vec<u32>,
-}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs
new file mode 100644
index 00000000..59942cbf
--- /dev/null
+++ b/candle-transformers/src/models/whisper/quantized_model.rs
@@ -0,0 +1,403 @@
+use super::Config;
+use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
+pub use crate::quantized_var_builder::VarBuilder;
+use candle::{Device, IndexOp, Result, Tensor, D};
+use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
+
+#[derive(Debug)]
+struct Linear {
+ weight: QMatMul,
+ bias: Option<Tensor>,
+}
+
+impl Module for Linear {
+ fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
+ let x = x.apply(&self.weight)?;
+ match &self.bias {
+ None => Ok(x),
+ Some(bias) => x.broadcast_add(bias),
+ }
+ }
+}
+
+fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear {
+ weight,
+ bias: Some(bias),
+ })
+}
+
+fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight, bias: None })
+}
+
+fn conv1d(
+ in_channels: usize,
+ out_channels: usize,
+ kernel_size: usize,
+ config: Conv1dConfig,
+ vb: VarBuilder,
+) -> Result<Conv1d> {
+ let weight = vb
+ .get((out_channels, in_channels, kernel_size), "weight")?
+ .dequantize(vb.device())?;
+ let bias = vb.get(out_channels, "bias")?.dequantize(vb.device())?;
+ Ok(Conv1d::new(weight, Some(bias), config))
+}
+
+fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
+ let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
+ let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
+ Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
+struct MultiHeadAttention {
+ query: Linear,
+ key: Linear,
+ value: Linear,
+ out: Linear,
+ n_head: usize,
+ span: tracing::Span,
+ softmax_span: tracing::Span,
+ matmul_span: tracing::Span,
+ kv_cache: Option<(Tensor, Tensor)>,
+}
+
+impl MultiHeadAttention {
+ fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
+ let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
+ let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
+ let query = linear(n_state, n_state, vb.pp("q_proj"))?;
+ let value = linear(n_state, n_state, vb.pp("v_proj"))?;
+ let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
+ let out = linear(n_state, n_state, vb.pp("out_proj"))?;
+ Ok(Self {
+ query,
+ key,
+ value,
+ out,
+ n_head,
+ span,
+ softmax_span,
+ matmul_span,
+ kv_cache: None,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ x: &Tensor,
+ xa: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ flush_cache: bool,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let q = self.query.forward(x)?;
+ let (k, v) = match xa {
+ None => {
+ let k = self.key.forward(x)?;
+ let v = self.value.forward(x)?;
+ (k, v)
+ }
+ Some(x) => {
+ if flush_cache {
+ self.kv_cache = None;
+ }
+ if let Some((k, v)) = &self.kv_cache {
+ (k.clone(), v.clone())
+ } else {
+ let k = self.key.forward(x)?;
+ let v = self.value.forward(x)?;
+ self.kv_cache = Some((k.clone(), v.clone()));
+ (k, v)
+ }
+ }
+ };
+ let wv = self.qkv_attention(&q, &k, &v, mask)?;
+ let out = self.out.forward(&wv)?;
+ Ok(out)
+ }
+
+ fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
+ let (n_batch, n_ctx, n_state) = x.dims3()?;
+ let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
+ x.reshape(target_dims)?.transpose(1, 2)
+ }
+
+ fn qkv_attention(
+ &self,
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ mask: Option<&Tensor>,
+ ) -> Result<Tensor> {
+ let (_, n_ctx, n_state) = q.dims3()?;
+ let scale = ((n_state / self.n_head) as f64).powf(-0.25);
+ let q = (self.reshape_head(q)? * scale)?;
+ let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
+ let v = self.reshape_head(v)?.contiguous()?;
+ let mut qk = {
+ let _enter = self.matmul_span.enter();
+ q.matmul(&k)?
+ };
+ if let Some(mask) = mask {
+ let mask = mask.i((0..n_ctx, 0..n_ctx))?;
+ qk = qk.broadcast_add(&mask)?
+ }
+ let w = {
+ let _enter = self.softmax_span.enter();
+ candle_nn::ops::softmax_last_dim(&qk)?
+ };
+ let wv = {
+ let _enter = self.matmul_span.enter();
+ w.matmul(&v)?
+ }
+ .transpose(1, 2)?
+ .flatten_from(2)?;
+ Ok(wv)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
+struct ResidualAttentionBlock {
+ attn: MultiHeadAttention,
+ attn_ln: LayerNorm,
+ cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
+ mlp_linear1: Linear,
+ mlp_linear2: Linear,
+ mlp_ln: LayerNorm,
+ span: tracing::Span,
+}
+
+impl ResidualAttentionBlock {
+ fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
+ let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
+ let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
+ let cross_attn = if ca {
+ let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
+ let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
+ Some((cross_attn, cross_attn_ln))
+ } else {
+ None
+ };
+ let n_mlp = n_state * 4;
+ let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
+ let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
+ let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
+ Ok(Self {
+ attn,
+ attn_ln,
+ cross_attn,
+ mlp_linear1,
+ mlp_linear2,
+ mlp_ln,
+ span,
+ })
+ }
+
+ fn forward(
+ &mut self,
+ x: &Tensor,
+ xa: Option<&Tensor>,
+ mask: Option<&Tensor>,
+ flush_kv_cache: bool,
+ ) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let attn = self
+ .attn
+ .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
+ let mut x = (x + attn)?;
+ if let Some((attn, ln)) = &mut self.cross_attn {
+ x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
+ }
+ let mlp = self.mlp_linear2.forward(
+ &self
+ .mlp_linear1
+ .forward(&self.mlp_ln.forward(&x)?)?
+ .gelu()?,
+ )?;
+ x + mlp
+ }
+}
+
+fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
+ let max_timescale = 10000f32;
+ let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
+ let inv_timescales: Vec<_> = (0..channels / 2)
+ .map(|i| (i as f32 * (-log_timescale_increment)).exp())
+ .collect();
+ let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
+ let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
+ .to_dtype(candle::DType::F32)?
+ .unsqueeze(1)?;
+ let sh = (length, channels / 2);
+ let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
+ let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
+ Ok(sincos)
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
+pub struct AudioEncoder {
+ conv1: Conv1d,
+ conv2: Conv1d,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln_post: LayerNorm,
+ span: tracing::Span,
+ conv1_span: tracing::Span,
+ conv2_span: tracing::Span,
+}
+
+impl AudioEncoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
+ let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
+ let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
+ let n_state = cfg.d_model;
+ let n_head = cfg.encoder_attention_heads;
+ let n_ctx = cfg.max_source_positions;
+ let cfg1 = Conv1dConfig {
+ padding: 1,
+ stride: 1,
+ groups: 1,
+ dilation: 1,
+ };
+ let cfg2 = Conv1dConfig {
+ padding: 1,
+ stride: 2,
+ groups: 1,
+ dilation: 1,
+ };
+ let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
+ let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
+ let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
+ let blocks = (0..cfg.encoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
+ Ok(Self {
+ conv1,
+ conv2,
+ positional_embedding,
+ blocks,
+ ln_post,
+ conv1_span,
+ conv2_span,
+ span,
+ })
+ }
+
+ pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let x = {
+ let _enter = self.conv1_span.enter();
+ self.conv1.forward(x)?.gelu()?
+ };
+ let x = {
+ let _enter = self.conv2_span.enter();
+ self.conv2.forward(&x)?.gelu()?
+ };
+ let x = x.transpose(1, 2)?;
+ let (_bsize, seq_len, _hidden) = x.dims3()?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
+ let mut x = x.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter_mut() {
+ x = block.forward(&x, None, None, flush_kv_cache)?
+ }
+ let x = self.ln_post.forward(&x)?;
+ Ok(x)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
+pub struct TextDecoder {
+ token_embedding: Embedding,
+ positional_embedding: Tensor,
+ blocks: Vec<ResidualAttentionBlock>,
+ ln: LayerNorm,
+ mask: Tensor,
+ span: tracing::Span,
+ span_final: tracing::Span,
+}
+
+impl TextDecoder {
+ fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
+ let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
+ let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
+ let n_state = cfg.d_model;
+ let n_head = cfg.decoder_attention_heads;
+ let n_ctx = cfg.max_target_positions;
+ let token_embedding = Embedding::new(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
+ let positional_embedding = vb
+ .get((n_ctx, n_state), "embed_positions.weight")?
+ .dequantize(vb.device())?;
+ let blocks = (0..cfg.decoder_layers)
+ .map(|i| {
+ ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
+ let mask: Vec<_> = (0..n_ctx)
+ .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
+ .collect();
+ let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
+ Ok(Self {
+ token_embedding,
+ positional_embedding,
+ blocks,
+ ln,
+ mask,
+ span,
+ span_final,
+ })
+ }
+
+ pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
+ let _enter = self.span.enter();
+ let last = x.dim(D::Minus1)?;
+ let token_embedding = self.token_embedding.forward(x)?;
+ let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
+ let mut x = token_embedding.broadcast_add(&positional_embedding)?;
+ for block in self.blocks.iter_mut() {
+ x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
+ }
+ self.ln.forward(&x)
+ }
+
+ pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
+ let b_size = x.dim(0)?;
+ let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
+ let logits = {
+ let _enter = self.span_final.enter();
+ x.matmul(&w.t()?)?
+ };
+ Ok(logits)
+ }
+}
+
+// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
+pub struct Whisper {
+ pub encoder: AudioEncoder,
+ pub decoder: TextDecoder,
+ pub config: Config,
+}
+
+impl Whisper {
+ pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
+ let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
+ let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
+ Ok(Self {
+ encoder,
+ decoder,
+ config,
+ })
+ }
+}