diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-03-09 11:06:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-09 11:06:04 +0100 |
commit | dd00482ea3456111482ec1cee045d2ae8efaf8ba (patch) | |
tree | 1bc4d566d8c8599f887eb8f8a1ed07be2afb7715 /candle-transformers/src | |
parent | 936f6a48407ee111f52742cf48eccc61f6b62325 (diff) | |
download | candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.gz candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.bz2 candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.zip |
Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model.
* Integrate the quantized version of metavoice.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r-- | candle-transformers/src/models/metavoice.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_metavoice.rs | 226 | ||||
-rw-r--r-- | candle-transformers/src/quantized_nn.rs | 10 |
4 files changed, 241 insertions, 4 deletions
diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 35cb30c7..2eeb0713 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -2,7 +2,7 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D}; use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; // Equivalent to torch.repeat_interleave -fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { +pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> { let img = img.unsqueeze(dim + 1)?; let mut dims = img.dims().to_vec(); dims[dim + 1] = repeats; @@ -664,15 +664,15 @@ pub mod transformer { } } - fn n_local_heads(&self) -> usize { + pub(crate) fn n_local_heads(&self) -> usize { self.n_local_heads.unwrap_or(self.n_head) } - fn head_dim(&self) -> usize { + pub(crate) fn head_dim(&self) -> usize { self.dim / self.n_head } - fn intermediate_size(&self) -> usize { + pub(crate) fn intermediate_size(&self) -> usize { match self.intermediate_size { Some(intermediate_size) => intermediate_size, None => { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 66e06e0e..389d1a80 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -30,6 +30,7 @@ pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_llama2_c; +pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs new file mode 100644 index 00000000..16545150 --- /dev/null +++ b/candle-transformers/src/models/quantized_metavoice.rs @@ -0,0 +1,226 @@ +use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm}; +pub use crate::quantized_var_builder::VarBuilder; + +use crate::models::metavoice::repeat_interleave; +use candle::{Module, Result, Tensor, D}; + +pub mod transformer { + use super::*; + + type Config = crate::models::metavoice::transformer::Config; + + #[derive(Debug, Clone)] + struct FeedForward { + w1: Linear, + w2: Linear, + w3: Linear, + } + + impl FeedForward { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let i_size = cfg.intermediate_size(); + let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?; + let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?; + let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?; + Ok(Self { w1, w2, w3 }) + } + } + + impl Module for FeedForward { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?; + swiglu.apply(&self.w2) + } + } + + #[derive(Debug, Clone)] + struct Attention { + wqkv: Linear, + wo: Linear, + dim: usize, + kv_size: usize, + n_local_heads: usize, + head_dim: usize, + n_head: usize, + kv_cache: Option<(Tensor, Tensor)>, + } + + impl Attention { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let n_local_heads = cfg.n_local_heads(); + let head_dim = cfg.head_dim(); + let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim; + let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?; + let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?; + Ok(Self { + wqkv, + wo, + dim: cfg.dim, + kv_size: n_local_heads * head_dim, + n_local_heads, + head_dim, + n_head: cfg.n_head, + kv_cache: None, + }) + } + + fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> { + let (b_sz, seqlen, _) = xs.dims3()?; + + let qkv = xs.apply(&self.wqkv)?; + let q = qkv.narrow(D::Minus1, 0, self.dim)?; + let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?; + let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?; + let q = q + .reshape((b_sz, seqlen, self.n_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))? + .transpose(1, 2)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let k = Tensor::cat(&[prev_k, &k], 2)?; + let v = Tensor::cat(&[prev_v, &v], 2)?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?; + let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?; + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + let attn_weights = attn_weights.broadcast_add(mask)?; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&v)?; + attn_output + .transpose(1, 2)? + .reshape((b_sz, seqlen, self.dim))? + .apply(&self.wo) + } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } + } + + #[derive(Debug, Clone)] + struct Block { + attention: Attention, + feed_forward: FeedForward, + ffn_norm: RmsNorm, + attention_norm: RmsNorm, + } + + impl Block { + fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let attention = Attention::new(cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?; + let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?; + let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?; + Ok(Self { + attention, + feed_forward, + ffn_norm, + attention_norm, + }) + } + + fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> { + let hs = xs.apply(&self.attention_norm)?; + let hs = (xs + self.attention.forward(&hs, pos, mask))?; + &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward) + } + + fn clear_kv_cache(&mut self) { + self.attention.clear_kv_cache() + } + } + + #[derive(Debug, Clone)] + pub struct Model { + tok_embeddings: Embedding, + pos_embeddings: Embedding, + speaker_cond_pos: Linear, + layers: Vec<Block>, + norm: RmsNorm, + output: Linear, + spk_cond_mask: Tensor, + } + + impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> { + let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?; + let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?; + let speaker_cond_pos = linear_b( + cfg.speaker_emb_dim, + cfg.dim, + false, + vb.pp("speaker_cond_pos"), + )?; + let mut layers = Vec::with_capacity(cfg.n_layer); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.n_layer { + let layer = Block::new(cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?; + let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?; + let spk_cond_mask = Tensor::cat( + &[ + Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?, + Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?, + ], + 0, + )?; + Ok(Self { + tok_embeddings, + pos_embeddings, + speaker_cond_pos, + layers, + norm, + output, + spk_cond_mask, + }) + } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } + + pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> { + let (_b_sz, seqlen) = xs.dims2()?; + let mask: Vec<_> = (0..seqlen) + .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?; + let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?; + let tok_embeddings = xs.apply(&self.tok_embeddings)?; + let pos_embeddings = input_pos.apply(&self.pos_embeddings)?; + let mut xs = tok_embeddings + .broadcast_add(&pos_embeddings)? + .broadcast_add( + &spk_emb + .apply(&self.speaker_cond_pos)? + .broadcast_mul(&self.spk_cond_mask)?, + )?; + let mask = mask.to_dtype(xs.dtype())?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, pos, &mask)? + } + xs.narrow(1, seqlen - 1, 1)? + .apply(&self.norm)? + .apply(&self.output) + } + } +} diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 99e8d45b..21c88430 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -50,6 +50,16 @@ impl Module for Linear { } } +pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> { + let bias = if bias { + Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?) + } else { + None + }; + let weight = QMatMul::new(in_dim, out_dim, vb)?; + Ok(Linear { weight, bias }) +} + pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> { let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; let weight = QMatMul::new(in_dim, out_dim, vb)?; |