summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorlaurent <laurent.mazare@gmail.com>2024-02-22 10:22:03 +0100
committerlaurent <laurent.mazare@gmail.com>2024-02-22 10:22:03 +0100
commit544018b6d0ffa0a4b0ac6c30de10ec2012765fcb (patch)
treee309f5185f5277235879fdbe9e09608343094c5f /candle-examples/examples/llama2-c
parentc753f72c8552ba3e108bd3f1a04971e8abbf3012 (diff)
downloadcandle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.tar.gz
candle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.tar.bz2
candle-544018b6d0ffa0a4b0ac6c30de10ec2012765fcb.zip
Explicit caching in llama2.c.
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs30
-rw-r--r--candle-examples/examples/llama2-c/training.rs11
2 files changed, 21 insertions, 20 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index 27ebc80f..1a82bf1f 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;
-use model::{Config, Llama};
+use model::{Cache, Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
@@ -160,10 +160,10 @@ enum Model {
}
impl Model {
- fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
+ fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
match self {
- Self::Llama(l) => Ok(l.forward(xs, pos)?),
- Self::QLlama(l) => Ok(l.forward(xs, pos)?),
+ Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
+ Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
}
}
}
@@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
- let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
- let model = Llama::load(vb, &cache, config)?;
+ let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
+ let model = Llama::load(vb, config)?;
let tokens = match &args.pretokenized_dir {
None => {
@@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
- let logits = model.forward(&inp, 0)?;
+ let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
@@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
- let (model, config) = if is_gguf {
+ let (model, config, mut cache) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
@@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
- let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
- (model, config)
+ let model = Model::QLlama(QLlama::load(vb, config.clone())?);
+ (model, config, cache)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
- let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
- (model, config)
+ let model = Model::Llama(Llama::load(vb, config.clone())?);
+ (model, config, cache)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
@@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
- let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
- (model, config)
+ let model = Model::Llama(Llama::load(vb, config.clone())?);
+ (model, config, cache)
};
println!("starting the inference loop");
@@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
- let logits = model.forward(&input, index_pos)?;
+ let logits = model.forward(&input, index_pos, &mut cache)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits
diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs
index b2aa0889..c83ca43f 100644
--- a/candle-examples/examples/llama2-c/training.rs
+++ b/candle-examples/examples/llama2-c/training.rs
@@ -8,6 +8,7 @@ fn valid_loss(
model: &Llama,
args: &crate::TrainingCmd,
device: &Device,
+ cache: &mut Cache,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@@ -15,7 +16,7 @@ fn valid_loss(
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?;
- let logits = model.forward(&inp, 0)?;
+ let logits = model.forward(&inp, 0, cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1;
@@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
- let cache = Cache::new(false, &config, vb.pp("rot"))?;
- let model = Llama::load(vb, &cache, config)?;
+ let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
+ let model = Llama::load(vb, config)?;
let params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
@@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
- let logits = model.forward(&inp, 0)?;
+ let logits = model.forward(&inp, 0, &mut cache)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss.
- let loss = valid_loss(&dataset, &model, args, &device)?;
+ let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
println!("{batch_index} {loss}");
}
if batch_index > 0 && batch_index % 1000 == 0 {