diff options
-rw-r--r-- | README.md | 2 | ||||
-rw-r--r-- | candle-examples/examples/llama/main.rs | 47 | ||||
-rw-r--r-- | candle-examples/examples/llama/model.rs | 184 |
3 files changed, 123 insertions, 110 deletions
@@ -13,7 +13,7 @@ let c = a.matmul(&b)?; Check out our [examples](./candle-examples/examples/): - [Whisper](./candle-examples/examples/whisper/) -- [Llama](./candle-examples/examples/llama/) +- [Llama and Llama-v2](./candle-examples/examples/llama/) - [Bert](./candle-examples/examples/bert/) (Useful for sentence embeddings) - [Falcon](./candle-examples/examples/falcon/) diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 40f1af06..ea3ee4d7 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -15,7 +15,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use clap::Parser; -use candle::{DType, Device, Tensor, D}; +use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; @@ -76,23 +76,6 @@ Whate'er it bodes, henceforward will I bear Upon my target three fair-shining suns. "; -fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> { - let n_elem = config.n_embd / config.n_head; - let theta: Vec<_> = (0..n_elem) - .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) - .collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? - .to_dtype(DType::F32)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1]; - let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; - let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; - Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?) -} - #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -127,6 +110,12 @@ struct Args { /// Use f32 computations rather than f16. #[arg(long)] use_f32: bool, + + #[arg(long)] + model_id: Option<String>, + + #[arg(long)] + v2: bool, } fn main() -> Result<()> { @@ -136,7 +125,7 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let config = Config::config_7b(); - let cache = model::Cache::new(!args.no_kv_cache, &config, &device); + let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?; let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; let (llama, tokenizer_filename) = match args.npy { Some(filename) => { @@ -146,8 +135,15 @@ fn main() -> Result<()> { } None => { let api = Api::new()?; - let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); - println!("loading the model weights"); + let model_id = args.model_id.unwrap_or_else(|| { + if args.v2 { + "meta-llama/Llama-2-7b-hf".to_string() + } else { + "Narsil/amall-7b".to_string() + } + }); + println!("loading the model weights from {model_id}"); + let repo = Repo::new(model_id, RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ @@ -180,8 +176,6 @@ fn main() -> Result<()> { .get_ids() .to_vec(); - println!("pre-computing the positional embeddings"); - let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); let mut new_tokens = vec![]; @@ -196,12 +190,7 @@ fn main() -> Result<()> { }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let freqs_cis = if cache.use_kv_cache { - freqs_cis.narrow(1, index_pos, ctxt.len())? - } else { - freqs_cis.clone() - }; - let logits = llama.forward(&input, &freqs_cis)?; + let logits = llama.forward(&input, index_pos)?; let logits = logits.squeeze(0)?; index_pos += ctxt.len(); diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index ce2e6d2e..f3e30ec9 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -12,6 +12,7 @@ pub struct Config { pub n_layer: usize, pub n_head: usize, pub n_embd: usize, + pub n_key_value_head: usize, } impl Config { @@ -23,6 +24,7 @@ impl Config { n_layer: 32, n_head: 32, n_embd: 4096, + n_key_value_head: 32, } } } @@ -33,17 +35,37 @@ pub struct Cache { pub use_kv_cache: bool, #[allow(clippy::type_complexity)] kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>, + cos: Tensor, + sin: Tensor, device: Device, } impl Cache { - pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self { - Self { + pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> { + // precompute freqs_cis + let n_elem = config.n_embd / config.n_head; + let theta: Vec<_> = (0..n_elem) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // This is different from the paper, see: + // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 + let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok(Self { masks: Arc::new(Mutex::new(HashMap::new())), use_kv_cache, kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), device: device.clone(), - } + cos, + sin, + }) } fn mask(&self, t: usize) -> Result<Tensor> { @@ -97,7 +119,7 @@ impl RmsNorm { let (b_sz, seq_len, hidden_size) = x.shape().r3()?; let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?; let size = self.scale.shape().r1()?; let scale = self .scale @@ -110,63 +132,52 @@ impl RmsNorm { } struct CausalSelfAttention { - c_attn: Linear, - c_proj: Linear, + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, n_head: usize, + n_key_value_head: usize, + head_dim: usize, cache: Cache, } impl CausalSelfAttention { - fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self { - Self { - c_attn, - c_proj, - n_head, - cache: cache.clone(), - } - } - - fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { - let mut dims = x.dims().to_vec(); - let fcis_dims = freqs_cis.dims(); - let freqs_cis = if dims[2] < fcis_dims[1] { - freqs_cis.narrow(1, 0, dims[2])? - } else { - freqs_cis.clone() - }; - let v = dims.pop().unwrap(); - dims.push(v / 2); - dims.push(2); - let x = x.reshape(dims)?; - let re_x = x.narrow(D::Minus1, 0, 1)?; - let im_x = x.narrow(D::Minus1, 1, 1)?; - let re_f = freqs_cis - .narrow(D::Minus1, 0, 1)? - .broadcast_as(re_x.shape())?; - let im_f = freqs_cis - .narrow(D::Minus1, 1, 1)? - .broadcast_as(im_x.shape())?; - let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; - let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], D::Minus1)?; - let rope = rope.flatten_from(D::Minus2)?; + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { + let (b_sz, _, seq_len, n_embd) = x.shape().r4()?; + let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; + let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; + let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; + let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; + let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; Ok(rope) } - fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> { + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { let x_dtype = x.dtype(); let (b_sz, seq_len, n_embd) = x.shape().r3()?; - let qkv = self.c_attn.forward(x)?; - let qkv = qkv.to_dtype(DType::F32)?; - let q = qkv.narrow(D::Minus1, 0, n_embd)?; - let k = qkv.narrow(D::Minus1, n_embd, n_embd)?; - let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?; - let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head]; - let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?; - let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?; - let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?; - let q = self.apply_rotary_emb(&q, freqs_cis)?; - let mut k = self.apply_rotary_emb(&k, freqs_cis)?; + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + let k = k + .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + let mut v = v + .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .transpose(1, 2)? + .to_dtype(DType::F32)?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let mut k = self.apply_rotary_emb(&k, index_pos)?; if self.cache.use_kv_cache { let mut cache = self.cache.kvs.lock().unwrap(); @@ -189,7 +200,9 @@ impl CausalSelfAttention { cache[block_idx] = Some((k.clone(), v.clone())) } - let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?; + let k = self.repeat_kv(k)?; + let v = self.repeat_kv(v)?; + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(D::Minus1)?; @@ -197,31 +210,42 @@ impl CausalSelfAttention { let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.to_dtype(x_dtype)?; - let y = self.c_proj.forward(&y)?; + let y = self.o_proj.forward(&y)?; Ok(y) } + fn repeat_kv(&self, x: Tensor) -> Result<Tensor> { + let n_rep = self.n_head / self.n_key_value_head; + if n_rep == 1 { + Ok(x) + } else { + let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?; + let x = x + .unsqueeze(2)? + .expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))? + .reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?; + Ok(x) + } + } + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { let size_in = cfg.hidden_size; - let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head; - let q_proj = vb.get((size_in, size), "q_proj.weight")?; - let k_proj = vb.get((size_in, size), "k_proj.weight")?; - let v_proj = vb.get((size_in, size), "v_proj.weight")?; - // Invert the transformation from: - // https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101 - let n_head = cfg.n_head; - let q_proj = q_proj - .reshape((n_head, 2, size / n_head / 2, size_in))? - .transpose(1, 2)? - .reshape((size_in, size))?; - let k_proj = k_proj - .reshape((n_head, 2, size / n_head / 2, size_in))? - .transpose(1, 2)? - .reshape((size_in, size))?; - let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?; - let c_attn = Linear::new(attn_weight, None); - let o_proj = linear(size, size_in, vb.pp("o_proj"))?; - Ok(Self::new(c_attn, o_proj, cfg.n_head, cache)) + let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; + let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + n_head: cfg.n_head, + n_key_value_head: cfg.n_key_value_head, + head_dim: cfg.hidden_size / cfg.n_head, + cache: cache.clone(), + }) } } @@ -279,12 +303,12 @@ impl Block { } } - fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> { - let x = (self - .attn - .forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)? - + x)?; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; + fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { + let residual = x; + let x = self.rms_1.forward(x)?; + let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let residual = &x; + let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) } @@ -320,11 +344,11 @@ impl Llama { } } - pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { + pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { let (_b_sz, seq_len) = x.shape().r2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, freqs_cis, block_idx)?; + x = block.forward(&x, index_pos, block_idx)?; } let x = self.ln_f.forward(&x)?; let x = x.i((.., seq_len - 1, ..))?; |