summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/llama')
-rw-r--r--candle-examples/examples/llama/convert_checkpoint.py68
-rw-r--r--candle-examples/examples/llama/main.rs572
-rw-r--r--candle-examples/examples/llama/var_store.rs145
-rw-r--r--candle-examples/examples/llama/weights.rs127
4 files changed, 912 insertions, 0 deletions
diff --git a/candle-examples/examples/llama/convert_checkpoint.py b/candle-examples/examples/llama/convert_checkpoint.py
new file mode 100644
index 00000000..245c167c
--- /dev/null
+++ b/candle-examples/examples/llama/convert_checkpoint.py
@@ -0,0 +1,68 @@
+# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
+import sys
+import torch
+import numpy as np
+from typing import Dict
+from pathlib import Path
+
+def tr(v):
+ return np.ascontiguousarray(np.transpose(v))
+
+def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
+ print("start conv")
+
+ def get_and_remove(key, transpose=False):
+ v = state_dict[key].to(dtype).numpy()
+ if transpose:
+ v = tr(v)
+ del state_dict[key]
+ return v
+
+ converted = {}
+ converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
+ converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
+ converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
+
+ for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
+ print(layer_idx)
+
+ # attention
+ # the wq, wk, wv from the FB model are stacked in our model as c_attn
+ converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
+ (
+ get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
+ get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
+ get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
+ )
+ ))
+ converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
+ f"layers.{layer_idx}.attention.wo.weight"
+ ))
+ # mlp
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
+ )
+ converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
+ )
+ converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
+ f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
+ )
+ # rms norm
+ converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
+ converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
+ return converted
+
+def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
+ dt = getattr(torch, dtype, None)
+ if not isinstance(dt, torch.dtype):
+ raise ValueError(f"{dtype} is not a valid dtype.")
+ checkpoint = torch.load(llama_ckpt, map_location="cpu")
+ converted = convert_state_dict(checkpoint, dtype=dt)
+ del checkpoint
+ np.savez(output_npz, **converted)
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
+ convert_weights(sys.argv[1])
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
new file mode 100644
index 00000000..73db15e0
--- /dev/null
+++ b/candle-examples/examples/llama/main.rs
@@ -0,0 +1,572 @@
+// An implementation of LLaMA https://github.com/facebookresearch/llama
+//
+// This is based on nanoGPT in a similar way to:
+// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
+//
+// The tokenizer config can be retrieved from:
+// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json
+//
+// In order to convert the llama weights to a .npz file, run:
+// python examples/llama/convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth
+
+// TODO: This does not use a batch dimension. If adding it back, be cautious about the
+// transposition operations.
+use anyhow::{Error as E, Result};
+use clap::Parser;
+use rand::{distributions::Distribution, SeedableRng};
+
+use candle::{DType, Device, Tensor};
+use candle_hub::{api::Api, Repo, RepoType};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+mod var_store;
+mod weights;
+
+const MAX_SEQ_LEN: usize = 4096;
+const DTYPE: DType = DType::F16;
+const DEFAULT_PROMPT: &str = r"
+EDWARD:
+I wonder how our princely father 'scaped,
+Or whether he be 'scaped away or no
+From Clifford's and Northumberland's pursuit:
+Had he been ta'en, we should have heard the news;
+Had he been slain, we should have heard the news;
+Or had he 'scaped, methinks we should have heard
+The happy tidings of his good escape.
+How fares my brother? why is he so sad?
+
+RICHARD:
+I cannot joy, until I be resolved
+Where our right valiant father is become.
+I saw him in the battle range about;
+And watch'd him how he singled Clifford forth.
+Methought he bore him in the thickest troop
+As doth a lion in a herd of neat;
+Or as a bear, encompass'd round with dogs,
+Who having pinch'd a few and made them cry,
+The rest stand all aloof, and bark at him.
+So fared our father with his enemies;
+So fled his enemies my warlike father:
+Methinks, 'tis prize enough to be his son.
+See how the morning opes her golden gates,
+And takes her farewell of the glorious sun!
+How well resembles it the prime of youth,
+Trimm'd like a younker prancing to his love!
+
+EDWARD:
+Dazzle mine eyes, or do I see three suns?
+
+RICHARD:
+Three glorious suns, each one a perfect sun;
+Not separated with the racking clouds,
+But sever'd in a pale clear-shining sky.
+See, see! they join, embrace, and seem to kiss,
+As if they vow'd some league inviolable:
+Now are they but one lamp, one light, one sun.
+In this the heaven figures some event.
+
+EDWARD:
+'Tis wondrous strange, the like yet never heard of.
+I think it cites us, brother, to the field,
+That we, the sons of brave Plantagenet,
+Each one already blazing by our meeds,
+Should notwithstanding join our lights together
+And over-shine the earth as this the world.
+Whate'er it bodes, henceforward will I bear
+Upon my target three fair-shining suns.
+";
+
+#[allow(dead_code)]
+struct Config {
+ block_size: usize,
+ vocab_size: usize,
+ n_layer: usize,
+ n_head: usize,
+ n_embd: usize,
+}
+
+#[allow(dead_code)]
+impl Config {
+ fn config_7b() -> Self {
+ Self {
+ block_size: 4096,
+ vocab_size: 32000,
+ n_layer: 32,
+ n_head: 32,
+ n_embd: 4096,
+ }
+ }
+
+ fn config_13b() -> Self {
+ Self {
+ block_size: 4096,
+ vocab_size: 32000,
+ n_layer: 40,
+ n_head: 40,
+ n_embd: 5120,
+ }
+ }
+
+ fn config_30b() -> Self {
+ Self {
+ block_size: 4096,
+ vocab_size: 32000,
+ n_layer: 60,
+ n_head: 52,
+ n_embd: 6656,
+ }
+ }
+
+ fn config_65b() -> Self {
+ Self {
+ block_size: 4096,
+ vocab_size: 32000,
+ n_layer: 80,
+ n_head: 64,
+ n_embd: 8192,
+ }
+ }
+}
+
+struct Embedding {
+ embeddings: Tensor,
+}
+
+impl Embedding {
+ fn new(embeddings: Tensor) -> Self {
+ Self { embeddings }
+ }
+
+ fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
+ let embeddings = self.embeddings.to_dtype(DTYPE)?;
+ Ok(Tensor::embedding(indexes, &embeddings)?)
+ }
+}
+
+struct Linear {
+ weight: Tensor,
+}
+
+impl Linear {
+ fn new(weight: Tensor) -> Self {
+ Self { weight }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let weight = self.weight.to_dtype(DTYPE)?;
+ let x = x.matmul(&weight.t()?)?;
+ Ok(x)
+ }
+}
+
+struct RmsNorm {
+ scale: Tensor,
+}
+
+impl RmsNorm {
+ fn new(scale: Tensor) -> Self {
+ Self { scale }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ // This is a no-op if x's dtype is already f32.
+ let x = x.to_dtype(DType::F32)?;
+ let (seq_len, hidden_size) = x.shape().r2()?;
+ let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
+ let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
+ let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
+ let size = self.scale.shape().r1()?;
+ let scale = self
+ .scale
+ .to_dtype(DType::F32)?
+ .broadcast_as((seq_len, size))?;
+ let x = (scale * x_normed)?;
+ let x = x.to_dtype(DTYPE)?;
+ Ok(x)
+ }
+}
+
+struct Mlp {
+ c_fc1: Linear,
+ c_fc2: Linear,
+ c_proj: Linear,
+}
+
+fn silu(xs: &Tensor) -> Result<Tensor> {
+ Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
+}
+
+impl Mlp {
+ fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
+ Self {
+ c_fc1,
+ c_fc2,
+ c_proj,
+ }
+ }
+
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
+ self.c_proj.forward(&x)
+ }
+}
+
+fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
+ let shape = mask.shape();
+ let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
+ Ok(m)
+}
+
+#[derive(Clone)]
+struct Cache {
+ masks: Arc<Mutex<HashMap<usize, Tensor>>>,
+ use_kv_cache: bool,
+ #[allow(clippy::type_complexity)]
+ kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
+ device: Device,
+}
+
+impl Cache {
+ fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
+ Self {
+ masks: Arc::new(Mutex::new(HashMap::new())),
+ use_kv_cache,
+ kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
+ device: device.clone(),
+ }
+ }
+
+ fn mask(&self, t: usize) -> Result<Tensor> {
+ let mut masks = self.masks.lock().unwrap();
+ if let Some(mask) = masks.get(&t) {
+ Ok(mask.clone())
+ } else {
+ // TODO: If we support bool or u8 tensors, this would be better.
+ let mask: Vec<_> = (0..t)
+ .flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
+ .collect();
+ let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
+ masks.insert(t, mask.clone());
+ Ok(mask)
+ }
+ }
+}
+
+struct CausalSelfAttention {
+ c_attn: Linear,
+ c_proj: Linear,
+ n_head: usize,
+ cache: Cache,
+}
+
+impl CausalSelfAttention {
+ fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
+ Self {
+ c_attn,
+ c_proj,
+ n_head,
+ cache: cache.clone(),
+ }
+ }
+
+ fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
+ let mut dims = x.dims().to_vec();
+ let fcis_dims = freqs_cis.dims();
+ let freqs_cis = if dims[1] < fcis_dims[1] {
+ freqs_cis.narrow(1, 0, dims[1])?
+ } else {
+ freqs_cis.clone()
+ };
+ let v = dims.pop().unwrap();
+ dims.push(v / 2);
+ dims.push(2);
+ let x = x.reshape(dims)?;
+ let rank = x.rank();
+ let re_x = x.narrow(rank - 1, 0, 1)?;
+ let im_x = x.narrow(rank - 1, 1, 1)?;
+ let re_f = freqs_cis
+ .narrow(rank - 1, 0, 1)?
+ .broadcast_as(re_x.shape())?;
+ let im_f = freqs_cis
+ .narrow(rank - 1, 1, 1)?
+ .broadcast_as(im_x.shape())?;
+ let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
+ let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
+ let rope = Tensor::cat(&[&re, &im], rank - 1)?;
+ let rope = rope.flatten(Some(rope.rank() - 2), None)?;
+ Ok(rope)
+ }
+
+ fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
+ let (t, c) = x.shape().r2()?;
+ let qkv = self.c_attn.forward(x)?;
+ let qkv = qkv.to_dtype(DType::F32)?;
+ let n_embd = c;
+ let q = qkv.narrow(1, 0, n_embd)?;
+ let k = qkv.narrow(1, n_embd, n_embd)?;
+ let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
+ let target_dim = [t, self.n_head, c / self.n_head];
+ let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let q = self.apply_rotary_emb(&q, freqs_cis)?;
+ let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
+
+ if self.cache.use_kv_cache {
+ let mut cache = self.cache.kvs.lock().unwrap();
+ if let Some((cache_k, cache_v)) = &cache[block_idx] {
+ k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
+ v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
+ let k_seq_len = k.dims()[1];
+ if k_seq_len > MAX_SEQ_LEN {
+ k = k
+ .narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
+ }
+ let v_seq_len = v.dims()[1];
+ if v_seq_len > 2 * MAX_SEQ_LEN {
+ v = v
+ .narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
+ .contiguous()?
+ }
+ }
+ cache[block_idx] = Some((k.clone(), v.clone()))
+ }
+
+ let k_shape = k.shape();
+ let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
+ let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = att.softmax(att.rank() - 1)?;
+ // Convert to contiguous as matmul doesn't support strided vs for now.
+ let y = att.matmul(&v.contiguous()?)?;
+ let y = y.transpose(0, 1)?.reshape(&[t, c])?;
+ let y = y.to_dtype(DTYPE)?;
+ let y = self.c_proj.forward(&y)?;
+ Ok(y)
+ }
+}
+
+struct Block {
+ rms_1: RmsNorm,
+ attn: CausalSelfAttention,
+ rms_2: RmsNorm,
+ mlp: Mlp,
+}
+
+impl Block {
+ fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
+ Self {
+ rms_1,
+ attn,
+ rms_2,
+ mlp,
+ }
+ }
+
+ fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
+ let x = (self
+ .attn
+ .forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
+ + x)?;
+ let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
+ Ok(x)
+ }
+}
+
+struct Llama {
+ wte: Embedding,
+ blocks: Vec<Block>,
+ ln_f: RmsNorm,
+ lm_head: Linear,
+}
+
+impl Llama {
+ fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
+ Self {
+ wte,
+ blocks,
+ ln_f,
+ lm_head,
+ }
+ }
+
+ fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
+ // TODO: Support for mini-batches? (i.e. r2)
+ let t = x.shape().r1()?;
+ let mut x = self.wte.forward(x)?;
+ for (block_idx, block) in self.blocks.iter().enumerate() {
+ x = block.forward(&x, freqs_cis, block_idx)?;
+ }
+ let x = self.ln_f.forward(&x)?;
+ let x = x.narrow(0, t - 1, 1)?;
+ let logits = self.lm_head.forward(&x)?;
+ let logits = logits.to_dtype(DType::F32)?;
+ let (b, vocab_size) = logits.shape().r2()?;
+ assert_eq!(b, 1);
+ Ok(logits.reshape(vocab_size)?)
+ }
+}
+
+fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
+ let n_elem = config.n_embd / config.n_head;
+ let theta: Vec<_> = (0..n_elem)
+ .step_by(2)
+ .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
+ .collect();
+ let arange: Vec<_> = (0..MAX_SEQ_LEN).map(|c| c as f32).collect();
+ let theta = Tensor::new(theta.as_slice(), device)?;
+ let arange = Tensor::new(arange.as_slice(), device)?;
+ let idx_theta = arange
+ .reshape((arange.elem_count(), 1))?
+ .matmul(&theta.reshape((1, theta.elem_count()))?)?;
+ let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
+ let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
+ let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
+ let last_dim = idx_theta_cos.rank() - 1;
+ Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
+}
+
+#[derive(Parser, Debug)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// Use npy instead of safetensors
+ #[arg(long)]
+ npy: Option<String>,
+
+ /// The temperature used to generate samples.
+ #[arg(long)]
+ temperature: Option<f64>,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 299792458)]
+ seed: u64,
+
+ /// The length of the sample to generate (in tokens).
+ #[arg(long, default_value_t = 100)]
+ sample_len: usize,
+
+ /// Disable the key-value cache.
+ #[arg(long)]
+ no_kv_cache: bool,
+
+ /// The initial prompt.
+ #[arg(long)]
+ prompt: Option<String>,
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ use tokenizers::Tokenizer;
+
+ let args = Args::parse();
+ let device = if args.cpu {
+ Device::Cpu
+ } else {
+ Device::new_cuda(0)?
+ };
+ let config = Config::config_7b();
+ let cache = Cache::new(!args.no_kv_cache, &config, &device);
+ let start = std::time::Instant::now();
+ let (llama, tokenizer_filename) = match args.npy {
+ Some(npy) => {
+ println!("building the model (NPY)");
+ let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
+ let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
+ (weights, token_path)
+ }
+ None => {
+ let api = Api::new()?;
+ let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
+ println!("building the model");
+ let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
+ let mut filenames = vec![];
+ for rfilename in [
+ "model-00001-of-00002.safetensors",
+ "model-00002-of-00002.safetensors",
+ ] {
+ let filename = api.get(&repo, rfilename).await?;
+ filenames.push(filename);
+ }
+
+ println!("building the model (SF)");
+ (
+ Llama::load(&device, &filenames, &cache, &config)?,
+ tokenizer_filename,
+ )
+ }
+ };
+ println!("Loaded in {:?}", start.elapsed());
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+
+ println!("pre-computing the positional embeddings");
+ let freqs_cis = precompute_freqs_cis(&config, &device)?;
+ println!("starting the inference loop");
+ let mut new_tokens = vec![];
+ let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
+ let start_gen = std::time::Instant::now();
+ let mut index_pos = 0;
+ for index in 0..args.sample_len {
+ let start_gen = std::time::Instant::now();
+ let context_size = if cache.use_kv_cache && index > 0 {
+ 1
+ } else {
+ tokens.len()
+ };
+ let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
+ let input = Tensor::new(ctxt, &device)?;
+ let freqs_cis = if cache.use_kv_cache {
+ freqs_cis.narrow(1, index_pos, ctxt.len())?
+ } else {
+ freqs_cis.clone()
+ };
+ let logits = llama.forward(&input, &freqs_cis)?;
+ index_pos += ctxt.len();
+
+ let next_token = if let Some(temperature) = args.temperature {
+ println!("Sampling with temperature {temperature:?}");
+ let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
+ let logits_v: Vec<f32> = prs.to_vec1()?;
+ let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
+
+ distr.sample(&mut rng) as u32
+ } else {
+ let logits_v: Vec<f32> = logits.to_vec1()?;
+ logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap()
+ };
+ tokens.push(next_token);
+ new_tokens.push(next_token);
+ println!("> {:?}", start_gen.elapsed());
+ println!(
+ "{} token: {} '{}'",
+ index + 1,
+ next_token,
+ tokenizer.decode(vec![next_token], true).map_err(E::msg)?
+ );
+ }
+ let dt = start_gen.elapsed();
+ println!(
+ "{} tokens generated ({} token/s)\n----\n{}\n----",
+ args.sample_len,
+ args.sample_len as f64 / dt.as_secs_f64(),
+ tokenizer.decode(new_tokens, true).map_err(E::msg)?
+ );
+ Ok(())
+}
diff --git a/candle-examples/examples/llama/var_store.rs b/candle-examples/examples/llama/var_store.rs
new file mode 100644
index 00000000..1a22bd89
--- /dev/null
+++ b/candle-examples/examples/llama/var_store.rs
@@ -0,0 +1,145 @@
+use super::*;
+use candle::{DType, Device, Result, Shape, Tensor};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+#[allow(dead_code)]
+#[derive(Clone)]
+struct NamedVar {
+ path: String,
+ dtype: DType,
+ shape: Shape,
+}
+
+#[derive(Clone)]
+pub struct VarBuilder {
+ path: Vec<String>,
+ default_device: Device,
+ tensors: Arc<Mutex<HashMap<String, Tensor>>>,
+}
+
+impl VarBuilder {
+ pub fn new(device: &Device, tensors: HashMap<String, Tensor>) -> Self {
+ Self {
+ path: vec![],
+ tensors: Arc::new(Mutex::new(tensors)),
+ default_device: device.clone(),
+ }
+ }
+
+ pub fn get_and_remove(&self, s: &str) -> Result<Tensor> {
+ let path = format!("{}.{s}", self.path.join("."));
+ let mut tensors = self.tensors.as_ref().lock().unwrap();
+ let parameter = match tensors.remove(&path) {
+ Some(tensor) => tensor.to_device(&self.default_device)?,
+ None => panic!("cannot find tensor for {path}"),
+ };
+ Ok(parameter)
+ }
+}
+
+impl<S: ToString> std::ops::Div<S> for &VarBuilder {
+ type Output = VarBuilder;
+
+ fn div(self, rhs: S) -> VarBuilder {
+ let mut path = self.path.clone();
+ path.push(rhs.to_string());
+ VarBuilder {
+ path,
+ default_device: self.default_device.clone(),
+ tensors: self.tensors.clone(),
+ }
+ }
+}
+
+impl<S: ToString> std::ops::Div<S> for VarBuilder {
+ type Output = VarBuilder;
+
+ fn div(self, rhs: S) -> VarBuilder {
+ &self / rhs
+ }
+}
+
+impl Embedding {
+ fn load_npy(vb: VarBuilder) -> Result<Self> {
+ let embeddings = vb.get_and_remove("weight")?.to_dtype(DTYPE)?;
+ Ok(Self { embeddings })
+ }
+}
+
+impl Linear {
+ fn load_npy(vb: VarBuilder) -> Result<Self> {
+ let weight = vb.get_and_remove("weight")?.to_dtype(DTYPE)?.t()?;
+ Ok(Self { weight })
+ }
+}
+
+impl RmsNorm {
+ fn load_npy(vb: VarBuilder) -> Result<Self> {
+ let scale = vb.get_and_remove("scale")?.to_dtype(DTYPE)?;
+ Ok(Self::new(scale))
+ }
+}
+
+impl CausalSelfAttention {
+ fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
+ let c_attn = Linear::load_npy(&vb / "c_attn")?;
+ let c_proj = Linear::load_npy(&vb / "c_proj")?;
+ Ok(Self::new(c_attn, c_proj, config.n_head, cache))
+ }
+}
+
+impl Mlp {
+ fn load_npy(vb: VarBuilder) -> Result<Self> {
+ let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
+ let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
+ let c_proj = Linear::load_npy(&vb / "c_proj")?;
+ Ok(Self::new(c_fc1, c_fc2, c_proj))
+ }
+}
+
+impl Block {
+ fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
+ let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
+ let mlp = Mlp::load_npy(&vb / "mlp")?;
+ let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
+ let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
+ Ok(Self::new(
+ input_layernorm,
+ attn,
+ post_attention_layernorm,
+ mlp,
+ ))
+ }
+}
+
+impl Llama {
+ pub fn load_npy(
+ device: &Device,
+ filename: &str,
+ cache: &Cache,
+ config: &Config,
+ ) -> anyhow::Result<Self> {
+ let weight_path = std::path::Path::new(filename);
+ let weights = if weight_path.exists() {
+ println!("loading weights from {weight_path:?}");
+ let start_load = std::time::Instant::now();
+ let tensors = Tensor::read_npz(weight_path)?;
+ println!("loaded weights in {:?}", start_load.elapsed());
+ let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
+ tensors
+ } else {
+ anyhow::bail!("cannot find {weight_path:?}")
+ };
+ let vb = VarBuilder::new(device, weights);
+
+ let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
+ let lm_head = Linear::load_npy(&vb / "lm_head")?;
+ let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
+ let blocks: Vec<_> = (0..config.n_layer)
+ .map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
+ .collect();
+
+ Ok(Self::new(wte, blocks, norm, lm_head))
+ }
+}
diff --git a/candle-examples/examples/llama/weights.rs b/candle-examples/examples/llama/weights.rs
new file mode 100644
index 00000000..c3364cef
--- /dev/null
+++ b/candle-examples/examples/llama/weights.rs
@@ -0,0 +1,127 @@
+use super::*;
+use candle::{safetensors::SafeTensors, Device, Result, Tensor};
+use std::path::PathBuf;
+
+pub struct VarBuilder<'a> {
+ routing: HashMap<String, usize>,
+ safetensors: Vec<SafeTensors<'a>>,
+ device: Device,
+}
+
+impl<'a> VarBuilder<'a> {
+ pub fn new(safetensors: Vec<SafeTensors<'a>>, device: Device) -> Self {
+ let mut routing = HashMap::new();
+ for (index, sf) in safetensors.iter().enumerate() {
+ for k in sf.names() {
+ routing.insert(k.to_string(), index);
+ }
+ }
+
+ Self {
+ safetensors,
+ device,
+ routing,
+ }
+ }
+
+ pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
+ // Unwrap or 0 just to let the proper error flow.
+ let index = self.routing.get(tensor_name).unwrap_or(&0);
+ self.safetensors[*index]
+ .tensor(tensor_name, &self.device)?
+ .to_dtype(DTYPE)
+ }
+}
+
+impl Linear {
+ fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
+ let weight = vb.get(&format!("{prefix}.weight"))?;
+ Ok(Self::new(weight))
+ }
+
+ fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result<Self> {
+ let weights: Vec<_> = prefixes
+ .iter()
+ .map(|p| vb.get(&format!("{p}.weight")).unwrap())
+ .collect();
+ let weight = Tensor::cat(&weights, 0)?;
+ Ok(Self::new(weight))
+ }
+}
+
+impl RmsNorm {
+ fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
+ let scale = vb.get(&format!("{prefix}.weight"))?;
+ Ok(Self::new(scale))
+ }
+}
+
+impl CausalSelfAttention {
+ fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
+ let c_attn = Linear::load_multi(
+ &[
+ &format!("{prefix}.q_proj"),
+ &format!("{prefix}.k_proj"),
+ &format!("{prefix}.v_proj"),
+ ],
+ vb,
+ )?;
+ let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?;
+ Ok(Self::new(c_attn, o_proj, config.n_head, cache))
+ }
+}
+
+impl Mlp {
+ fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
+ let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?;
+ let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?;
+ let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?;
+ Ok(Self::new(c_fc1, c_fc2, c_proj))
+ }
+}
+
+impl Block {
+ fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
+ let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?;
+ let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?;
+ let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?;
+ let post_attention_layernorm =
+ RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?;
+ Ok(Self::new(
+ input_layernorm,
+ attn,
+ post_attention_layernorm,
+ mlp,
+ ))
+ }
+}
+
+impl Llama {
+ pub fn load(
+ device: &Device,
+ filenames: &[PathBuf],
+ cache: &Cache,
+ config: &Config,
+ ) -> Result<Self> {
+ let handles: Vec<_> = filenames
+ .iter()
+ .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
+ .collect::<Result<Vec<_>>>()?;
+ let tensors: Vec<_> = handles
+ .iter()
+ .map(|h| h.deserialize())
+ .collect::<Result<Vec<_>>>()?;
+
+ let vb = VarBuilder::new(tensors, device.clone());
+
+ let embedding = vb.get("model.embed_tokens.weight")?;
+ let wte = Embedding::new(embedding);
+ let lm_head = Linear::load("lm_head", &vb)?;
+ let norm = RmsNorm::load("model.norm", &vb)?;
+ let blocks: Vec<_> = (0..config.n_layer)
+ .map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap())
+ .collect();
+
+ Ok(Self::new(wte, blocks, norm, lm_head))
+ }
+}