diff options
author | laurent <laurent.mazare@gmail.com> | 2024-02-22 10:22:03 +0100 |
---|---|---|
committer | laurent <laurent.mazare@gmail.com> | 2024-02-22 10:22:03 +0100 |
commit | 544018b6d0ffa0a4b0ac6c30de10ec2012765fcb (patch) | |
tree | e309f5185f5277235879fdbe9e09608343094c5f /candle-examples/examples/llama2-c | |
parent | c753f72c8552ba3e108bd3f1a04971e8abbf3012 (diff) | |
download | candle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.tar.gz candle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.tar.bz2 candle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.zip |
Explicit caching in llama2.c.
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 30 | ||||
-rw-r--r-- | candle-examples/examples/llama2-c/training.rs | 11 |
2 files changed, 21 insertions, 20 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 27ebc80f..1a82bf1f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor; use std::io::Write; use tokenizers::Tokenizer; -use model::{Config, Llama}; +use model::{Cache, Config, Llama}; use qmodel::QLlama; use weights::TransformerWeights; @@ -160,10 +160,10 @@ enum Model { } impl Model { - fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> { + fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> { match self { - Self::Llama(l) => Ok(l.forward(xs, pos)?), - Self::QLlama(l) => Ok(l.forward(xs, pos)?), + Self::Llama(l) => Ok(l.forward(xs, pos, cache)?), + Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?), } } } @@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { let config = Config::from_reader(&mut file)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; - let cache = model::Cache::new(false, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; + let mut cache = Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, config)?; let tokens = match &args.pretokenized_dir { None => { @@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); for inp_tgt in batch_iter { let (inp, tgt) = inp_tgt?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, &mut cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; println!("{}", loss.to_vec0::<f32>()?); } @@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let is_safetensors = config_path .extension() .map_or(false, |v| v == "safetensors"); - let (model, config) = if is_gguf { + let (model, config, mut cache) = if is_gguf { let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? @@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; - let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::QLlama(QLlama::load(vb, config.clone())?); + (model, config, cache) } else if is_safetensors { let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::Llama(Llama::load(vb, config.clone())?); + (model, config, cache) } else { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; @@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::Llama(Llama::load(vb, config.clone())?); + (model, config, cache) }; println!("starting the inference loop"); @@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = model.forward(&input, index_pos)?; + let logits = model.forward(&input, index_pos, &mut cache)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { logits diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index b2aa0889..c83ca43f 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -8,6 +8,7 @@ fn valid_loss( model: &Llama, args: &crate::TrainingCmd, device: &Device, + cache: &mut Cache, ) -> Result<f64> { let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); @@ -15,7 +16,7 @@ fn valid_loss( let mut cnt = 0usize; for inp_tgt in batch_iter.take(50) { let (inp, tgt) = inp_tgt?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; sum_ce += loss.to_vec0::<f32>()? as f64; cnt += 1; @@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); - let cache = Cache::new(false, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; + let mut cache = Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, config)?; let params = candle_nn::ParamsAdamW { lr: args.learning_rate, ..Default::default() @@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; for (batch_index, batch) in batch_iter.enumerate() { let (inp, tgt) = batch?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, &mut cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; opt.backward_step(&loss)?; if batch_index > 0 && batch_index % 100 == 0 { // TODO: Add a way to deactivate the backprop graph tracking when computing the // validation loss. - let loss = valid_loss(&dataset, &model, args, &device)?; + let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?; println!("{batch_index} {loss}"); } if batch_index > 0 && batch_index % 1000 == 0 { |