diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-29 22:28:46 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-29 22:28:46 +0100 |
commit | a1a5ab8b0a59b717eb04ff9aa4ef49018ec99e7f (patch) | |
tree | 3aa8059a1ea95b6f3bc8857ab86c7c84f9a8785d /candle-examples/examples/quantized | |
parent | 59b731de99f2f351fdb6ecb428224a29e8cf2d35 (diff) | |
download | candle-a1a5ab8b0a59b717eb04ff9aa4ef49018ec99e7f.tar.gz candle-a1a5ab8b0a59b717eb04ff9aa4ef49018ec99e7f.tar.bz2 candle-a1a5ab8b0a59b717eb04ff9aa4ef49018ec99e7f.zip |
Neon optimized vecdot (#666)
* Q5k vecdot.
* Add the q3k vecdot.
* Q2k vecdot.
* Move the quantized model to its own file.
Diffstat (limited to 'candle-examples/examples/quantized')
-rw-r--r-- | candle-examples/examples/quantized/main.rs | 368 | ||||
-rw-r--r-- | candle-examples/examples/quantized/model.rs | 367 |
2 files changed, 371 insertions, 364 deletions
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index a1e3eabd..53be19b9 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -5,377 +5,17 @@ extern crate intel_mkl_src; extern crate accelerate_src; use clap::{Parser, ValueEnum}; -use std::collections::HashMap; use std::io::Write; use tokenizers::Tokenizer; -use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::{Embedding, Module}; +use candle::{Device, Tensor}; use candle_transformers::generation::LogitsProcessor; -const MAX_SEQ_LEN: usize = 4096; -const DEFAULT_PROMPT: &str = "My favorite theorem is "; - -struct RmsNorm { - inner: candle_nn::LayerNorm, - span: tracing::Span, -} - -impl RmsNorm { - fn new(scale: QTensor, eps: f32) -> Result<Self> { - let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); - let scale = scale.dequantize(&Device::Cpu)?; - let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); - Ok(Self { inner, span }) - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -// QMatMul wrapper adding some tracing. -struct QMatMul { - inner: candle::quantized::QMatMul, - span: tracing::Span, -} - -impl QMatMul { - fn from_qtensor(qtensor: QTensor) -> Self { - let inner = candle::quantized::QMatMul::from_qtensor(qtensor); - let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Self { inner, span } - } - - fn forward(&self, xs: &Tensor) -> Result<Tensor> { - let _enter = self.span.enter(); - self.inner.forward(xs) - } -} - -struct LayerWeights { - attention_wq: QMatMul, - attention_wk: QMatMul, - attention_wv: QMatMul, - attention_wo: QMatMul, - attention_norm: RmsNorm, - feed_forward_w1: QMatMul, - feed_forward_w2: QMatMul, - feed_forward_w3: QMatMul, - ffn_norm: RmsNorm, - n_head: usize, - n_kv_head: usize, - head_dim: usize, - cos: Tensor, - sin: Tensor, - kv_cache: Option<(Tensor, Tensor)>, - span_attn: tracing::Span, - span_rot: tracing::Span, - span_mlp: tracing::Span, -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -impl LayerWeights { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_rot.enter(); - let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; - let cos = self - .cos - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let sin = self - .sin - .narrow(0, index_pos, seq_len)? - .reshape((seq_len, n_embd / 2, 1))?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; - // This mimics the llama.cpp behavior. - // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 - // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. - // The resulting y0 and y1 are also interleaved with: - // y0 = x0*cos - x1*sin - // y1 = x0*sin + x1*cos - let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[y0, y1], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; - Ok(rope) - } - - fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { - let _enter = self.span_attn.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.attention_wq.forward(x)?; - let k = self.attention_wk.forward(x)?; - let v = self.attention_wv.forward(x)?; - - let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? - .transpose(1, 2)?; - let k = k - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - let v = v - .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? - .transpose(1, 2)?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let k = self.apply_rotary_emb(&k, index_pos)?; - - let (k, v) = match &self.kv_cache { - None => (k, v), - Some((k_cache, v_cache)) => { - let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; - let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; - (k, v) - } - }; - self.kv_cache = Some((k.clone(), v.clone())); - - // Support for MQA, useful for 70B models. - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = mask.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.attention_wo.forward(&y)?; - Ok(y) - } +mod model; +use model::ModelWeights; - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_kv_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; - let x = x - .unsqueeze(2)? - .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? - .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; - Ok(x) - } - } -} - -struct ModelWeights { - tok_embeddings: Embedding, - layers: Vec<LayerWeights>, - norm: RmsNorm, - output: QMatMul, - masks: HashMap<usize, Tensor>, - span: tracing::Span, - span_output: tracing::Span, -} - -fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { - let theta: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let cos = idx_theta.cos()?; - let sin = idx_theta.sin()?; - Ok((cos, sin)) -} - -impl ModelWeights { - fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { - let cpu = &Device::Cpu; - let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; - let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; - let tok_embeddings = ct.remove("tok_embeddings.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; - let output = ct.remove("output.weight")?; - let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); - for layer_idx in 0..ct.hparams.n_layer { - let prefix = format!("layers.{layer_idx}"); - let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; - let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; - let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; - let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; - let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; - let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, - n_head: ct.hparams.n_head as usize, - n_kv_head: ct.hparams.n_head as usize / gqa, - head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - fn from_gguf<R: std::io::Seek + std::io::Read>( - ct: gguf_file::Content, - reader: &mut R, - ) -> Result<Self> { - let cpu = &Device::Cpu; - let md_get = |s: &str| match ct.metadata.get(s) { - None => candle::bail!("cannot find {s} in metadata"), - Some(v) => Ok(v), - }; - - // Parameter extraction from metadata. - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; - // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; - - let rope_freq_base = md_get("llama.rope.freq_base") - .and_then(|m| m.to_f32()) - .unwrap_or(10000f32); - let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; - - let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; - let tok_embeddings = tok_embeddings.dequantize(cpu)?; - let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; - let output = ct.tensor(reader, "output.weight")?; - let mut layers = Vec::with_capacity(block_count); - for layer_idx in 0..block_count { - let prefix = format!("blk.{layer_idx}"); - let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; - let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; - let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; - let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; - let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; - let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; - let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; - let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; - let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); - let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); - let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); - layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), - attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), - ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, - n_head: head_count, - n_kv_head: head_count_kv, - head_dim: embedding_length / head_count, - cos: cos.clone(), - sin: sin.clone(), - kv_cache: None, - span_attn, - span_rot, - span_mlp, - }) - } - let span = tracing::span!(tracing::Level::TRACE, "model"); - let span_output = tracing::span!(tracing::Level::TRACE, "output"); - Ok(Self { - tok_embeddings: Embedding::new(tok_embeddings, embedding_length), - layers, - norm, - output: QMatMul::from_qtensor(output), - masks: HashMap::new(), - span, - span_output, - }) - } - - fn mask(&mut self, t: usize) -> Result<Tensor> { - if let Some(mask) = self.masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; - self.masks.insert(t, mask.clone()); - Ok(mask) - } - } - - fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, seq_len) = x.dims2()?; - let mask = self.mask(seq_len)?; - let _enter = self.span.enter(); - let mut layer_in = self.tok_embeddings.forward(x)?; - for layer in self.layers.iter_mut() { - let x = layer_in; - let residual = &x; - let x = layer.attention_norm.forward(&x)?; - let attn = layer.forward_attn(&x, &mask, index_pos)?; - let x = (attn + residual)?; - - // MLP - let _enter = layer.span_mlp.enter(); - let residual = &x; - let x = layer.ffn_norm.forward(&x)?; - let w1 = layer.feed_forward_w1.forward(&x)?; - let w3 = layer.feed_forward_w3.forward(&x)?; - let mlp = layer - .feed_forward_w2 - .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; - layer_in = (mlp + residual)?; - } - let x = self.norm.forward(&layer_in)?; - let x = x.i((.., seq_len - 1, ..))?; - let _enter = self.span_output.enter(); - self.output.forward(&x) - } -} +const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Clone, Debug, Copy, ValueEnum)] enum Which { diff --git a/candle-examples/examples/quantized/model.rs b/candle-examples/examples/quantized/model.rs new file mode 100644 index 00000000..27ac18a9 --- /dev/null +++ b/candle-examples/examples/quantized/model.rs @@ -0,0 +1,367 @@ +use std::collections::HashMap; + +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +const MAX_SEQ_LEN: usize = 4096; + +struct RmsNorm { + inner: candle_nn::LayerNorm, + span: tracing::Span, +} + +impl RmsNorm { + fn new(scale: QTensor, eps: f32) -> Result<Self> { + let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); + let scale = scale.dequantize(&Device::Cpu)?; + let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} + +// QMatMul wrapper adding some tracing. +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Self { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Self { inner, span } + } + + fn forward(&self, xs: &Tensor) -> Result<Tensor> { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +struct LayerWeights { + attention_wq: QMatMul, + attention_wk: QMatMul, + attention_wv: QMatMul, + attention_wo: QMatMul, + attention_norm: RmsNorm, + feed_forward_w1: QMatMul, + feed_forward_w2: QMatMul, + feed_forward_w3: QMatMul, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_rot.enter(); + let (b_sz, n_head, seq_len, n_embd) = x.dims4()?; + let cos = self + .cos + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let sin = self + .sin + .narrow(0, index_pos, seq_len)? + .reshape((seq_len, n_embd / 2, 1))?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd / 2, 1))?; + // This mimics the llama.cpp behavior. + // https://github.com/ggerganov/llama.cpp/blob/1f0bccb27929e261744c979bc75114955da49e98/ggml.c#L12104-L12105 + // The x0 and x1 value are interleaved on the n_embd (= head_dim) dimension. + // The resulting y0 and y1 are also interleaved with: + // y0 = x0*cos - x1*sin + // y1 = x0*sin + x1*cos + let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?; + let x0 = x.narrow(D::Minus1, 0, 1)?; + let x1 = x.narrow(D::Minus1, 1, 1)?; + let y0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; + let y1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; + let rope = Tensor::cat(&[y0, y1], D::Minus1)?; + let rope = rope.flatten_from(D::Minus2)?; + Ok(rope) + } + + fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + let k = Tensor::cat(&[k_cache, &k], 2)?.contiguous()?; + let v = Tensor::cat(&[v_cache, &v], 2)?.contiguous()?; + (k, v) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + // Support for MQA, useful for 70B models. + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let mask = mask.broadcast_as(att.shape())?; + let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; + let att = candle_nn::ops::softmax(&att, D::Minus1)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } + + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_kv_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?; + Ok(x) + } + } +} + +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec<LayerWeights>, + norm: RmsNorm, + output: QMatMul, + masks: HashMap<usize, Tensor>, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { + let cpu = &Device::Cpu; + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, 1e-5)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf<R: std::io::Seek + std::io::Read>( + ct: gguf_file::Content, + reader: &mut R, + ) -> Result<Self> { + let cpu = &Device::Cpu; + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()?; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight")?; + let tok_embeddings = tok_embeddings.dequantize(cpu)?; + let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?; + let output = ct.tensor(reader, "output.weight")?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?; + let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?; + let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?; + let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?; + let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: QMatMul::from_qtensor(attention_wq), + attention_wk: QMatMul::from_qtensor(attention_wk), + attention_wv: QMatMul::from_qtensor(attention_wv), + attention_wo: QMatMul::from_qtensor(attention_wo), + attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output), + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize) -> Result<Tensor> { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (_b_sz, seq_len) = x.dims2()?; + let mask = self.mask(seq_len)?; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, &mask, index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let w1 = layer.feed_forward_w1.forward(&x)?; + let w3 = layer.feed_forward_w3.forward(&x)?; + let mlp = layer + .feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?)?; + layer_in = (mlp + residual)?; + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} |