diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-10-28 17:51:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-28 16:51:19 +0100 |
commit | 95a857cf57c56a34ecdaae5372f2a13ebd900001 (patch) | |
tree | 9b0bac74758528addfdd27db331d3dcbae20f3ac /candle-examples/examples/llama2-c | |
parent | 612f5b81561150ca6651368c245ac2065c04159a (diff) | |
download | candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.gz candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.tar.bz2 candle-95a857cf57c56a34ecdaae5372f2a13ebd900001.zip |
Move the llama2-c model in transformers. (#1205)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/model.rs | 314 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/qmodel.rs | 227 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/weights.rs | 168 |
4 files changed, 3 insertions, 712 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 77dbc677..a3f01ae2 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -6,10 +6,10 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -mod model; -mod qmodel; +use candle_transformers::models::llama2_c as model; +use candle_transformers::models::llama2_c_weights as weights; +use candle_transformers::models::quantized_llama2_c as qmodel; mod training; -mod weights; use clap::{Parser, Subcommand}; use anyhow::{Error as E, Result}; diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-examples/examples/llama2-c/model.rs deleted file mode 100644 index 07a6e2f2..00000000 --- a/candle-examples/examples/llama2-c/model.rs +++ /dev/null @@ -1,314 +0,0 @@ -use candle::{DType, Device, IndexOp, Result, Tensor, D}; -use candle_nn::linear_no_bias as linear; -use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; -use std::collections::HashMap; -use std::sync::{Arc, Mutex}; - -#[derive(Debug, Clone)] -pub struct Config { - pub dim: usize, // transformer dimension - pub hidden_dim: usize, // for ffn layers - pub n_layers: usize, // number of layers - pub n_heads: usize, // number of query heads - pub n_kv_heads: usize, // number of key/value heads (can be < query heads because of multiquery) - pub vocab_size: usize, // vocabulary size, usually 256 (byte-level) - pub seq_len: usize, // max sequence length - pub norm_eps: f64, -} - -impl Config { - pub fn tiny() -> Self { - Self { - dim: 288, - hidden_dim: 768, - n_layers: 6, - n_heads: 6, - n_kv_heads: 6, - vocab_size: 32000, - seq_len: 256, - norm_eps: 1e-5, - } - } -} - -#[derive(Clone)] -pub struct Cache { - masks: Arc<Mutex<HashMap<usize, Tensor>>>, - pub use_kv_cache: bool, - #[allow(clippy::type_complexity)] - pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, - pub cos: Tensor, - pub sin: Tensor, - device: Device, -} - -impl Cache { - pub fn new(use_kv_cache: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> { - let n_elem = cfg.dim / cfg.n_heads; - let theta: Vec<_> = (0..n_elem) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), vb.device())?; - let idx_theta = Tensor::arange(0, cfg.seq_len as u32, vb.device())? - .to_dtype(DType::F32)? - .reshape((cfg.seq_len, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let precomputed_cos = idx_theta.cos()?; - let precomputed_sin = idx_theta.sin()?; - - let freq_cis_real = vb - .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_real") - .unwrap_or(precomputed_cos); - let freq_cis_imag = vb - .get((cfg.seq_len, cfg.head_size() / 2), "freq_cis_imag") - .unwrap_or(precomputed_sin); - let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; - let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; - Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), - use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), - cos, - sin, - device: vb.device().clone(), - }) - } - - pub fn mask(&self, t: usize) -> Result<Tensor> { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { - Ok(mask.clone()) - } else { - let mask: Vec<_> = (0..t) - .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) - .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); - Ok(mask) - } - } -} - -fn silu(xs: &Tensor) -> Result<Tensor> { - xs / (xs.neg()?.exp()? + 1.0)? -} - -struct CausalSelfAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - n_head: usize, - n_key_value_head: usize, - head_dim: usize, - cache: Cache, -} - -impl CausalSelfAttention { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; - let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; - let cos = cos.unsqueeze(1)?; - let sin = sin.unsqueeze(1)?; - let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; - let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; - Ok(rope) - } - - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.q_proj.forward(x)?; - let k = self.k_proj.forward(x)?; - let v = self.v_proj.forward(x)?; - - let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?; - let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; - - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; - } - cache[block_idx] = Some((k.clone(), v.clone())) - } - - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let q = q.transpose(1, 2)?.contiguous()?; - let k = k.transpose(1, 2)?.contiguous()?; - let v = v.transpose(1, 2)?.contiguous()?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.o_proj.forward(&y)?; - Ok(y) - } - - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_key_value_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?; - let x = x - .unsqueeze(3)? - .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))? - .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; - Ok(x) - } - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let size_in = cfg.dim; - let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; - let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - n_head: cfg.n_heads, - n_key_value_head: cfg.n_kv_heads, - head_dim: cfg.dim / cfg.n_heads, - cache: cache.clone(), - }) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -struct Mlp { - c_fc1: Linear, - c_fc2: Linear, - c_proj: Linear, -} - -impl Mlp { - fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - Self { - c_fc1, - c_fc2, - c_proj, - } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; - self.c_proj.forward(&x) - } - - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let h_size = cfg.dim; - let i_size = cfg.hidden_dim; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) - } -} - -struct Block { - rms_1: RmsNorm, - attn: CausalSelfAttention, - rms_2: RmsNorm, - mlp: Mlp, -} - -impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { - Self { - rms_1, - attn, - rms_2, - mlp, - } - } - - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let residual = x; - let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; - let residual = &x; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; - Ok(x) - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; - let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = - rms_norm(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; - Ok(Self::new( - input_layernorm, - attn, - post_attention_layernorm, - mlp, - )) - } -} - -pub struct Llama { - wte: Embedding, - blocks: Vec<Block>, - ln_f: RmsNorm, - lm_head: Linear, - pub config: Config, -} - -impl Llama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, _seq_len) = x.dims2()?; - let mut x = self.wte.forward(x)?; - for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; - } - let x = self.ln_f.forward(&x)?; - let logits = self.lm_head.forward(&x)?; - logits.to_dtype(DType::F32) - } - - pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { - let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; - let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) - .collect(); - Ok(Self { - wte, - blocks, - ln_f, - lm_head, - config: cfg, - }) - } -} diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-examples/examples/llama2-c/qmodel.rs deleted file mode 100644 index 07db146e..00000000 --- a/candle-examples/examples/llama2-c/qmodel.rs +++ /dev/null @@ -1,227 +0,0 @@ -use super::model::{Cache, Config}; -use candle::{DType, IndexOp, Module, Result, Tensor, D}; -use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; -pub use candle_transformers::quantized_var_builder::VarBuilder; - -fn silu(xs: &Tensor) -> Result<Tensor> { - xs / (xs.neg()?.exp()? + 1.0)? -} - -struct CausalSelfAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - n_head: usize, - n_key_value_head: usize, - head_dim: usize, - cache: Cache, -} - -impl CausalSelfAttention { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; - let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; - let cos = cos.unsqueeze(1)?; - let sin = sin.unsqueeze(1)?; - let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; - let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; - let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; - let x0 = x.narrow(D::Minus1, 0, 1)?; - let x1 = x.narrow(D::Minus1, 1, 1)?; - let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; - let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; - let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; - Ok(rope) - } - - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let (b_sz, seq_len, n_embd) = x.dims3()?; - let q = self.q_proj.forward(x)?; - let k = self.k_proj.forward(x)?; - let v = self.v_proj.forward(x)?; - - let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?; - let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; - - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { - k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; - v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; - } - cache[block_idx] = Some((k.clone(), v.clone())) - } - - let k = self.repeat_kv(k)?; - let v = self.repeat_kv(v)?; - - let q = q.transpose(1, 2)?.contiguous()?; - let k = k.transpose(1, 2)?.contiguous()?; - let v = v.transpose(1, 2)?.contiguous()?; - - let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; - let y = self.o_proj.forward(&y)?; - Ok(y) - } - - fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_key_value_head; - if n_rep == 1 { - Ok(x) - } else { - let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?; - let x = x - .unsqueeze(3)? - .expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))? - .reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?; - Ok(x) - } - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let size_in = cfg.dim; - let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; - let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - n_head: cfg.n_heads, - n_key_value_head: cfg.n_kv_heads, - head_dim: cfg.dim / cfg.n_heads, - cache: cache.clone(), - }) - } -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) -} - -struct Mlp { - c_fc1: Linear, - c_fc2: Linear, - c_proj: Linear, -} - -impl Mlp { - fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self { - Self { - c_fc1, - c_fc2, - c_proj, - } - } - - fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; - self.c_proj.forward(&x) - } - - fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { - let h_size = cfg.dim; - let i_size = cfg.hidden_dim; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; - Ok(Self::new(c_fc1, c_fc2, c_proj)) - } -} - -struct Block { - rms_1: RmsNorm, - attn: CausalSelfAttention, - rms_2: RmsNorm, - mlp: Mlp, -} - -impl Block { - fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { - Self { - rms_1, - attn, - rms_2, - mlp, - } - } - - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { - let residual = x; - let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; - let residual = &x; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; - Ok(x) - } - - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; - let mlp = Mlp::load(vb.pp("mlp"), cfg)?; - let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = - RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?; - Ok(Self::new( - input_layernorm, - attn, - post_attention_layernorm, - mlp, - )) - } -} - -pub struct QLlama { - wte: Embedding, - blocks: Vec<Block>, - ln_f: RmsNorm, - lm_head: Linear, - pub config: Config, -} - -impl QLlama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { - let (_b_sz, _seq_len) = x.dims2()?; - let mut x = self.wte.forward(x)?; - for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; - } - let x = self.ln_f.forward(&x)?; - let logits = self.lm_head.forward(&x)?; - logits.to_dtype(DType::F32) - } - - pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> { - let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; - let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; - let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap()) - .collect(); - Ok(Self { - wte, - blocks, - ln_f, - lm_head, - config: cfg, - }) - } -} diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs deleted file mode 100644 index b78418ce..00000000 --- a/candle-examples/examples/llama2-c/weights.rs +++ /dev/null @@ -1,168 +0,0 @@ -use anyhow::Result; -use byteorder::{LittleEndian, ReadBytesExt}; -use candle::{DType, Device, IndexOp, Shape, Tensor}; -use candle_nn::VarBuilder; - -use crate::model::Config; - -pub struct TransformerWeights { - // token embedding table - token_embedding_table: Tensor, // (vocab_size, dim) - // weights for rmsnorms - rms_att_weight: Tensor, // (layer, dim) rmsnorm weights - rms_ffn_weight: Tensor, // (layer, dim) - // weights for matmuls - wq: Tensor, // (layer, dim, dim) - wk: Tensor, // (layer, dim, dim) - wv: Tensor, // (layer, dim, dim) - wo: Tensor, // (layer, dim, dim) - // weights for ffn - w1: Tensor, // (layer, hidden_dim, dim) - w2: Tensor, // (layer, dim, hidden_dim) - w3: Tensor, // (layer, hidden_dim, dim) - // final rmsnorm - rms_final_weight: Tensor, // (dim,) - // freq_cis for RoPE relatively positional embeddings - freq_cis_real: Tensor, // (seq_len, head_size/2) - freq_cis_imag: Tensor, // (seq_len, head_size/2) -} - -fn read_i32<R: std::io::Read>(r: &mut R) -> Result<i32> { - let mut buf = [0u8; 4]; - r.read_exact(&mut buf)?; - Ok(i32::from_le_bytes(buf)) -} - -fn read_tensor<R: std::io::Read, S: Into<Shape>>( - r: &mut R, - shape: S, - dev: &Device, -) -> Result<Tensor> { - let shape = shape.into(); - let mut data_t = vec![0f32; shape.elem_count()]; - r.read_f32_into::<LittleEndian>(&mut data_t)?; - let tensor = Tensor::from_vec(data_t, shape, dev)?; - Ok(tensor) -} - -impl Config { - pub fn from_reader<R: std::io::Read>(r: &mut R) -> Result<Self> { - let dim = read_i32(r)? as usize; - let hidden_dim = read_i32(r)? as usize; - let n_layers = read_i32(r)? as usize; - let n_heads = read_i32(r)? as usize; - let n_kv_heads = read_i32(r)? as usize; - let vocab_size = read_i32(r)? as usize; - let seq_len = read_i32(r)? as usize; - Ok(Self { - dim, - hidden_dim, - n_layers, - n_heads, - n_kv_heads, - vocab_size, - seq_len, - norm_eps: 1e-5, - }) - } - - pub fn head_size(&self) -> usize { - self.dim / self.n_heads - } -} - -impl TransformerWeights { - pub fn from_reader<R: std::io::Read>(r: &mut R, c: &Config, dev: &Device) -> Result<Self> { - let token_embedding_table = read_tensor(r, (c.vocab_size, c.dim), dev)?; - let rms_att_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; - let wq = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wk = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wv = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let wo = read_tensor(r, (c.n_layers, c.dim, c.dim), dev)?; - let rms_ffn_weight = read_tensor(r, (c.n_layers, c.dim), dev)?; - let w1 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let w2 = read_tensor(r, (c.n_layers, c.dim, c.hidden_dim), dev)?; - let w3 = read_tensor(r, (c.n_layers, c.hidden_dim, c.dim), dev)?; - let rms_final_weight = read_tensor(r, c.dim, dev)?; - let head_size = c.head_size(); - let freq_cis_real = read_tensor(r, (c.seq_len, head_size / 2), dev)?; - let freq_cis_imag = read_tensor(r, (c.seq_len, head_size / 2), dev)?; - Ok(Self { - token_embedding_table, - rms_att_weight, - wq, - wk, - wv, - wo, - rms_ffn_weight, - w1, - w2, - w3, - rms_final_weight, - freq_cis_real, - freq_cis_imag, - }) - } - - pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result<VarBuilder<'static>> { - // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of - // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the - // second matrix back. We detect this case here and as a temporary hack make the weight - // matrix column major rather than row major. This ends up speeding up text generation from - // 120 token/s to 220 token/s on a Ryzen 2600X. - let tr = device.is_cpu() && !candle::utils::has_mkl(); - let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) }; - let mut ws = std::collections::HashMap::new(); - let mut insert = |name: &str, t: Tensor| { - ws.insert(name.to_string(), t); - }; - insert("rot.freq_cis_real", self.freq_cis_real.clone()); - insert("rot.freq_cis_imag", self.freq_cis_imag.clone()); - insert( - "model.embed_tokens.weight", - self.token_embedding_table.clone(), - ); - insert("lm_head.weight", tr(self.token_embedding_table.clone())?); - insert("model.norm.weight", self.rms_final_weight.clone()); - for layer in 0..cfg.n_layers { - ws.insert( - format!("model.layers.{layer}.self_attn.q_proj.weight"), - tr(self.wq.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.k_proj.weight"), - tr(self.wk.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.v_proj.weight"), - tr(self.wv.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.self_attn.o_proj.weight"), - tr(self.wo.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.gate_proj.weight"), - tr(self.w1.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.down_proj.weight"), - tr(self.w2.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.mlp.up_proj.weight"), - tr(self.w3.i(layer)?)?, - ); - ws.insert( - format!("model.layers.{layer}.input_layernorm.weight"), - self.rms_att_weight.i(layer)?, - ); - ws.insert( - format!("model.layers.{layer}.post_attention_layernorm.weight"), - self.rms_ffn_weight.i(layer)?, - ); - } - let vb = VarBuilder::from_tensors(ws, DType::F32, device); - Ok(vb) - } -} |