#![allow(dead_code)] // We use anyhow rather than candle errors as it provides better support for getting the backtrace // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{Device, Tensor}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub num_mel_bins: usize, // n_mels pub max_source_positions: usize, // n_audio_ctx pub d_model: usize, // n_audio_state pub encoder_attention_heads: usize, // n_audio_head pub encoder_layers: usize, // n_audio_layer pub vocab_size: usize, // n_vocab pub max_target_positions: usize, // n_text_ctx // pub n_text_state: usize, pub decoder_attention_heads: usize, // n_text_head pub decoder_layers: usize, // n_text_layer } impl Config { pub fn tiny_en() -> Self { Self { num_mel_bins: 80, vocab_size: 51864, max_source_positions: 1500, d_model: 384, encoder_attention_heads: 6, encoder_layers: 4, max_target_positions: 448, // n_text_state: 384, decoder_attention_heads: 6, decoder_layers: 4, } } } // The struct below is duplicated from candle_nn::Linear so that it's easier to add some wasm // specific monitoring. #[derive(Debug)] struct Linear { weight: Tensor, bias: Option, } impl Linear { fn new(weight: Tensor, bias: Option) -> Self { Self { weight, bias } } fn forward(&self, x: &Tensor) -> candle::Result { let _timer = crate::Timer::new("Linear::forward"); let w = match x.dims() { &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, }; let x = { let _timer = crate::Timer::new("Linear::matmul"); x.matmul(&w)? }; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias), } } } fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((vocab_size, hidden_size), "weight")?; Ok(Embedding::new(embeddings, hidden_size)) } fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { let weight = vb.get((size2, size1), "weight")?; let bias = vb.get(size2, "bias")?; Ok(Linear::new(weight, Some(bias))) } fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { let weight = vb.get((size2, size1), "weight")?; Ok(Linear::new(weight, None)) } fn conv1d( in_channels: usize, out_channels: usize, kernel_size: usize, config: Conv1dConfig, vb: VarBuilder, ) -> Result { let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; let bias = vb.get(out_channels, "bias")?; Ok(Conv1d::new(weight, Some(bias), config)) } fn conv1d_no_bias( in_channels: usize, out_channels: usize, kernel_size: usize, config: Conv1dConfig, vb: VarBuilder, ) -> Result { let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; Ok(Conv1d::new(weight, None, config)) } struct Dropout { pr: f64, } impl Dropout { fn new(pr: f64) -> Self { Self { pr } } fn forward(&self, x: &Tensor) -> Result { // TODO Ok(x.clone()) } } fn layer_norm(size: usize, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, 1e-5)) } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 struct MultiHeadAttention { query: Linear, key: Linear, value: Linear, out: Linear, n_head: usize, } impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { let query = linear(n_state, n_state, vb.pp("q_proj"))?; let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; let out = linear(n_state, n_state, vb.pp("out_proj"))?; Ok(Self { query, key, value, out, n_head, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { let _timer = crate::Timer::new("MultiHeadAttention::forward"); let q = self.query.forward(x)?; let k = self.key.forward(xa.unwrap_or(x))?; let v = self.value.forward(xa.unwrap_or(x))?; let wv = self.qkv_attention(&q, &k, &v, mask)?; let out = self.out.forward(&wv)?; Ok(out) } fn reshape_head(&self, x: &Tensor) -> Result { let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; Ok(x.reshape(target_dims)?.transpose(1, 2)?) } fn qkv_attention( &self, q: &Tensor, k: &Tensor, v: &Tensor, mask: Option<&Tensor>, ) -> Result { let (_, n_ctx, n_state) = q.dims3()?; let scale = ((n_state / self.n_head) as f64).powf(-0.25); let q = { let _timer = crate::Timer::new("q::reshape"); (self.reshape_head(q)? * scale)? }; let k = { let _timer = crate::Timer::new("k::reshape"); (self.reshape_head(k)?.transpose(2, 3)? * scale)? }; let v = { let _timer = crate::Timer::new("v::reshape-contiguous"); self.reshape_head(v)?.contiguous()? }; let mut qk = { let _timer = crate::Timer::new("qk::matmul"); q.matmul(&k)? }; if let Some(mask) = mask { let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; qk = qk.broadcast_add(&mask)? } let w = { let _timer = crate::Timer::new("qk::softmax"); qk.softmax(candle::D::Minus1)? }; let wv = { let _timer = crate::Timer::new("wv::matmul"); w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)? }; Ok(wv) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, cross_attn: Option<(MultiHeadAttention, LayerNorm)>, mlp_linear1: Linear, mlp_linear2: Linear, mlp_ln: LayerNorm, } impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; Ok(Self { attn, attn_ln, cross_attn, mlp_linear1, mlp_linear2, mlp_ln, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { let _timer = crate::Timer::new("ResidualAttentionBlock::forward"); let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; let mut x = (x + attn)?; if let Some((attn, ln)) = &self.cross_attn { x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; } let mlp = self.mlp_linear2.forward( &self .mlp_linear1 .forward(&self.mlp_ln.forward(&x)?)? .gelu()?, )?; Ok((x + mlp)?) } } fn sinusoids(length: usize, channels: usize) -> Result { let max_timescale = 10000f32; let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; let inv_timescales: Vec<_> = (0..channels / 2) .map(|i| (i as f32 * (-log_timescale_increment)).exp()) .collect(); let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; let arange = Tensor::arange(0, length as u32, &Device::Cpu)? .to_dtype(candle::DType::F32)? .unsqueeze(1)?; let sh = (length, channels / 2); let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; Ok(sincos) } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, positional_embedding: Tensor, blocks: Vec, ln_post: LayerNorm, } impl AudioEncoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; let cfg1 = Conv1dConfig { padding: 1, stride: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; let blocks = (0..cfg.encoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; Ok(Self { conv1, conv2, positional_embedding, blocks, ln_post, }) } pub fn forward(&self, x: &Tensor) -> Result { let _timer = crate::Timer::new("AudioEncoder::forward"); let x = { let _timer = crate::Timer::new("conv1::forward"); self.conv1.forward(x)?.gelu()? }; let x = { let _timer = crate::Timer::new("conv2::forward"); self.conv2.forward(&x)?.gelu()? }; let x = x.transpose(1, 2)?; let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; let mut x = x.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { x = block.forward(&x, None, None)? } let x = self.ln_post.forward(&x)?; Ok(x) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, blocks: Vec, ln: LayerNorm, mask: Tensor, } impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { let _timer = crate::Timer::new("TextDecoder::forward"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; let blocks = (0..cfg.decoder_layers) .map(|i| { ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) }) .collect::>>()?; let ln = layer_norm(n_state, vb.pp("layer_norm"))?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; Ok(Self { token_embedding, positional_embedding, blocks, ln, mask, }) } pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { x = block.forward(&x, Some(xa), Some(&self.mask))?; } let x = self.ln.forward(&x)?; let w = self .token_embedding .embeddings() .broadcast_left(x_dims[0])?; let logits = x.matmul(&w.t()?)?; Ok(logits) } } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, pub config: Config, } impl Whisper { pub fn load(vb: &VarBuilder, config: Config) -> Result { let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; Ok(Self { encoder, decoder, config, }) } pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { let enc = self.encoder.forward(mel)?; let dec = self.decoder.forward(tokens, &enc)?; Ok(dec) } }