summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-08-16 10:29:46 +0200
committerNicolas Patry <patry.nicolas@protonmail.com>2023-08-16 10:36:01 +0200
commit76804730c6a18e18240897369ada71bdb67151cf (patch)
treef126237a226ae230353d0894aa7afef80771f25b /candle-examples/examples/llama
parent965597a873090dc10c3b1f215b9a5cf06017a8ba (diff)
downloadcandle-76804730c6a18e18240897369ada71bdb67151cf.tar.gz
candle-76804730c6a18e18240897369ada71bdb67151cf.tar.bz2
candle-76804730c6a18e18240897369ada71bdb67151cf.zip
Using the real config from the hub when available.
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/main.rs28
-rw-r--r--candle-examples/examples/llama/model.rs90
2 files changed, 75 insertions, 43 deletions
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index 98ff9cca..def5eb20 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -22,7 +22,7 @@ use hf_hub::api::sync::Api;
use std::io::Write;
mod model;
-use model::{Config, Llama};
+use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
@@ -98,18 +98,18 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
- let config = if args.v1 {
- Config::config_7b_v1(args.use_flash_attn)
- } else {
- Config::config_7b_v2(args.use_flash_attn)
- };
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
- let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
- let (llama, tokenizer_filename) = match args.npy {
+ let (llama, tokenizer_filename, cache) = match args.npy {
Some(filename) => {
+ let config = if args.v1 {
+ Config::config_7b_v1(args.use_flash_attn)
+ } else {
+ Config::config_7b_v2(args.use_flash_attn)
+ };
+ let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
- (Llama::load(vb, &cache, &config)?, tokenizer)
+ (Llama::load(vb, &cache, &config)?, tokenizer, cache)
}
None => {
let api = Api::new()?;
@@ -128,6 +128,13 @@ fn main() -> Result<()> {
_ => api.get("tokenizer.json")?,
};
+ let config_filename = match &args.local_weights {
+ Some(path) => (path.to_owned() + "config.json").into(),
+ _ => api.get("config.json")?,
+ };
+ let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
+ let config = config.into_config(args.use_flash_attn);
+
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
@@ -153,9 +160,10 @@ fn main() -> Result<()> {
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
+ let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
- (Llama::load(vb, &cache, &config)?, tokenizer_filename)
+ (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs
index f5ac587e..6ee4a585 100644
--- a/candle-examples/examples/llama/model.rs
+++ b/candle-examples/examples/llama/model.rs
@@ -2,17 +2,43 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
+use serde::Deserialize;
use super::MAX_SEQ_LEN;
+#[derive(Deserialize)]
+pub struct LlamaConfig {
+ pub hidden_size: usize,
+ pub intermediate_size: usize,
+ pub vocab_size: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub num_key_value_heads: usize,
+ pub rms_norm_eps: f64,
+}
+
+impl LlamaConfig{
+ pub fn into_config(&self, use_flash_attn: bool) -> Config{
+ Config{
+ hidden_size: self.hidden_size,
+ intermediate_size: self.intermediate_size,
+ vocab_size: self.vocab_size,
+ num_hidden_layers: self.num_hidden_layers,
+ num_attention_heads: self.num_attention_heads,
+ num_key_value_heads: self.num_key_value_heads,
+ rms_norm_eps: self.rms_norm_eps,
+ use_flash_attn
+ }
+ }
+}
+
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
- pub n_layer: usize,
- pub n_head: usize,
- pub n_embd: usize,
- pub n_key_value_head: usize,
+ pub num_hidden_layers: usize,
+ pub num_attention_heads: usize,
+ pub num_key_value_heads: usize,
pub use_flash_attn: bool,
pub rms_norm_eps: f64,
}
@@ -23,10 +49,9 @@ impl Config {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
- n_layer: 32,
- n_head: 32,
- n_embd: 4096,
- n_key_value_head: 32,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-6,
}
@@ -37,10 +62,9 @@ impl Config {
hidden_size: 4096,
intermediate_size: 11008,
vocab_size: 32000,
- n_layer: 32,
- n_head: 32,
- n_embd: 4096,
- n_key_value_head: 32,
+ num_hidden_layers: 32,
+ num_attention_heads: 32,
+ num_key_value_heads: 32,
use_flash_attn,
rms_norm_eps: 1e-5,
}
@@ -76,7 +100,7 @@ pub struct Cache {
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
- let n_elem = config.n_embd / config.n_head;
+ let n_elem = config.hidden_size / config.num_attention_heads;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
@@ -94,7 +118,7 @@ impl Cache {
Ok(Self {
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
- kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
+ kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
device: device.clone(),
cos,
sin,
@@ -169,8 +193,8 @@ struct CausalSelfAttention {
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
- n_head: usize,
- n_key_value_head: usize,
+ num_attention_heads: usize,
+ num_key_value_heads: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
@@ -197,13 +221,13 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
- let (b_sz, _, seq_len, n_embd) = x.dims4()?;
+ let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
- let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
- let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
- let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
- let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
+ let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
+ let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
+ let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
+ let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
Ok(rope)
@@ -211,19 +235,19 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
- let (b_sz, seq_len, n_embd) = x.dims3()?;
+ let (b_sz, seq_len, hidden_size) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q
- .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
- .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
let mut v = v
- .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
+ .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
@@ -272,13 +296,13 @@ impl CausalSelfAttention {
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
};
- let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
+ let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
- let n_rep = self.n_head / self.n_key_value_head;
+ let n_rep = self.num_attention_heads / self.num_key_value_heads;
if n_rep == 1 {
Ok(x)
} else {
@@ -295,8 +319,8 @@ impl CausalSelfAttention {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
- let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
- let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
+ let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
+ let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
@@ -306,9 +330,9 @@ impl CausalSelfAttention {
k_proj,
v_proj,
o_proj,
- n_head: cfg.n_head,
- n_key_value_head: cfg.n_key_value_head,
- head_dim: cfg.hidden_size / cfg.n_head,
+ num_attention_heads: cfg.num_attention_heads,
+ num_key_value_heads: cfg.num_key_value_heads,
+ head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
@@ -417,7 +441,7 @@ impl Llama {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
- let blocks: Vec<_> = (0..cfg.n_layer)
+ let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();