summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/examples/llama/main.rs222
1 files changed, 90 insertions, 132 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs
index 83a1d69a..3a025683 100644
--- a/candle-core/examples/llama/main.rs
+++ b/candle-core/examples/llama/main.rs
@@ -13,7 +13,7 @@
// transposition operations.
use anyhow::{Error as E, Result};
use clap::Parser;
-use rand::{distributions::Distribution, SeedableRng};
+use rand::{distributions::Distribution, thread_rng};
use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
@@ -138,7 +138,7 @@ impl Embedding {
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
- Ok(Tensor::embedding(indexes, &self.embeddings).unwrap())
+ Ok(Tensor::embedding(indexes, &self.embeddings)?)
}
}
@@ -152,7 +152,7 @@ impl Linear {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = x.matmul(&self.weight.t().unwrap()).unwrap();
+ let x = x.matmul(&self.weight.t()?)?;
Ok(x)
}
}
@@ -168,18 +168,16 @@ impl RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.to_dtype(DType::F32)?;
- let (seq_len, hidden_size) = x.shape().r2().unwrap();
- let norm_x = ((&x * &x).unwrap().sum(&[1]).unwrap() / hidden_size as f64).unwrap();
- let norm_x = norm_x.broadcast_as((seq_len, hidden_size)).unwrap();
- let x_normed = (x / (norm_x + 1e-5).unwrap().sqrt().unwrap()).unwrap();
- let size = self.scale.shape().r1().unwrap();
+ let (seq_len, hidden_size) = x.shape().r2()?;
+ let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
+ let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
+ let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
+ let size = self.scale.shape().r1()?;
let scale = self
.scale
- .to_dtype(DType::F32)
- .unwrap()
- .broadcast_as((seq_len, size))
- .unwrap();
- let x = (scale * x_normed).unwrap();
+ .to_dtype(DType::F32)?
+ .broadcast_as((seq_len, size))?;
+ let x = (scale * x_normed)?;
let x = x.to_dtype(DType::F16)?;
Ok(x)
}
@@ -192,7 +190,7 @@ struct Mlp {
}
fn silu(xs: &Tensor) -> Result<Tensor> {
- Ok((xs / (xs.neg().unwrap().exp().unwrap() + 1.0).unwrap()).unwrap())
+ Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
}
impl Mlp {
@@ -205,19 +203,15 @@ impl Mlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
- let x = (silu(&self.c_fc1.forward(x).unwrap()).unwrap() * self.c_fc2.forward(x).unwrap())
- .unwrap();
+ let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
- let on_true = Tensor::new(on_true, &on_false.device())
- .unwrap()
- .broadcast_as(shape.dims())
- .unwrap();
- let m = mask.where_cond(&on_true, on_false).unwrap();
+ let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
+ let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@@ -244,7 +238,7 @@ impl Cache {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
.collect();
- let mask = Tensor::from_slice(&mask, (t, t), &self.device).unwrap();
+ let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
masks.insert(t, mask.clone());
Ok(mask)
}
@@ -274,70 +268,47 @@ impl CausalSelfAttention {
let v = dims.pop().unwrap();
dims.push(v / 2);
dims.push(2);
- let x = x.reshape(dims).unwrap();
+ let x = x.reshape(dims)?;
let rank = x.rank();
- let re_x = x.narrow(rank - 1, 0, 1).unwrap();
- let im_x = x.narrow(rank - 1, 1, 1).unwrap();
+ let re_x = x.narrow(rank - 1, 0, 1)?;
+ let im_x = x.narrow(rank - 1, 1, 1)?;
let re_f = freqs_cis
- .narrow(rank - 1, 0, 1)
- .unwrap()
- .broadcast_as(re_x.shape())
- .unwrap();
+ .narrow(rank - 1, 0, 1)?
+ .broadcast_as(re_x.shape())?;
let im_f = freqs_cis
- .narrow(rank - 1, 1, 1)
- .unwrap()
- .broadcast_as(im_x.shape())
- .unwrap();
- let re = ((&re_x * &re_f).unwrap() - (&im_x * &im_f).unwrap()).unwrap();
- let im = ((&re_x * &im_f).unwrap() + (&im_x * &re_f).unwrap()).unwrap();
- let rope = Tensor::cat(&[&re, &im], rank - 1).unwrap();
- let rope = rope.flatten(Some(rope.rank() - 2), None).unwrap();
+ .narrow(rank - 1, 1, 1)?
+ .broadcast_as(im_x.shape())?;
+ let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
+ let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
+ let rope = Tensor::cat(&[&re, &im], rank - 1)?;
+ let rope = rope.flatten(Some(rope.rank() - 2), None)?;
Ok(rope)
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
- let (t, c) = x.shape().r2().unwrap();
- let qkv = self.c_attn.forward(x).unwrap();
- let qkv = qkv.to_dtype(DType::F32).unwrap();
+ let (t, c) = x.shape().r2()?;
+ let qkv = self.c_attn.forward(x)?;
+ let qkv = qkv.to_dtype(DType::F32)?;
let n_embd = c;
- let q = qkv.narrow(1, 0, n_embd).unwrap();
- let k = qkv.narrow(1, n_embd, n_embd).unwrap();
- let v = qkv.narrow(1, 2 * n_embd, n_embd).unwrap();
+ let q = qkv.narrow(1, 0, n_embd)?;
+ let k = qkv.narrow(1, n_embd, n_embd)?;
+ let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
let target_dim = [t, self.n_head, c / self.n_head];
- let k = k
- .reshape(target_dim.as_slice())
- .unwrap()
- .transpose(0, 1)
- .unwrap();
- let q = q
- .reshape(target_dim.as_slice())
- .unwrap()
- .transpose(0, 1)
- .unwrap();
- let v = v
- .reshape(target_dim.as_slice())
- .unwrap()
- .transpose(0, 1)
- .unwrap();
- let q = self.apply_rotary_emb(&q, freqs_cis).unwrap();
- let k = self.apply_rotary_emb(&k, freqs_cis).unwrap();
+ let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
+ let q = self.apply_rotary_emb(&q, freqs_cis)?;
+ let k = self.apply_rotary_emb(&k, freqs_cis)?;
let k_shape = k.shape();
- let att = (q.matmul(&k.t().unwrap()).unwrap()
- / (*k_shape.dims().last().unwrap() as f64).sqrt())
- .unwrap();
- let mask = self
- .cache
- .mask(t)
- .unwrap()
- .broadcast_as(att.shape())
- .unwrap();
- let att = masked_fill(&att, &mask, f32::NEG_INFINITY).unwrap();
- let att = att.softmax(att.rank() - 1).unwrap();
+ let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
+ let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
+ let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
+ let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
- let y = att.matmul(&v.contiguous().unwrap()).unwrap();
- let y = y.transpose(0, 1).unwrap().reshape(&[t, c]).unwrap();
- let y = y.to_dtype(DType::F16).unwrap();
- let y = self.c_proj.forward(&y).unwrap();
+ let y = att.matmul(&v.contiguous()?)?;
+ let y = y.transpose(0, 1)?.reshape(&[t, c])?;
+ let y = y.to_dtype(DType::F16)?;
+ let y = self.c_proj.forward(&y)?;
Ok(y)
}
}
@@ -360,13 +331,8 @@ impl Block {
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
- let x = (self
- .attn
- .forward(&self.rms_1.forward(x).unwrap(), freqs_cis)
- .unwrap()
- + x)
- .unwrap();
- let x = (self.mlp.forward(&self.rms_2.forward(&x).unwrap()).unwrap() + x).unwrap();
+ let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?;
+ let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
Ok(x)
}
}
@@ -390,18 +356,18 @@ impl Llama {
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
// TODO: Support for mini-batches? (i.e. r2)
- let t = x.shape().r1().unwrap();
- let mut x = self.wte.forward(x).unwrap();
+ let t = x.shape().r1()?;
+ let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() {
- x = block.forward(&x, freqs_cis).unwrap();
+ x = block.forward(&x, freqs_cis)?;
}
- let x = self.ln_f.forward(&x).unwrap();
- let x = x.narrow(0, t - 1, 1).unwrap();
- let logits = self.lm_head.forward(&x).unwrap();
+ let x = self.ln_f.forward(&x)?;
+ let x = x.narrow(0, t - 1, 1)?;
+ let logits = self.lm_head.forward(&x)?;
let logits = logits.to_dtype(DType::F32)?;
- let (b, vocab_size) = logits.shape().r2().unwrap();
+ let (b, vocab_size) = logits.shape().r2()?;
assert_eq!(b, 1);
- Ok(logits.reshape(vocab_size).unwrap())
+ Ok(logits.reshape(vocab_size)?)
}
}
@@ -413,18 +379,16 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
.collect();
let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect();
- let theta = Tensor::new(theta.as_slice(), device).unwrap();
- let arange = Tensor::new(arange.as_slice(), device).unwrap();
+ let theta = Tensor::new(theta.as_slice(), device)?;
+ let arange = Tensor::new(arange.as_slice(), device)?;
let idx_theta = arange
- .reshape((arange.elem_count(), 1))
- .unwrap()
- .matmul(&theta.reshape((1, theta.elem_count())).unwrap())
- .unwrap();
+ .reshape((arange.elem_count(), 1))?
+ .matmul(&theta.reshape((1, theta.elem_count()))?)?;
let shape = [1, seq_len, n_elem / 2, 1];
- let idx_theta_cos = idx_theta.cos().unwrap().reshape(&shape).unwrap();
- let idx_theta_sin = idx_theta.sin().unwrap().reshape(&shape).unwrap();
+ let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
+ let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
let last_dim = idx_theta_cos.rank() - 1;
- Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim).unwrap())
+ Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?)
}
#[derive(Parser, Debug)]
@@ -442,10 +406,6 @@ struct Args {
#[arg(long)]
temperature: Option<f64>,
- /// The seed to use when generating random samples.
- #[arg(long, default_value_t = 299792458)]
- seed: u64,
-
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
@@ -453,26 +413,28 @@ struct Args {
#[tokio::main]
async fn main() -> Result<()> {
+ //use rand::prelude::*;
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
} else {
- Device::new_cuda(0).unwrap()
+ Device::new_cuda(0)?
};
+ let api = Api::new()?;
+ let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
+ println!("building the model");
let config = Config::config_7b();
let cache = Cache::new(&device);
let start = std::time::Instant::now();
let (llama, tokenizer_filename) = if args.npy {
println!("building the model (NPY)");
(
- Llama::load_npy(&device, "/data/llama.npz", &cache, &config).unwrap(),
+ Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
)
} else {
- let api = Api::new()?;
- let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut filenames = vec![];
for rfilename in [
@@ -485,51 +447,50 @@ async fn main() -> Result<()> {
println!("building the model (SF)");
(
- Llama::load(&device, &filenames, &cache, &config).unwrap(),
+ Llama::load(&device, &filenames, &cache, &config)?,
tokenizer_filename,
)
};
println!("Loaded in {:?}", start.elapsed());
- let tokenizer = Tokenizer::from_file(tokenizer_filename)
- .map_err(E::msg)
- .unwrap();
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer
.encode(START_PROMPT, true)
- .map_err(E::msg)
- .unwrap()
+ .map_err(E::msg)?
.get_ids()
.to_vec();
println!("pre-computing the positional embeddings");
- let freqs_cis = precompute_freqs_cis(&config, &device).unwrap();
+ let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop");
let mut new_tokens = vec![];
- let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
+ let mut rng = thread_rng();
let start_gen = std::time::Instant::now();
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
- let input = Tensor::new(ctxt, &device).unwrap();
- let logits = llama.forward(&input, &freqs_cis).unwrap();
+ let input = Tensor::new(ctxt, &device)?;
+ let logits = llama.forward(&input, &freqs_cis)?;
let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}");
- let prs = (&logits / temperature)
- .unwrap()
- .softmax(logits.rank() - 1)
- .unwrap();
- let logits_v: Vec<f32> = prs.to_vec1().unwrap();
- let distr = rand::distributions::WeightedIndex::new(&logits_v).unwrap();
+ let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
+ let logits_v: Vec<f32> = prs.to_vec1()?;
+ let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut rng) as u32
} else {
- let logits_v: Vec<f32> = logits.to_vec1().unwrap();
+ let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
+ .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
+ if &val_max > val {
+ (idx_max, val_max)
+ } else {
+ (idx, *val)
+ }
+ })
+ .0 as u32
};
tokens.push(next_token);
new_tokens.push(next_token);
@@ -538,10 +499,7 @@ async fn main() -> Result<()> {
"{} token: {} '{}'",
index + 1,
next_token,
- tokenizer
- .decode(vec![next_token], true)
- .map_err(E::msg)
- .unwrap()
+ tokenizer.decode(vec![next_token], true).map_err(E::msg)?
);
}
let dt = start_gen.elapsed();
@@ -549,7 +507,7 @@ async fn main() -> Result<()> {
"{} tokens generated ({} token/s)\n----\n{}\n----",
args.sample_len,
args.sample_len as f64 / dt.as_secs_f64(),
- tokenizer.decode(new_tokens, true).map_err(E::msg).unwrap()
+ tokenizer.decode(new_tokens, true).map_err(E::msg)?
);
Ok(())
}