diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-29 11:56:49 +0000 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-29 12:00:16 +0000 |
commit | 2fe1d3e36d195200336f6fec478045ea7ea3e01c (patch) | |
tree | 48c8cc60717f26c95a0618806b7f9e62ca2c32e5 /candle-core/examples | |
parent | 31396a3b9f18dd51dc5b7b8c05ed08a08ed4e00a (diff) | |
download | candle-2fe1d3e36d195200336f6fec478045ea7ea3e01c.tar.gz candle-2fe1d3e36d195200336f6fec478045ea7ea3e01c.tar.bz2 candle-2fe1d3e36d195200336f6fec478045ea7ea3e01c.zip |
Moving llama to f16.
Diffstat (limited to 'candle-core/examples')
-rw-r--r-- | candle-core/examples/llama/main.rs | 206 |
1 files changed, 128 insertions, 78 deletions
diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 8feb7fb0..83a1d69a 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -1,7 +1,7 @@ -// An implementation of LLaMA https://github.com/facebookresearch/llama +// An implementation of LLaMA https://github.com/facebookresearch/llama");");"); // // This is based on nanoGPT in a similar way to: -// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py"); // // The tokenizer config can be retrieved from: // https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json @@ -138,7 +138,7 @@ impl Embedding { } fn forward(&self, indexes: &Tensor) -> Result<Tensor> { - Ok(Tensor::embedding(indexes, &self.embeddings)?) + Ok(Tensor::embedding(indexes, &self.embeddings).unwrap()) } } @@ -152,7 +152,7 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?; + let x = x.matmul(&self.weight.t().unwrap()).unwrap(); Ok(x) } } @@ -167,16 +167,21 @@ impl RmsNorm { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let (seq_len, hidden_size) = x.shape().r2()?; - let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; - let size = self.scale.shape().r1()?; + let x = x.to_dtype(DType::F32)?; + let (seq_len, hidden_size) = x.shape().r2().unwrap(); + let norm_x = ((&x * &x).unwrap().sum(&[1]).unwrap() / hidden_size as f64).unwrap(); + let norm_x = norm_x.broadcast_as((seq_len, hidden_size)).unwrap(); + let x_normed = (x / (norm_x + 1e-5).unwrap().sqrt().unwrap()).unwrap(); + let size = self.scale.shape().r1().unwrap(); let scale = self .scale - .to_dtype(DType::F32)? - .broadcast_as((seq_len, size))?; - Ok((scale * x_normed)?) + .to_dtype(DType::F32) + .unwrap() + .broadcast_as((seq_len, size)) + .unwrap(); + let x = (scale * x_normed).unwrap(); + let x = x.to_dtype(DType::F16)?; + Ok(x) } } @@ -187,7 +192,7 @@ struct Mlp { } fn silu(xs: &Tensor) -> Result<Tensor> { - Ok((xs / (xs.neg()?.exp()? + 1.0)?)?) + Ok((xs / (xs.neg().unwrap().exp().unwrap() + 1.0).unwrap()).unwrap()) } impl Mlp { @@ -200,15 +205,19 @@ impl Mlp { } fn forward(&self, x: &Tensor) -> Result<Tensor> { - let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; + let x = (silu(&self.c_fc1.forward(x).unwrap()).unwrap() * self.c_fc2.forward(x).unwrap()) + .unwrap(); self.c_proj.forward(&x) } } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { let shape = mask.shape(); - let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; + let on_true = Tensor::new(on_true, &on_false.device()) + .unwrap() + .broadcast_as(shape.dims()) + .unwrap(); + let m = mask.where_cond(&on_true, on_false).unwrap(); Ok(m) } @@ -235,7 +244,7 @@ impl Cache { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) .collect(); - let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; + let mask = Tensor::from_slice(&mask, (t, t), &self.device).unwrap(); masks.insert(t, mask.clone()); Ok(mask) } @@ -265,45 +274,70 @@ impl CausalSelfAttention { let v = dims.pop().unwrap(); dims.push(v / 2); dims.push(2); - let x = x.reshape(dims)?; + let x = x.reshape(dims).unwrap(); let rank = x.rank(); - let re_x = x.narrow(rank - 1, 0, 1)?; - let im_x = x.narrow(rank - 1, 1, 1)?; + let re_x = x.narrow(rank - 1, 0, 1).unwrap(); + let im_x = x.narrow(rank - 1, 1, 1).unwrap(); let re_f = freqs_cis - .narrow(rank - 1, 0, 1)? - .broadcast_as(re_x.shape())?; + .narrow(rank - 1, 0, 1) + .unwrap() + .broadcast_as(re_x.shape()) + .unwrap(); let im_f = freqs_cis - .narrow(rank - 1, 1, 1)? - .broadcast_as(im_x.shape())?; - let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?; - let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?; - let rope = Tensor::cat(&[&re, &im], rank - 1)?; - let rope = rope.flatten(Some(rope.rank() - 2), None)?; + .narrow(rank - 1, 1, 1) + .unwrap() + .broadcast_as(im_x.shape()) + .unwrap(); + let re = ((&re_x * &re_f).unwrap() - (&im_x * &im_f).unwrap()).unwrap(); + let im = ((&re_x * &im_f).unwrap() + (&im_x * &re_f).unwrap()).unwrap(); + let rope = Tensor::cat(&[&re, &im], rank - 1).unwrap(); + let rope = rope.flatten(Some(rope.rank() - 2), None).unwrap(); Ok(rope) } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { - let (t, c) = x.shape().r2()?; - let qkv = self.c_attn.forward(x)?; + let (t, c) = x.shape().r2().unwrap(); + let qkv = self.c_attn.forward(x).unwrap(); + let qkv = qkv.to_dtype(DType::F32).unwrap(); let n_embd = c; - let q = qkv.narrow(1, 0, n_embd)?; - let k = qkv.narrow(1, n_embd, n_embd)?; - let v = qkv.narrow(1, 2 * n_embd, n_embd)?; + let q = qkv.narrow(1, 0, n_embd).unwrap(); + let k = qkv.narrow(1, n_embd, n_embd).unwrap(); + let v = qkv.narrow(1, 2 * n_embd, n_embd).unwrap(); let target_dim = [t, self.n_head, c / self.n_head]; - let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?; - let q = self.apply_rotary_emb(&q, freqs_cis)?; - let k = self.apply_rotary_emb(&k, freqs_cis)?; + let k = k + .reshape(target_dim.as_slice()) + .unwrap() + .transpose(0, 1) + .unwrap(); + let q = q + .reshape(target_dim.as_slice()) + .unwrap() + .transpose(0, 1) + .unwrap(); + let v = v + .reshape(target_dim.as_slice()) + .unwrap() + .transpose(0, 1) + .unwrap(); + let q = self.apply_rotary_emb(&q, freqs_cis).unwrap(); + let k = self.apply_rotary_emb(&k, freqs_cis).unwrap(); let k_shape = k.shape(); - let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?; - let mask = self.cache.mask(t)?.broadcast_as(att.shape())?; - let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; - let att = att.softmax(att.rank() - 1)?; + let att = (q.matmul(&k.t().unwrap()).unwrap() + / (*k_shape.dims().last().unwrap() as f64).sqrt()) + .unwrap(); + let mask = self + .cache + .mask(t) + .unwrap() + .broadcast_as(att.shape()) + .unwrap(); + let att = masked_fill(&att, &mask, f32::NEG_INFINITY).unwrap(); + let att = att.softmax(att.rank() - 1).unwrap(); // Convert to contiguous as matmul doesn't support strided vs for now. - let y = att.matmul(&v.contiguous()?)?; - let y = y.transpose(0, 1)?.reshape(&[t, c])?; - let y = self.c_proj.forward(&y)?; + let y = att.matmul(&v.contiguous().unwrap()).unwrap(); + let y = y.transpose(0, 1).unwrap().reshape(&[t, c]).unwrap(); + let y = y.to_dtype(DType::F16).unwrap(); + let y = self.c_proj.forward(&y).unwrap(); Ok(y) } } @@ -326,8 +360,13 @@ impl Block { } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { - let x = (self.attn.forward(&self.rms_1.forward(x)?, freqs_cis)? + x)?; - let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?; + let x = (self + .attn + .forward(&self.rms_1.forward(x).unwrap(), freqs_cis) + .unwrap() + + x) + .unwrap(); + let x = (self.mlp.forward(&self.rms_2.forward(&x).unwrap()).unwrap() + x).unwrap(); Ok(x) } } @@ -351,18 +390,18 @@ impl Llama { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { // TODO: Support for mini-batches? (i.e. r2) - let t = x.shape().r1()?; - let x = self.wte.forward(x)?; - let mut x = x.to_dtype(DType::F32)?; + let t = x.shape().r1().unwrap(); + let mut x = self.wte.forward(x).unwrap(); for block in self.blocks.iter() { - x = block.forward(&x, freqs_cis)?; + x = block.forward(&x, freqs_cis).unwrap(); } - let x = self.ln_f.forward(&x)?; - let x = x.narrow(0, t - 1, 1)?; - let logits = self.lm_head.forward(&x)?; - let (b, vocab_size) = logits.shape().r2()?; + let x = self.ln_f.forward(&x).unwrap(); + let x = x.narrow(0, t - 1, 1).unwrap(); + let logits = self.lm_head.forward(&x).unwrap(); + let logits = logits.to_dtype(DType::F32)?; + let (b, vocab_size) = logits.shape().r2().unwrap(); assert_eq!(b, 1); - Ok(logits.reshape(vocab_size)?) + Ok(logits.reshape(vocab_size).unwrap()) } } @@ -374,16 +413,18 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> { .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .collect(); let arange: Vec<_> = (0..seq_len).map(|c| c as f32).collect(); - let theta = Tensor::new(theta.as_slice(), device)?; - let arange = Tensor::new(arange.as_slice(), device)?; + let theta = Tensor::new(theta.as_slice(), device).unwrap(); + let arange = Tensor::new(arange.as_slice(), device).unwrap(); let idx_theta = arange - .reshape((arange.elem_count(), 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; + .reshape((arange.elem_count(), 1)) + .unwrap() + .matmul(&theta.reshape((1, theta.elem_count())).unwrap()) + .unwrap(); let shape = [1, seq_len, n_elem / 2, 1]; - let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; - let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?; + let idx_theta_cos = idx_theta.cos().unwrap().reshape(&shape).unwrap(); + let idx_theta_sin = idx_theta.sin().unwrap().reshape(&shape).unwrap(); let last_dim = idx_theta_cos.rank() - 1; - Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim)?) + Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], last_dim).unwrap()) } #[derive(Parser, Debug)] @@ -418,19 +459,19 @@ async fn main() -> Result<()> { let device = if args.cpu { Device::Cpu } else { - Device::new_cuda(0)? + Device::new_cuda(0).unwrap() }; - let api = Api::new()?; let config = Config::config_7b(); let cache = Cache::new(&device); let start = std::time::Instant::now(); let (llama, tokenizer_filename) = if args.npy { println!("building the model (NPY)"); ( - Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?, + Llama::load_npy(&device, "/data/llama.npz", &cache, &config).unwrap(), std::path::Path::new("llama-tokenizer.json").to_path_buf(), ) } else { + let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let mut filenames = vec![]; @@ -444,20 +485,23 @@ async fn main() -> Result<()> { println!("building the model (SF)"); ( - Llama::load(&device, &filenames, &cache, &config)?, + Llama::load(&device, &filenames, &cache, &config).unwrap(), tokenizer_filename, ) }; println!("Loaded in {:?}", start.elapsed()); - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename) + .map_err(E::msg) + .unwrap(); let mut tokens = tokenizer .encode(START_PROMPT, true) - .map_err(E::msg)? + .map_err(E::msg) + .unwrap() .get_ids() .to_vec(); println!("pre-computing the positional embeddings"); - let freqs_cis = precompute_freqs_cis(&config, &device)?; + let freqs_cis = precompute_freqs_cis(&config, &device).unwrap(); println!("starting the inference loop"); let mut new_tokens = vec![]; let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); @@ -465,18 +509,21 @@ async fn main() -> Result<()> { for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; - let input = Tensor::new(ctxt, &device)?; - let logits = llama.forward(&input, &freqs_cis)?; + let input = Tensor::new(ctxt, &device).unwrap(); + let logits = llama.forward(&input, &freqs_cis).unwrap(); let next_token = if let Some(temperature) = args.temperature { println!("Sampling with temperature {temperature:?}"); - let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?; - let logits_v: Vec<f32> = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + let prs = (&logits / temperature) + .unwrap() + .softmax(logits.rank() - 1) + .unwrap(); + let logits_v: Vec<f32> = prs.to_vec1().unwrap(); + let distr = rand::distributions::WeightedIndex::new(&logits_v).unwrap(); distr.sample(&mut rng) as u32 } else { - let logits_v: Vec<f32> = logits.to_vec1()?; + let logits_v: Vec<f32> = logits.to_vec1().unwrap(); logits_v .iter() .enumerate() @@ -491,7 +538,10 @@ async fn main() -> Result<()> { "{} token: {} '{}'", index + 1, next_token, - tokenizer.decode(vec![next_token], true).map_err(E::msg)? + tokenizer + .decode(vec![next_token], true) + .map_err(E::msg) + .unwrap() ); } let dt = start_gen.elapsed(); @@ -499,7 +549,7 @@ async fn main() -> Result<()> { "{} tokens generated ({} token/s)\n----\n{}\n----", args.sample_len, args.sample_len as f64 / dt.as_secs_f64(), - tokenizer.decode(new_tokens, true).map_err(E::msg)? + tokenizer.decode(new_tokens, true).map_err(E::msg).unwrap() ); Ok(()) } |