diff options
author | Zhuo Jinggang <jg.zhuo@outlook.com> | 2024-07-12 16:00:03 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-12 10:00:03 +0200 |
commit | c63048d3748649c6f13148eb01e6d812d897a0d2 (patch) | |
tree | 275f50476521bf47bb89530dd822a45ae776e6d3 /candle-transformers | |
parent | a226a9736baee550b01de53cb3e416d3d94e69d3 (diff) | |
download | candle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.gz candle-c63048d3748649c6f13148eb01e6d812d897a0d2.tar.bz2 candle-c63048d3748649c6f13148eb01e6d812d897a0d2.zip |
add quantized qwen2 (#2329)
* add quantized version of qwen2 and corresponding example for qwen2-instruct
* fix quantized qwen2 clippy error
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mod.rs | 1 | ||||
-rw-r--r-- | candle-transformers/src/models/quantized_qwen2.rs | 323 |
2 files changed, 324 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 86a0ec08..7baa12e6 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -47,6 +47,7 @@ pub mod quantized_moondream; pub mod quantized_mpt; pub mod quantized_phi; pub mod quantized_phi3; +pub mod quantized_qwen2; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_qwen2.rs b/candle-transformers/src/models/quantized_qwen2.rs new file mode 100644 index 00000000..addfab2b --- /dev/null +++ b/candle-transformers/src/models/quantized_qwen2.rs @@ -0,0 +1,323 @@ +use crate::{quantized_nn::RmsNorm, utils::repeat_kv}; +use candle::{ + quantized::{gguf_file, QMatMul}, + DType, Device, IndexOp, Result, Tensor, +}; +use candle_nn::{Embedding, Module}; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_bq: Tensor, + attention_bk: Tensor, + attention_bv: Tensor, + attention_wo: QMatMul, + attention_norm: RmsNorm, + mlp: Mlp, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q.broadcast_add(&self.attention_bq)?; + let k = k.broadcast_add(&self.attention_bk)?; + let v = v.broadcast_add(&self.attention_bv)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + // let (q, k) = self + // .rotary_embedding + // .apply_rotary_emb_qkv(&q, &k, index_pos)?; + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // Support for MQA, useful for 70B models and mistral. + let k = repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } +} + +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + norm: RmsNorm, + output: QMatMul, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + context_length: usize, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, context_length as u32, device)? + .to_dtype(DType::F32)? + .reshape((context_length, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result<Self> { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + let head_count = md_get("qwen2.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("qwen2.attention.head_count_kv")?.to_u32()? as usize; + let embedding_length = md_get("qwen2.embedding_length")?.to_u32()? as usize; + let context_length = md_get("qwen2.context_length")?.to_u32()? as usize; + let block_count = md_get("qwen2.block_count")?.to_u32()? as usize; + let rms_norm_eps = md_get("qwen2.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let rope_freq_base = md_get("qwen2.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + + let head_dim = embedding_length / head_count; + + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let norm = RmsNorm::from_qtensor( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(v) => QMatMul::from_qtensor(v)?, + _ => { + // use tie_word_embeddings + QMatMul::from_qtensor(ct.tensor(reader, "token_embd.weight", device)?)? + } + }; + + let (cos, sin) = precomput_freqs_cis(head_dim, rope_freq_base, context_length, device)?; + + let mut layers = Vec::with_capacity(block_count); + + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + + let attention_bq = ct.tensor(reader, &format!("{prefix}.attn_q.bias"), device)?; + let attention_bk = ct.tensor(reader, &format!("{prefix}.attn_k.bias"), device)?; + let attention_bv = ct.tensor(reader, &format!("{prefix}.attn_v.bias"), device)?; + + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + + let mlp = { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + Mlp { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + } + }; + + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq)?, + attention_wk: QMatMul::from_qtensor(attention_wk)?, + attention_wv: QMatMul::from_qtensor(attention_wv)?, + attention_bq: attention_bq.dequantize(device)?, + attention_bk: attention_bk.dequantize(device)?, + attention_bv: attention_bv.dequantize(device)?, + attention_wo: QMatMul::from_qtensor(attention_wo)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, + cos: cos.clone(), + sin: sin.clone(), + mlp, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim, + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }); + } + + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let x = layer.mlp.forward(&x)?; + let x = (x + residual)?; + layer_in = x + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} |