summaryrefslogtreecommitdiff
path: root/candle-transformers/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-03-09 11:06:04 +0100
committerGitHub <noreply@github.com>2024-03-09 11:06:04 +0100
commitdd00482ea3456111482ec1cee045d2ae8efaf8ba (patch)
tree1bc4d566d8c8599f887eb8f8a1ed07be2afb7715 /candle-transformers/src
parent936f6a48407ee111f52742cf48eccc61f6b62325 (diff)
downloadcandle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.gz
candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.tar.bz2
candle-dd00482ea3456111482ec1cee045d2ae8efaf8ba.zip
Quantized version of the metavoice model. (#1824)
* Quantized version of the metavoice model. * Integrate the quantized version of metavoice.
Diffstat (limited to 'candle-transformers/src')
-rw-r--r--candle-transformers/src/models/metavoice.rs8
-rw-r--r--candle-transformers/src/models/mod.rs1
-rw-r--r--candle-transformers/src/models/quantized_metavoice.rs226
-rw-r--r--candle-transformers/src/quantized_nn.rs10
4 files changed, 241 insertions, 4 deletions
diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs
index 35cb30c7..2eeb0713 100644
--- a/candle-transformers/src/models/metavoice.rs
+++ b/candle-transformers/src/models/metavoice.rs
@@ -2,7 +2,7 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
// Equivalent to torch.repeat_interleave
-fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
+pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
@@ -664,15 +664,15 @@ pub mod transformer {
}
}
- fn n_local_heads(&self) -> usize {
+ pub(crate) fn n_local_heads(&self) -> usize {
self.n_local_heads.unwrap_or(self.n_head)
}
- fn head_dim(&self) -> usize {
+ pub(crate) fn head_dim(&self) -> usize {
self.dim / self.n_head
}
- fn intermediate_size(&self) -> usize {
+ pub(crate) fn intermediate_size(&self) -> usize {
match self.intermediate_size {
Some(intermediate_size) => intermediate_size,
None => {
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 66e06e0e..389d1a80 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -30,6 +30,7 @@ pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
pub mod quantized_llama2_c;
+pub mod quantized_metavoice;
pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_mpt;
diff --git a/candle-transformers/src/models/quantized_metavoice.rs b/candle-transformers/src/models/quantized_metavoice.rs
new file mode 100644
index 00000000..16545150
--- /dev/null
+++ b/candle-transformers/src/models/quantized_metavoice.rs
@@ -0,0 +1,226 @@
+use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
+pub use crate::quantized_var_builder::VarBuilder;
+
+use crate::models::metavoice::repeat_interleave;
+use candle::{Module, Result, Tensor, D};
+
+pub mod transformer {
+ use super::*;
+
+ type Config = crate::models::metavoice::transformer::Config;
+
+ #[derive(Debug, Clone)]
+ struct FeedForward {
+ w1: Linear,
+ w2: Linear,
+ w3: Linear,
+ }
+
+ impl FeedForward {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let i_size = cfg.intermediate_size();
+ let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
+ let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
+ let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
+ Ok(Self { w1, w2, w3 })
+ }
+ }
+
+ impl Module for FeedForward {
+ fn forward(&self, xs: &Tensor) -> Result<Tensor> {
+ let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
+ swiglu.apply(&self.w2)
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ struct Attention {
+ wqkv: Linear,
+ wo: Linear,
+ dim: usize,
+ kv_size: usize,
+ n_local_heads: usize,
+ head_dim: usize,
+ n_head: usize,
+ kv_cache: Option<(Tensor, Tensor)>,
+ }
+
+ impl Attention {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let n_local_heads = cfg.n_local_heads();
+ let head_dim = cfg.head_dim();
+ let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
+ let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
+ let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
+ Ok(Self {
+ wqkv,
+ wo,
+ dim: cfg.dim,
+ kv_size: n_local_heads * head_dim,
+ n_local_heads,
+ head_dim,
+ n_head: cfg.n_head,
+ kv_cache: None,
+ })
+ }
+
+ fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
+ let (b_sz, seqlen, _) = xs.dims3()?;
+
+ let qkv = xs.apply(&self.wqkv)?;
+ let q = qkv.narrow(D::Minus1, 0, self.dim)?;
+ let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
+ let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
+ let q = q
+ .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
+ .transpose(1, 2)?
+ .contiguous()?;
+ let k = k
+ .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
+ .transpose(1, 2)?;
+ let v = v
+ .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
+ .transpose(1, 2)?;
+
+ let (k, v) = match &self.kv_cache {
+ None => (k, v),
+ Some((prev_k, prev_v)) => {
+ let k = Tensor::cat(&[prev_k, &k], 2)?;
+ let v = Tensor::cat(&[prev_v, &v], 2)?;
+ (k, v)
+ }
+ };
+ self.kv_cache = Some((k.clone(), v.clone()));
+
+ let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
+ let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
+
+ let scale = 1f64 / f64::sqrt(self.head_dim as f64);
+ let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
+
+ let attn_weights = attn_weights.broadcast_add(mask)?;
+ let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
+ let attn_output = attn_weights.matmul(&v)?;
+ attn_output
+ .transpose(1, 2)?
+ .reshape((b_sz, seqlen, self.dim))?
+ .apply(&self.wo)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.kv_cache = None
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ struct Block {
+ attention: Attention,
+ feed_forward: FeedForward,
+ ffn_norm: RmsNorm,
+ attention_norm: RmsNorm,
+ }
+
+ impl Block {
+ fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let attention = Attention::new(cfg, vb.pp("attention"))?;
+ let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
+ let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
+ let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
+ Ok(Self {
+ attention,
+ feed_forward,
+ ffn_norm,
+ attention_norm,
+ })
+ }
+
+ fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
+ let hs = xs.apply(&self.attention_norm)?;
+ let hs = (xs + self.attention.forward(&hs, pos, mask))?;
+ &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
+ }
+
+ fn clear_kv_cache(&mut self) {
+ self.attention.clear_kv_cache()
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ pub struct Model {
+ tok_embeddings: Embedding,
+ pos_embeddings: Embedding,
+ speaker_cond_pos: Linear,
+ layers: Vec<Block>,
+ norm: RmsNorm,
+ output: Linear,
+ spk_cond_mask: Tensor,
+ }
+
+ impl Model {
+ pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
+ let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
+ let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
+ let speaker_cond_pos = linear_b(
+ cfg.speaker_emb_dim,
+ cfg.dim,
+ false,
+ vb.pp("speaker_cond_pos"),
+ )?;
+ let mut layers = Vec::with_capacity(cfg.n_layer);
+ let vb_l = vb.pp("layers");
+ for layer_idx in 0..cfg.n_layer {
+ let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
+ layers.push(layer)
+ }
+ let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
+ let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
+ let spk_cond_mask = Tensor::cat(
+ &[
+ Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
+ Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
+ ],
+ 0,
+ )?;
+ Ok(Self {
+ tok_embeddings,
+ pos_embeddings,
+ speaker_cond_pos,
+ layers,
+ norm,
+ output,
+ spk_cond_mask,
+ })
+ }
+
+ pub fn clear_kv_cache(&mut self) {
+ for layer in self.layers.iter_mut() {
+ layer.clear_kv_cache()
+ }
+ }
+
+ pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
+ let (_b_sz, seqlen) = xs.dims2()?;
+ let mask: Vec<_> = (0..seqlen)
+ .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
+ let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
+ let tok_embeddings = xs.apply(&self.tok_embeddings)?;
+ let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
+ let mut xs = tok_embeddings
+ .broadcast_add(&pos_embeddings)?
+ .broadcast_add(
+ &spk_emb
+ .apply(&self.speaker_cond_pos)?
+ .broadcast_mul(&self.spk_cond_mask)?,
+ )?;
+ let mask = mask.to_dtype(xs.dtype())?;
+ for layer in self.layers.iter_mut() {
+ xs = layer.forward(&xs, pos, &mask)?
+ }
+ xs.narrow(1, seqlen - 1, 1)?
+ .apply(&self.norm)?
+ .apply(&self.output)
+ }
+ }
+}
diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs
index 99e8d45b..21c88430 100644
--- a/candle-transformers/src/quantized_nn.rs
+++ b/candle-transformers/src/quantized_nn.rs
@@ -50,6 +50,16 @@ impl Module for Linear {
}
}
+pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
+ let bias = if bias {
+ Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?)
+ } else {
+ None
+ };
+ let weight = QMatMul::new(in_dim, out_dim, vb)?;
+ Ok(Linear { weight, bias })
+}
+
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
let weight = QMatMul::new(in_dim, out_dim, vb)?;