summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-28 17:51:19 +0200
committerGitHub <noreply@github.com>2023-10-28 16:51:19 +0100
commit95a857cf57c56a34ecdaae5372f2a13ebd900001 (patch)
tree9b0bac74758528addfdd27db331d3dcbae20f3ac /candle-examples/examples/llama2-c
parent612f5b81561150ca6651368c245ac2065c04159a (diff)
downloadcandle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.gz
candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.bz2
candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.zip
Move the llama2-c model in transformers. (#1205)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs6
-rw-r--r--candle-examples/examples/llama2-c/model.rs314
-rw-r--r--candle-examples/examples/llama2-c/qmodel.rs227
-rw-r--r--candle-examples/examples/llama2-c/weights.rs168
4 files changed, 3 insertions, 712 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 77dbc677..a3f01ae2 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -6,10 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
-mod model;
-mod qmodel;
+use candle_transformers::models::llama2_c as model;
+use candle_transformers::models::llama2_c_weights as weights;
+use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
-mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs
deleted file mode 100644
index 07a6e2f2..00000000
--- a/candle-examples/examples/llama2-c/model.rs
+++ /dev/null
@@ -1,314 +0,0 @@
-use candle::{DType, Device, IndexOp, Result, Tensor, D};
-use candle_nn::linear_no_bias as linear;
-use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
-use std::collections::HashMap;
-use std::sync::{Arc, Mutex};
-
-#[derive(Debug, Clone)]
-pub struct Config {
- pub dim: usize, // transformer dimension
- pub hidden_dim: usize, // for ffn layers
- pub n_layers: usize, // number of layers
- pub n_heads: usize, // number of query heads
- pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery)
- pub vocab_size: usize, // vocabulary size, usually 256 (byte-level)
- pub seq_len: usize, // max sequence length
- pub norm_eps: f64,
-}
-
-impl Config {
- pub fn tiny() -> Self {
- Self {
- dim: 288,
- hidden_dim: 768,
- n_layers: 6,
- n_heads: 6,
- n_kv_heads: 6,
- vocab_size: 32000,
- seq_len: 256,
- norm_eps: 1e-5,
- }
- }
-}
-
-#[derive(Clone)]
-pub struct Cache {
- masks: Arc<Mutex<HashMap<usize, Tensor>>>,
- pub use_kv_cache: bool,
- #[allow(clippy::type_complexity)]
- pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
- pub cos: Tensor,
- pub sin: Tensor,
- device: Device,
-}
-
-impl Cache {
- pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
- let n_elem = cfg.dim / cfg.n_heads;
- let theta: Vec<_> = (0..n_elem)
- .step_by(2)
- .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
- .collect();
- let theta = Tensor::new(theta.as_slice(), vb.device())?;
- let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())?
- .to_dtype(DType::F32)?
- .reshape((cfg.seq_len, 1))?
- .matmul(&theta.reshape((1, theta.elem_count()))?)?;
- let precomputed_cos = idx_theta.cos()?;
- let precomputed_sin = idx_theta.sin()?;
-
- let freq_cis_real = vb
- .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real")
- .unwrap_or(precomputed_cos);
- let freq_cis_imag = vb
- .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag")
- .unwrap_or(precomputed_sin);
- let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
- let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
- Ok(Self {
- masks: Arc::new(Mutex::new(HashMap::new())),
- use_kv_cache,
- kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
- cos,
- sin,
- device: vb.device().clone(),
- })
- }
-
- pub fn mask(&self, t: usize) -> Result<Tensor> {
- let mut masks = self.masks.lock().unwrap();
- if let Some(mask) = masks.get(&t) {
- Ok(mask.clone())
- } else {
- let mask: Vec<_> = (0..t)
- .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
- .collect();
- let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
- masks.insert(t, mask.clone());
- Ok(mask)
- }
- }
-}
-
-fn silu(xs: &Tensor) -> Result<Tensor> {
- xs / (xs.neg()?.exp()? + 1.0)?
-}
-
-struct CausalSelfAttention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
- o_proj: Linear,
- n_head: usize,
- n_key_value_head: usize,
- head_dim: usize,
- cache: Cache,
-}
-
-impl CausalSelfAttention {
- fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (b_sz, seq_len, h, n_embd) = x.dims4()?;
- let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
- let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
- let cos = cos.unsqueeze(1)?;
- let sin = sin.unsqueeze(1)?;
- let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
- let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
- let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
- let x0 = x.narrow(D::Minus1, 0, 1)?;
- let x1 = x.narrow(D::Minus1, 1, 1)?;
- let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
- let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
- let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
- Ok(rope)
- }
-
- fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let (b_sz, seq_len, n_embd) = x.dims3()?;
- let q = self.q_proj.forward(x)?;
- let k = self.k_proj.forward(x)?;
- let v = self.v_proj.forward(x)?;
-
- let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
- let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
- let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
-
- let q = self.apply_rotary_emb(&q, index_pos)?;
- let mut k = self.apply_rotary_emb(&k, index_pos)?;
-
- if self.cache.use_kv_cache {
- let mut cache = self.cache.kvs.lock().unwrap();
- if let Some((cache_k, cache_v)) = &cache[block_idx] {
- k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
- v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
- }
- cache[block_idx] = Some((k.clone(), v.clone()))
- }
-
- let k = self.repeat_kv(k)?;
- let v = self.repeat_kv(v)?;
-
- let q = q.transpose(1, 2)?.contiguous()?;
- let k = k.transpose(1, 2)?.contiguous()?;
- let v = v.transpose(1, 2)?.contiguous()?;
-
- let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
- let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
- // Convert to contiguous as matmul doesn't support strided vs for now.
- let y = att.matmul(&v.contiguous()?)?;
- let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
- let y = self.o_proj.forward(&y)?;
- Ok(y)
- }
-
- fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.n_head / self.n_key_value_head;
- if n_rep == 1 {
- Ok(x)
- } else {
- let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(3)?
- .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
- .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
- Ok(x)
- }
- }
-
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let size_in = cfg.dim;
- let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
- let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
- let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
- let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
- let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
- let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
- Ok(Self {
- q_proj,
- k_proj,
- v_proj,
- o_proj,
- n_head: cfg.n_heads,
- n_key_value_head: cfg.n_kv_heads,
- head_dim: cfg.dim / cfg.n_heads,
- cache: cache.clone(),
- })
- }
-}
-
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
- let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
- Ok(m)
-}
-
-struct Mlp {
- c_fc1: Linear,
- c_fc2: Linear,
- c_proj: Linear,
-}
-
-impl Mlp {
- fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
- Self {
- c_fc1,
- c_fc2,
- c_proj,
- }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
- self.c_proj.forward(&x)
- }
-
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let h_size = cfg.dim;
- let i_size = cfg.hidden_dim;
- let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
- let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
- let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
- Ok(Self::new(c_fc1, c_fc2, c_proj))
- }
-}
-
-struct Block {
- rms_1: RmsNorm,
- attn: CausalSelfAttention,
- rms_2: RmsNorm,
- mlp: Mlp,
-}
-
-impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
- Self {
- rms_1,
- attn,
- rms_2,
- mlp,
- }
- }
-
- fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let residual = x;
- let x = self.rms_1.forward(x)?;
- let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
- let residual = &x;
- let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
- Ok(x)
- }
-
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
- let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
- let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
- let post_attention_layernorm =
- rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
- Ok(Self::new(
- input_layernorm,
- attn,
- post_attention_layernorm,
- mlp,
- ))
- }
-}
-
-pub struct Llama {
- wte: Embedding,
- blocks: Vec<Block>,
- ln_f: RmsNorm,
- lm_head: Linear,
- pub config: Config,
-}
-
-impl Llama {
- pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (_b_sz, _seq_len) = x.dims2()?;
- let mut x = self.wte.forward(x)?;
- for (block_idx, block) in self.blocks.iter().enumerate() {
- x = block.forward(&x, index_pos, block_idx)?;
- }
- let x = self.ln_f.forward(&x)?;
- let logits = self.lm_head.forward(&x)?;
- logits.to_dtype(DType::F32)
- }
-
- pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
- let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
- let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
- let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
- let blocks: Vec<_> = (0..cfg.n_layers)
- .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
- .collect();
- Ok(Self {
- wte,
- blocks,
- ln_f,
- lm_head,
- config: cfg,
- })
- }
-}
diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-examples/examples/llama2-c/qmodel.rs
deleted file mode 100644
index 07db146e..00000000
--- a/candle-examples/examples/llama2-c/qmodel.rs
+++ /dev/null
@@ -1,227 +0,0 @@
-use super::model::{Cache, Config};
-use candle::{DType, IndexOp, Module, Result, Tensor, D};
-use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
-pub use candle_transformers::quantized_var_builder::VarBuilder;
-
-fn silu(xs: &Tensor) -> Result<Tensor> {
- xs / (xs.neg()?.exp()? + 1.0)?
-}
-
-struct CausalSelfAttention {
- q_proj: Linear,
- k_proj: Linear,
- v_proj: Linear,
- o_proj: Linear,
- n_head: usize,
- n_key_value_head: usize,
- head_dim: usize,
- cache: Cache,
-}
-
-impl CausalSelfAttention {
- fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (b_sz, seq_len, h, n_embd) = x.dims4()?;
- let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
- let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
- let cos = cos.unsqueeze(1)?;
- let sin = sin.unsqueeze(1)?;
- let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
- let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
- let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
- let x0 = x.narrow(D::Minus1, 0, 1)?;
- let x1 = x.narrow(D::Minus1, 1, 1)?;
- let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
- let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
- let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
- Ok(rope)
- }
-
- fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let (b_sz, seq_len, n_embd) = x.dims3()?;
- let q = self.q_proj.forward(x)?;
- let k = self.k_proj.forward(x)?;
- let v = self.v_proj.forward(x)?;
-
- let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
- let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
- let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
-
- let q = self.apply_rotary_emb(&q, index_pos)?;
- let mut k = self.apply_rotary_emb(&k, index_pos)?;
-
- if self.cache.use_kv_cache {
- let mut cache = self.cache.kvs.lock().unwrap();
- if let Some((cache_k, cache_v)) = &cache[block_idx] {
- k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
- v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
- }
- cache[block_idx] = Some((k.clone(), v.clone()))
- }
-
- let k = self.repeat_kv(k)?;
- let v = self.repeat_kv(v)?;
-
- let q = q.transpose(1, 2)?.contiguous()?;
- let k = k.transpose(1, 2)?.contiguous()?;
- let v = v.transpose(1, 2)?.contiguous()?;
-
- let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
- let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
- let att = candle_nn::ops::softmax(&att, D::Minus1)?;
- // Convert to contiguous as matmul doesn't support strided vs for now.
- let y = att.matmul(&v.contiguous()?)?;
- let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
- let y = self.o_proj.forward(&y)?;
- Ok(y)
- }
-
- fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.n_head / self.n_key_value_head;
- if n_rep == 1 {
- Ok(x)
- } else {
- let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
- let x = x
- .unsqueeze(3)?
- .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
- .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
- Ok(x)
- }
- }
-
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let size_in = cfg.dim;
- let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
- let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
- let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
- let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
- let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
- let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
- Ok(Self {
- q_proj,
- k_proj,
- v_proj,
- o_proj,
- n_head: cfg.n_heads,
- n_key_value_head: cfg.n_kv_heads,
- head_dim: cfg.dim / cfg.n_heads,
- cache: cache.clone(),
- })
- }
-}
-
-fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
- let shape = mask.shape();
- let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
- let m = mask.where_cond(&on_true, on_false)?;
- Ok(m)
-}
-
-struct Mlp {
- c_fc1: Linear,
- c_fc2: Linear,
- c_proj: Linear,
-}
-
-impl Mlp {
- fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
- Self {
- c_fc1,
- c_fc2,
- c_proj,
- }
- }
-
- fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
- self.c_proj.forward(&x)
- }
-
- fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
- let h_size = cfg.dim;
- let i_size = cfg.hidden_dim;
- let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
- let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
- let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
- Ok(Self::new(c_fc1, c_fc2, c_proj))
- }
-}
-
-struct Block {
- rms_1: RmsNorm,
- attn: CausalSelfAttention,
- rms_2: RmsNorm,
- mlp: Mlp,
-}
-
-impl Block {
- fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
- Self {
- rms_1,
- attn,
- rms_2,
- mlp,
- }
- }
-
- fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
- let residual = x;
- let x = self.rms_1.forward(x)?;
- let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
- let residual = &x;
- let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
- Ok(x)
- }
-
- fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
- let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
- let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
- let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
- let post_attention_layernorm =
- RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
- Ok(Self::new(
- input_layernorm,
- attn,
- post_attention_layernorm,
- mlp,
- ))
- }
-}
-
-pub struct QLlama {
- wte: Embedding,
- blocks: Vec<Block>,
- ln_f: RmsNorm,
- lm_head: Linear,
- pub config: Config,
-}
-
-impl QLlama {
- pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
- let (_b_sz, _seq_len) = x.dims2()?;
- let mut x = self.wte.forward(x)?;
- for (block_idx, block) in self.blocks.iter().enumerate() {
- x = block.forward(&x, index_pos, block_idx)?;
- }
- let x = self.ln_f.forward(&x)?;
- let logits = self.lm_head.forward(&x)?;
- logits.to_dtype(DType::F32)
- }
-
- pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
- let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
- let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
- let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
- let blocks: Vec<_> = (0..cfg.n_layers)
- .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
- .collect();
- Ok(Self {
- wte,
- blocks,
- ln_f,
- lm_head,
- config: cfg,
- })
- }
-}
diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs
deleted file mode 100644
index b78418ce..00000000
--- a/candle-examples/examples/llama2-c/weights.rs
+++ /dev/null
@@ -1,168 +0,0 @@
-use anyhow::Result;
-use byteorder::{LittleEndian, ReadBytesExt};
-use candle::{DType, Device, IndexOp, Shape, Tensor};
-use candle_nn::VarBuilder;
-
-use crate::model::Config;
-
-pub struct TransformerWeights {
- // token embedding table
- token_embedding_table: Tensor, // (vocab_size, dim)
- // weights for rmsnorms
- rms_att_weight: Tensor, // (layer, dim) rmsnorm weights
- rms_ffn_weight: Tensor, // (layer, dim)
- // weights for matmuls
- wq: Tensor, // (layer, dim, dim)
- wk: Tensor, // (layer, dim, dim)
- wv: Tensor, // (layer, dim, dim)
- wo: Tensor, // (layer, dim, dim)
- // weights for ffn
- w1: Tensor, // (layer, hidden_dim, dim)
- w2: Tensor, // (layer, dim, hidden_dim)
- w3: Tensor, // (layer, hidden_dim, dim)
- // final rmsnorm
- rms_final_weight: Tensor, // (dim,)
- // freq_cis for RoPE relatively positional embeddings
- freq_cis_real: Tensor, // (seq_len, head_size/2)
- freq_cis_imag: Tensor, // (seq_len, head_size/2)
-}
-
-fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> {
- let mut buf = [0u8; 4];
- r.read_exact(&mut buf)?;
- Ok(i32::from_le_bytes(buf))
-}
-
-fn read_tensor<R: std::io::Read, S: Into<Shape>>(
- r: &mut R,
- shape: S,
- dev: &Device,
-) -> Result<Tensor> {
- let shape = shape.into();
- let mut data_t = vec![0f32; shape.elem_count()];
- r.read_f32_into::<LittleEndian>(&mut data_t)?;
- let tensor = Tensor::from_vec(data_t, shape, dev)?;
- Ok(tensor)
-}
-
-impl Config {
- pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> {
- let dim = read_i32(r)? as usize;
- let hidden_dim = read_i32(r)? as usize;
- let n_layers = read_i32(r)? as usize;
- let n_heads = read_i32(r)? as usize;
- let n_kv_heads = read_i32(r)? as usize;
- let vocab_size = read_i32(r)? as usize;
- let seq_len = read_i32(r)? as usize;
- Ok(Self {
- dim,
- hidden_dim,
- n_layers,
- n_heads,
- n_kv_heads,
- vocab_size,
- seq_len,
- norm_eps: 1e-5,
- })
- }
-
- pub fn head_size(&self) -> usize {
- self.dim / self.n_heads
- }
-}
-
-impl TransformerWeights {
- pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> {
- let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?;
- let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
- let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
- let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
- let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
- let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?;
- let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?;
- let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
- let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?;
- let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?;
- let rms_final_weight = read_tensor(r, c.dim, dev)?;
- let head_size = c.head_size();
- let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
- let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?;
- Ok(Self {
- token_embedding_table,
- rms_att_weight,
- wq,
- wk,
- wv,
- wo,
- rms_ffn_weight,
- w1,
- w2,
- w3,
- rms_final_weight,
- freq_cis_real,
- freq_cis_imag,
- })
- }
-
- pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> {
- // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of
- // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the
- // second matrix back. We detect this case here and as a temporary hack make the weight
- // matrix column major rather than row major. This ends up speeding up text generation from
- // 120 token/s to 220 token/s on a Ryzen 2600X.
- let tr = device.is_cpu() && !candle::utils::has_mkl();
- let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
- let mut ws = std::collections::HashMap::new();
- let mut insert = |name: &str, t: Tensor| {
- ws.insert(name.to_string(), t);
- };
- insert("rot.freq_cis_real", self.freq_cis_real.clone());
- insert("rot.freq_cis_imag", self.freq_cis_imag.clone());
- insert(
- "model.embed_tokens.weight",
- self.token_embedding_table.clone(),
- );
- insert("lm_head.weight", tr(self.token_embedding_table.clone())?);
- insert("model.norm.weight", self.rms_final_weight.clone());
- for layer in 0..cfg.n_layers {
- ws.insert(
- format!("model.layers.{layer}.self_attn.q_proj.weight"),
- tr(self.wq.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.self_attn.k_proj.weight"),
- tr(self.wk.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.self_attn.v_proj.weight"),
- tr(self.wv.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.self_attn.o_proj.weight"),
- tr(self.wo.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.mlp.gate_proj.weight"),
- tr(self.w1.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.mlp.down_proj.weight"),
- tr(self.w2.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.mlp.up_proj.weight"),
- tr(self.w3.i(layer)?)?,
- );
- ws.insert(
- format!("model.layers.{layer}.input_layernorm.weight"),
- self.rms_att_weight.i(layer)?,
- );
- ws.insert(
- format!("model.layers.{layer}.post_attention_layernorm.weight"),
- self.rms_ffn_weight.i(layer)?,
- );
- }
- let vb = VarBuilder::from_tensors(ws, DType::F32, device);
- Ok(vb)
- }
-}