diff options
Diffstat (limited to 'candle-examples/examples/whisper/model.rs')
-rw-r--r-- | candle-examples/examples/whisper/model.rs | 416 |
1 files changed, 0 insertions, 416 deletions
diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs deleted file mode 100644 index e58ab2ca..00000000 --- a/candle-examples/examples/whisper/model.rs +++ /dev/null @@ -1,416 +0,0 @@ -use candle::{Device, IndexOp, Result, Tensor, D}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; -use serde::Deserialize; - -// The names in comments correspond to the original implementation: -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 -#[derive(Debug, Clone, PartialEq, Deserialize)] -pub struct Config { - pub num_mel_bins: usize, // n_mels - pub max_source_positions: usize, // n_audio_ctx - pub d_model: usize, // n_audio_state - pub encoder_attention_heads: usize, // n_audio_head - pub encoder_layers: usize, // n_audio_layer - pub vocab_size: usize, // n_vocab - pub max_target_positions: usize, // n_text_ctx - // pub n_text_state: usize, - pub decoder_attention_heads: usize, // n_text_head - pub decoder_layers: usize, // n_text_layer - pub suppress_tokens: Vec<u32>, -} - -fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> { - let embeddings = vb.get((vocab_size, hidden_size), "weight")?; - Ok(Embedding::new(embeddings, hidden_size)) -} -// -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn conv1d( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result<Conv1d> { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - let bias = vb.get(out_channels, "bias")?; - Ok(Conv1d::new(weight, Some(bias), config)) -} - -fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> { - let weight = vb.get(size, "weight")?; - let bias = vb.get(size, "bias")?; - Ok(LayerNorm::new(weight, bias, 1e-5)) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 -struct MultiHeadAttention { - query: Linear, - key: Linear, - value: Linear, - out: Linear, - n_head: usize, - span: tracing::Span, - softmax_span: tracing::Span, - matmul_span: tracing::Span, - kv_cache: Option<(Tensor, Tensor)>, -} - -impl MultiHeadAttention { - fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); - let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax"); - let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul"); - let query = linear(n_state, n_state, vb.pp("q_proj"))?; - let value = linear(n_state, n_state, vb.pp("v_proj"))?; - let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; - let out = linear(n_state, n_state, vb.pp("out_proj"))?; - Ok(Self { - query, - key, - value, - out, - n_head, - span, - softmax_span, - matmul_span, - kv_cache: None, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let q = self.query.forward(x)?; - let (k, v) = match xa { - None => { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - (k, v) - } - Some(x) => { - if flush_cache { - self.kv_cache = None; - } - if let Some((k, v)) = &self.kv_cache { - (k.clone(), v.clone()) - } else { - let k = self.key.forward(x)?; - let v = self.value.forward(x)?; - self.kv_cache = Some((k.clone(), v.clone())); - (k, v) - } - } - }; - let wv = self.qkv_attention(&q, &k, &v, mask)?; - let out = self.out.forward(&wv)?; - Ok(out) - } - - fn reshape_head(&self, x: &Tensor) -> Result<Tensor> { - let (n_batch, n_ctx, n_state) = x.dims3()?; - let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - x.reshape(target_dims)?.transpose(1, 2) - } - - fn qkv_attention( - &self, - q: &Tensor, - k: &Tensor, - v: &Tensor, - mask: Option<&Tensor>, - ) -> Result<Tensor> { - let (_, n_ctx, n_state) = q.dims3()?; - let scale = ((n_state / self.n_head) as f64).powf(-0.25); - let q = (self.reshape_head(q)? * scale)?; - let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; - let v = self.reshape_head(v)?.contiguous()?; - let mut qk = { - let _enter = self.matmul_span.enter(); - q.matmul(&k)? - }; - if let Some(mask) = mask { - let mask = mask.i((0..n_ctx, 0..n_ctx))?; - qk = qk.broadcast_add(&mask)? - } - let w = { - let _enter = self.softmax_span.enter(); - softmax(&qk, D::Minus1)? - }; - let wv = { - let _enter = self.matmul_span.enter(); - w.matmul(&v)? - } - .transpose(1, 2)? - .flatten_from(2)?; - Ok(wv) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 -struct ResidualAttentionBlock { - attn: MultiHeadAttention, - attn_ln: LayerNorm, - cross_attn: Option<(MultiHeadAttention, LayerNorm)>, - mlp_linear1: Linear, - mlp_linear2: Linear, - mlp_ln: LayerNorm, - span: tracing::Span, -} - -impl ResidualAttentionBlock { - fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); - let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; - let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; - let cross_attn = if ca { - let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; - let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; - Some((cross_attn, cross_attn_ln)) - } else { - None - }; - let n_mlp = n_state * 4; - let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; - let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; - let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; - Ok(Self { - attn, - attn_ln, - cross_attn, - mlp_linear1, - mlp_linear2, - mlp_ln, - span, - }) - } - - fn forward( - &mut self, - x: &Tensor, - xa: Option<&Tensor>, - mask: Option<&Tensor>, - flush_kv_cache: bool, - ) -> Result<Tensor> { - let _enter = self.span.enter(); - let attn = self - .attn - .forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?; - let mut x = (x + attn)?; - if let Some((attn, ln)) = &mut self.cross_attn { - x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?; - } - let mlp = self.mlp_linear2.forward( - &self - .mlp_linear1 - .forward(&self.mlp_ln.forward(&x)?)? - .gelu()?, - )?; - x + mlp - } -} - -fn sinusoids(length: usize, channels: usize) -> Result<Tensor> { - let max_timescale = 10000f32; - let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; - let inv_timescales: Vec<_> = (0..channels / 2) - .map(|i| (i as f32 * (-log_timescale_increment)).exp()) - .collect(); - let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::arange(0, length as u32, &Device::Cpu)? - .to_dtype(candle::DType::F32)? - .unsqueeze(1)?; - let sh = (length, channels / 2); - let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; - let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; - Ok(sincos) -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 -pub struct AudioEncoder { - conv1: Conv1d, - conv2: Conv1d, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln_post: LayerNorm, - span: tracing::Span, - conv1_span: tracing::Span, - conv2_span: tracing::Span, -} - -impl AudioEncoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); - let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); - let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); - let n_state = cfg.d_model; - let n_head = cfg.encoder_attention_heads; - let n_ctx = cfg.max_source_positions; - let cfg1 = Conv1dConfig { - padding: 1, - stride: 1, - groups: 1, - dilation: 1, - }; - let cfg2 = Conv1dConfig { - padding: 1, - stride: 2, - groups: 1, - dilation: 1, - }; - let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; - let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; - let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; - let blocks = (0..cfg.encoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; - Ok(Self { - conv1, - conv2, - positional_embedding, - blocks, - ln_post, - conv1_span, - conv2_span, - span, - }) - } - - pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let x = { - let _enter = self.conv1_span.enter(); - self.conv1.forward(x)?.gelu()? - }; - let x = { - let _enter = self.conv2_span.enter(); - self.conv2.forward(&x)?.gelu()? - }; - let x = x.transpose(1, 2)?; - let (_bsize, seq_len, _hidden) = x.dims3()?; - let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; - let mut x = x.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, None, None, flush_kv_cache)? - } - let x = self.ln_post.forward(&x)?; - Ok(x) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 -pub struct TextDecoder { - token_embedding: Embedding, - positional_embedding: Tensor, - blocks: Vec<ResidualAttentionBlock>, - ln: LayerNorm, - mask: Tensor, - span: tracing::Span, - span_final: tracing::Span, -} - -impl TextDecoder { - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); - let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final"); - let n_state = cfg.d_model; - let n_head = cfg.decoder_attention_heads; - let n_ctx = cfg.max_target_positions; - let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; - let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; - let blocks = (0..cfg.decoder_layers) - .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) - }) - .collect::<Result<Vec<_>>>()?; - let ln = layer_norm(n_state, vb.pp("layer_norm"))?; - let mask: Vec<_> = (0..n_ctx) - .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) - .collect(); - let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; - Ok(Self { - token_embedding, - positional_embedding, - blocks, - ln, - mask, - span, - span_final, - }) - } - - pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> { - let _enter = self.span.enter(); - let last = x.dim(D::Minus1)?; - let token_embedding = self.token_embedding.forward(x)?; - let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; - let mut x = token_embedding.broadcast_add(&positional_embedding)?; - for block in self.blocks.iter_mut() { - x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?; - } - self.ln.forward(&x) - } - - pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> { - let b_size = x.dim(0)?; - let w = self.token_embedding.embeddings().broadcast_left(b_size)?; - let logits = { - let _enter = self.span_final.enter(); - x.matmul(&w.t()?)? - }; - Ok(logits) - } -} - -// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 -pub struct Whisper { - pub encoder: AudioEncoder, - pub decoder: TextDecoder, - pub config: Config, -} - -impl Whisper { - pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> { - let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; - let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; - Ok(Self { - encoder, - decoder, - config, - }) - } -} |