diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-16 10:29:46 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-08-16 10:36:01 +0200 |
commit | 76804730c6a18e18240897369ada71bdb67151cf (patch) | |
tree | f126237a226ae230353d0894aa7afef80771f25b /candle-examples/examples/llama | |
parent | 965597a873090dc10c3b1f215b9a5cf06017a8ba (diff) | |
download | candle-76804730c6a18e18240897369ada71bdb67151cf.tar.gz candle-76804730c6a18e18240897369ada71bdb67151cf.tar.bz2 candle-76804730c6a18e18240897369ada71bdb67151cf.zip |
Using the real config from the hub when available.
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r-- | candle-examples/examples/llama/main.rs | 28 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 90 |
2 files changed, 75 insertions, 43 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 98ff9cca..def5eb20 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -22,7 +22,7 @@ use hf_hub::api::sync::Api; use std::io::Write; mod model; -use model::{Config, Llama}; +use model::{Config, Llama, LlamaConfig}; const EOS_TOKEN: &str = "</s>"; const MAX_SEQ_LEN: usize = 4096; @@ -98,18 +98,18 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; - let config = if args.v1 { - Config::config_7b_v1(args.use_flash_attn) - } else { - Config::config_7b_v2(args.use_flash_attn) - }; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; - let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; - let (llama, tokenizer_filename) = match args.npy { + let (llama, tokenizer_filename, cache) = match args.npy { Some(filename) => { + let config = if args.v1 { + Config::config_7b_v1(args.use_flash_attn) + } else { + Config::config_7b_v2(args.use_flash_attn) + }; + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = VarBuilder::from_npz(filename, dtype, &device)?; let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); - (Llama::load(vb, &cache, &config)?, tokenizer) + (Llama::load(vb, &cache, &config)?, tokenizer, cache) } None => { let api = Api::new()?; @@ -128,6 +128,13 @@ fn main() -> Result<()> { _ => api.get("tokenizer.json")?, }; + let config_filename = match &args.local_weights { + Some(path) => (path.to_owned() + "config.json").into(), + _ => api.get("config.json")?, + }; + let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let config = config.into_config(args.use_flash_attn); + let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", @@ -153,9 +160,10 @@ fn main() -> Result<()> { .iter() .map(|h| Ok(h.deserialize()?)) .collect::<Result<Vec<_>>>()?; + let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let vb = VarBuilder::from_safetensors(tensors, dtype, &device); - (Llama::load(vb, &cache, &config)?, tokenizer_filename) + (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) } }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index f5ac587e..6ee4a585 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -2,17 +2,43 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, VarBuilder}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; +use serde::Deserialize; use super::MAX_SEQ_LEN; +#[derive(Deserialize)] +pub struct LlamaConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub vocab_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, +} + +impl LlamaConfig{ + pub fn into_config(&self, use_flash_attn: bool) -> Config{ + Config{ + hidden_size: self.hidden_size, + intermediate_size: self.intermediate_size, + vocab_size: self.vocab_size, + num_hidden_layers: self.num_hidden_layers, + num_attention_heads: self.num_attention_heads, + num_key_value_heads: self.num_key_value_heads, + rms_norm_eps: self.rms_norm_eps, + use_flash_attn + } + } +} + pub struct Config { pub hidden_size: usize, pub intermediate_size: usize, pub vocab_size: usize, - pub n_layer: usize, - pub n_head: usize, - pub n_embd: usize, - pub n_key_value_head: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, pub use_flash_attn: bool, pub rms_norm_eps: f64, } @@ -23,10 +49,9 @@ impl Config { hidden_size: 4096, intermediate_size: 11008, vocab_size: 32000, - n_layer: 32, - n_head: 32, - n_embd: 4096, - n_key_value_head: 32, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, use_flash_attn, rms_norm_eps: 1e-6, } @@ -37,10 +62,9 @@ impl Config { hidden_size: 4096, intermediate_size: 11008, vocab_size: 32000, - n_layer: 32, - n_head: 32, - n_embd: 4096, - n_key_value_head: 32, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, use_flash_attn, rms_norm_eps: 1e-5, } @@ -76,7 +100,7 @@ pub struct Cache { impl Cache { pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> { // precompute freqs_cis - let n_elem = config.n_embd / config.n_head; + let n_elem = config.hidden_size / config.num_attention_heads; let theta: Vec<_> = (0..n_elem) .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) @@ -94,7 +118,7 @@ impl Cache { Ok(Self { masks: Arc::new(Mutex::new(HashMap::new())), use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), device: device.clone(), cos, sin, @@ -169,8 +193,8 @@ struct CausalSelfAttention { k_proj: Linear, v_proj: Linear, o_proj: Linear, - n_head: usize, - n_key_value_head: usize, + num_attention_heads: usize, + num_key_value_heads: usize, head_dim: usize, cache: Cache, use_flash_attn: bool, @@ -197,13 +221,13 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let _enter = self.span_rot.enter(); - let (b_sz, _, seq_len, n_embd) = x.dims4()?; + let (b_sz, _, seq_len, hidden_size) = x.dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; - let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; + let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; Ok(rope) @@ -211,19 +235,19 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { let _enter = self.span.enter(); - let (b_sz, seq_len, n_embd) = x.dims3()?; + let (b_sz, seq_len, hidden_size) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? .transpose(1, 2)?; let k = k - .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; let mut v = v - .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; @@ -272,13 +296,13 @@ impl CausalSelfAttention { // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? }; - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; let y = self.o_proj.forward(&y)?; Ok(y) } fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { - let n_rep = self.n_head / self.n_key_value_head; + let n_rep = self.num_attention_heads / self.num_key_value_heads; if n_rep == 1 { Ok(x) } else { @@ -295,8 +319,8 @@ impl CausalSelfAttention { let span = tracing::span!(tracing::Level::TRACE, "attn"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let size_in = cfg.hidden_size; - let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; - let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; + let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; + let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; @@ -306,9 +330,9 @@ impl CausalSelfAttention { k_proj, v_proj, o_proj, - n_head: cfg.n_head, - n_key_value_head: cfg.n_key_value_head, - head_dim: cfg.hidden_size / cfg.n_head, + num_attention_heads: cfg.num_attention_heads, + num_key_value_heads: cfg.num_key_value_heads, + head_dim: cfg.hidden_size / cfg.num_attention_heads, cache: cache.clone(), use_flash_attn: cfg.use_flash_attn, span, @@ -417,7 +441,7 @@ impl Llama { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; - let blocks: Vec<_> = (0..cfg.n_layer) + let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap()) .collect(); |