diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-18 19:42:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-18 20:42:08 +0200 |
commit | 58197e189657b6587a254882abdb232e83e86848 (patch) | |
tree | 01dbed067341d47e933b821a1b33100524611a50 /candle-transformers/src/models/parler_tts.rs | |
parent | 736d8eb7521dd48e777827848f2b9ed8a7473571 (diff) | |
download | candle-58197e189657b6587a254882abdb232e83e86848.tar.gz candle-58197e189657b6587a254882abdb232e83e86848.tar.bz2 candle-58197e189657b6587a254882abdb232e83e86848.zip |
parler-tts support (#2431)
* Start sketching parler-tts support.
* Implement the attention.
* Add the example code.
* Fix the example.
* Add the description + t5 encode it.
* More of the parler forward pass.
* Fix the positional embeddings.
* Support random sampling in generation.
* Handle EOS.
* Add the python decoder.
* Proper causality mask.
Diffstat (limited to 'candle-transformers/src/models/parler_tts.rs')
-rw-r--r-- | candle-transformers/src/models/parler_tts.rs | 452 |
1 files changed, 452 insertions, 0 deletions
diff --git a/candle-transformers/src/models/parler_tts.rs b/candle-transformers/src/models/parler_tts.rs new file mode 100644 index 00000000..9c66c93a --- /dev/null +++ b/candle-transformers/src/models/parler_tts.rs @@ -0,0 +1,452 @@ +use crate::generation::LogitsProcessor; +use crate::models::t5; +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{layer_norm, linear_b as linear, Activation, LayerNorm, Linear, VarBuilder}; + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct DecoderConfig { + pub vocab_size: usize, + pub max_position_embeddings: usize, + pub num_hidden_layers: usize, + pub ffn_dim: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: Option<usize>, + pub num_cross_attention_key_value_heads: Option<usize>, + pub activation_function: Activation, + pub hidden_size: usize, + pub scale_embedding: bool, + pub num_codebooks: usize, + pub pad_token_id: usize, + pub bos_token_id: usize, + pub eos_token_id: usize, + pub tie_word_embeddings: bool, + pub rope_embeddings: bool, + pub rope_theta: f64, +} + +#[derive(serde::Deserialize, Debug, Clone)] +pub struct Config { + pub decoder_start_token_id: u32, + pub pad_token_id: u32, + pub decoder: DecoderConfig, + pub text_encoder: t5::Config, + pub vocab_size: usize, +} + +#[derive(Debug, Clone)] +pub struct Attention { + k_proj: Linear, + v_proj: Linear, + q_proj: Linear, + out_proj: Linear, + is_causal: bool, + kv_cache: Option<(Tensor, Tensor)>, + scaling: f64, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, +} + +impl Attention { + fn new( + num_kv_heads: usize, + is_causal: bool, + cfg: &DecoderConfig, + vb: VarBuilder, + ) -> Result<Self> { + if cfg.rope_embeddings { + candle::bail!("rope embeddings are not supported"); + } + let embed_dim = cfg.hidden_size; + let head_dim = embed_dim / cfg.num_attention_heads; + let kv_out_dim = num_kv_heads * head_dim; + let k_proj = linear(embed_dim, kv_out_dim, false, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, kv_out_dim, false, vb.pp("v_proj"))?; + let q_proj = linear(embed_dim, embed_dim, false, vb.pp("q_proj"))?; + let out_proj = linear(embed_dim, embed_dim, false, vb.pp("out_proj"))?; + Ok(Self { + k_proj, + v_proj, + q_proj, + out_proj, + is_causal, + kv_cache: None, + scaling: (head_dim as f64).powf(-0.5), + num_heads: cfg.num_attention_heads, + num_kv_heads, + num_kv_groups: cfg.num_attention_heads / num_kv_heads, + head_dim, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + key_value_states: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + let (b_sz, tgt_len, _) = xs.dims3()?; + let query_states = (xs.apply(&self.q_proj)? * self.scaling)? + .reshape((b_sz, tgt_len, self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let key_states = match key_value_states { + Some(states) => states.apply(&self.k_proj)?, + None => xs.apply(&self.k_proj)?, + }; + let key_states = key_states + .reshape((b_sz, (), self.num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let value_states = match key_value_states { + Some(states) => states.apply(&self.v_proj)?, + None => xs.apply(&self.v_proj)?, + }; + let value_states = value_states + .reshape((b_sz, (), self.num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let (key_states, value_states) = match &self.kv_cache { + None => (key_states, value_states), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &key_states], 2)?; + let value_states = Tensor::cat(&[prev_v, &value_states], 2)?; + (key_states, value_states) + } + }; + if self.is_causal { + self.kv_cache = Some((key_states.clone(), value_states.clone())); + } + + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&value_states)?; + attn_output + .transpose(1, 2)? + .reshape((b_sz, tgt_len, ()))? + .apply(&self.out_proj) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } +} + +#[derive(Debug, Clone)] +pub struct DecoderLayer { + self_attn: Attention, + self_attn_layer_norm: LayerNorm, + encoder_attn: Attention, + encoder_attn_layer_norm: LayerNorm, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, + activation: Activation, +} + +impl DecoderLayer { + fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> { + let kv_heads = cfg.num_key_value_heads.unwrap_or(cfg.num_attention_heads); + let kv_heads_cross = cfg.num_cross_attention_key_value_heads.unwrap_or(kv_heads); + + let self_attn = Attention::new(kv_heads, true, cfg, vb.pp("self_attn"))?; + let encoder_attn = Attention::new(kv_heads_cross, false, cfg, vb.pp("encoder_attn"))?; + let self_attn_layer_norm = + layer_norm(cfg.hidden_size, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn_layer_norm = + layer_norm(cfg.hidden_size, 1e-5, vb.pp("encoder_attn_layer_norm"))?; + let fc1 = linear(cfg.hidden_size, cfg.ffn_dim, false, vb.pp("fc1"))?; + let fc2 = linear(cfg.ffn_dim, cfg.hidden_size, false, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb.pp("final_layer_norm"))?; + Ok(Self { + self_attn, + self_attn_layer_norm, + encoder_attn, + encoder_attn_layer_norm, + fc1, + fc2, + final_layer_norm, + activation: cfg.activation_function, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + encoder_xs: &Tensor, + encoder_attention_mask: Option<&Tensor>, + ) -> Result<Tensor> { + // Self attention + let residual = xs; + let xs = xs.apply(&self.self_attn_layer_norm)?; + let xs = self.self_attn.forward(&xs, None, attention_mask)?; + let xs = (residual + xs)?; + + // Cross attention + let residual = &xs; + let xs = xs.apply(&self.encoder_attn_layer_norm)?; + let xs = self + .encoder_attn + .forward(&xs, Some(encoder_xs), encoder_attention_mask)?; + let xs = (residual + xs)?; + + // Fully connected + let residual = &xs; + let xs = xs + .apply(&self.final_layer_norm)? + .apply(&self.fc1)? + .apply(&self.activation)? + .apply(&self.fc2)?; + residual + xs + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + self.encoder_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct Decoder { + embed_tokens: Vec<candle_nn::Embedding>, + embed_positions: Tensor, + layers: Vec<DecoderLayer>, + layer_norm: LayerNorm, + num_codebooks: usize, + hidden_size: usize, + lm_heads: Vec<Linear>, + dtype: candle::DType, +} + +impl Decoder { + pub fn new(cfg: &DecoderConfig, vb: VarBuilder) -> Result<Self> { + let vb_d = vb.pp("model.decoder"); + let mut embed_tokens = Vec::with_capacity(cfg.num_codebooks); + let vb_e = vb_d.pp("embed_tokens"); + for embed_idx in 0..cfg.num_codebooks { + let e = candle_nn::embedding(cfg.vocab_size + 1, cfg.hidden_size, vb_e.pp(embed_idx))?; + embed_tokens.push(e) + } + let embed_positions = vb_d.get( + (cfg.max_position_embeddings, cfg.hidden_size), + "embed_positions.weights", + )?; + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_d.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let layer_norm = layer_norm(cfg.hidden_size, 1e-5, vb_d.pp("layer_norm"))?; + + let mut lm_heads = Vec::with_capacity(cfg.num_codebooks); + let vb_l = vb.pp("lm_heads"); + for lm_idx in 0..cfg.num_codebooks { + let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb_l.pp(lm_idx))?; + lm_heads.push(lm_head) + } + Ok(Self { + embed_tokens, + embed_positions, + layers, + layer_norm, + num_codebooks: cfg.num_codebooks, + lm_heads, + hidden_size: cfg.hidden_size, + dtype: vb.dtype(), + }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + prompt_hidden_states: Option<&Tensor>, + attention_mask: Option<&Tensor>, + encoder_xs: &Tensor, + encoder_attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result<Vec<Tensor>> { + let (b_sz, num_codebooks, seq_len) = input_ids.dims3()?; + if num_codebooks != self.num_codebooks { + candle::bail!("unexpected num codebooks in input {:?}", input_ids.shape()) + } + let mut inputs_embeds = Tensor::zeros( + (b_sz, seq_len, self.hidden_size), + self.dtype, + input_ids.device(), + )?; + for (idx, embs) in self.embed_tokens.iter().enumerate() { + let e = input_ids.i((.., idx))?.apply(embs)?; + inputs_embeds = (inputs_embeds + e)? + } + let inputs_embeds = match prompt_hidden_states { + None => inputs_embeds, + Some(pis) => Tensor::cat(&[pis, &inputs_embeds], 1)?, + }; + let embed_positions = self + .embed_positions + .i(seqlen_offset..seqlen_offset + inputs_embeds.dim(1)?)?; + let mut xs = (inputs_embeds + embed_positions.unsqueeze(0))?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask, encoder_xs, encoder_attention_mask)?; + } + let xs = xs.apply(&self.layer_norm)?; + let mut lm_logits = Vec::with_capacity(self.num_codebooks); + for lm_head in self.lm_heads.iter() { + let logits = xs.apply(lm_head)?; + lm_logits.push(logits) + } + Ok(lm_logits) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } +} + +#[derive(Debug, Clone)] +pub struct Model { + pub embed_prompts: candle_nn::Embedding, + pub enc_to_dec_proj: Option<Linear>, + pub decoder: Decoder, + pub text_encoder: t5::T5EncoderModel, + pub decoder_start_token_id: u32, + pub pad_token_id: u32, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.text_encoder)?; + let decoder = Decoder::new(&cfg.decoder, vb.pp("decoder"))?; + let embed_prompts = candle_nn::embedding( + cfg.vocab_size, + cfg.decoder.hidden_size, + vb.pp("embed_prompts"), + )?; + let enc_to_dec_proj = if cfg.text_encoder.d_model != cfg.decoder.hidden_size { + let proj = linear( + cfg.text_encoder.d_model, + cfg.decoder.hidden_size, + true, + vb.pp("enc_to_dec_proj"), + )?; + Some(proj) + } else { + None + }; + Ok(Self { + decoder, + text_encoder, + embed_prompts, + enc_to_dec_proj, + decoder_start_token_id: cfg.decoder_start_token_id, + pad_token_id: cfg.pad_token_id, + }) + } + + /// Note that the returned tensor uses the CPU device. + pub fn generate( + &mut self, + prompt_tokens: &Tensor, + description_tokens: &Tensor, + mut lp: LogitsProcessor, + max_steps: usize, + ) -> Result<Tensor> { + self.decoder.clear_kv_cache(); + self.text_encoder.clear_kv_cache(); + let encoded = self.text_encoder.forward(description_tokens)?; + let encoded = match self.enc_to_dec_proj.as_ref() { + None => encoded, + Some(proj) => encoded.apply(proj)?, + }; + let prompt_hidden_states = prompt_tokens.apply(&self.embed_prompts)?; + let num_codebooks = self.decoder.num_codebooks; + let mut audio_tokens = vec![self.decoder_start_token_id; num_codebooks]; + let mut all_audio_tokens = vec![vec![]; num_codebooks]; + let prompt_len = prompt_hidden_states.dim(1)?; + for step in 0..max_steps { + let input_ids = Tensor::from_slice( + audio_tokens.as_slice(), + (1, num_codebooks, 1), + prompt_tokens.device(), + )?; + let (prompt_hidden_states, pos) = if step == 0 { + (Some(&prompt_hidden_states), 0) + } else { + (None, step + prompt_len) + }; + let causal_mask = if pos == 0 { + self.prepare_causal_mask(prompt_len + 1, prompt_len + 1, input_ids.device())? + } else { + self.prepare_causal_mask(1, pos + 1, input_ids.device())? + }; + let logits = self.decoder.forward( + &input_ids, + prompt_hidden_states, + Some(&causal_mask), + &encoded, + None, + pos, + )?; + for (logit_idx, logit) in logits.iter().enumerate() { + if logit_idx > step { + break; + } + if audio_tokens[logit_idx] != self.pad_token_id { + let logit = logit.i((0, logit.dim(1)? - 1))?; + let token = lp.sample(&logit)?; + audio_tokens[logit_idx] = token + } + } + if audio_tokens.iter().all(|v| v == &self.pad_token_id) { + break; + } + for (cb_idx, &token) in audio_tokens.iter().enumerate() { + if token != self.decoder_start_token_id && token != self.pad_token_id { + all_audio_tokens[cb_idx].push(token) + } + } + } + + let min_len = all_audio_tokens.iter().map(|v| v.len()).min().unwrap_or(0); + all_audio_tokens.iter_mut().for_each(|v| { + v.resize(min_len, 0); + v.push(self.pad_token_id) + }); + let all_audio_tokens = Tensor::new(all_audio_tokens, &candle::Device::Cpu)?; + Ok(all_audio_tokens) + } + + fn prepare_causal_mask( + &self, + q_len: usize, + kv_len: usize, + device: &candle::Device, + ) -> Result<Tensor> { + let mask: Vec<_> = (0..q_len) + .flat_map(|i| { + (0..kv_len).map(move |j| { + if i + kv_len < j + q_len { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (q_len, kv_len), device) + } +} |