summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-01 09:16:10 +0100
committerGitHub <noreply@github.com>2023-08-01 09:16:10 +0100
commite1e8127f154e83c3c8877033c3c50344cca06083 (patch)
treead81f648b3007542c42a8ec76b4bb268135e2ff1 /candle-examples/examples/llama2-c
parentfa98ca0c35861e15a532528ed27df0bd40bb4ce5 (diff)
downloadcandle-e1e8127f154e83c3c8877033c3c50344cca06083.tar.gz
candle-e1e8127f154e83c3c8877033c3c50344cca06083.tar.bz2
candle-e1e8127f154e83c3c8877033c3c50344cca06083.zip
Add the batcher. (#293)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs32
1 files changed, 14 insertions, 18 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index f9bbe149..ff2a53fe 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
println!("dataset loaded and encoded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
- let mut inputs = vec![];
- let mut targets = vec![];
- for start_idx in (0..tokens.len()).step_by(seq_len) {
+ let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
if start_idx + seq_len + 1 > tokens.len() {
- break;
- }
- let tokens = &tokens[start_idx..start_idx + seq_len + 1];
- let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
- let targets_ = Tensor::new(&tokens[1..], &device)?;
- inputs.push(inputs_);
- targets.push(targets_);
- if inputs.len() >= args.batch_size {
- let inp = Tensor::stack(&inputs, 0)?;
- let tgt = Tensor::stack(&targets, 0)?;
- let logits = model.forward(&inp, 0)?;
- let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
- println!("{}", loss.to_vec0::<f32>()?);
- inputs.clear();
- targets.clear();
+ None
+ } else {
+ let tokens = &tokens[start_idx..start_idx + seq_len + 1];
+ let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
+ let targets = Tensor::new(&tokens[1..], &device).ok();
+ inputs.zip(targets)
}
+ });
+ let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
+ for inp_tgt in batch_iter {
+ let (inp, tgt) = inp_tgt?;
+ let logits = model.forward(&inp, 0)?;
+ let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
+ println!("{}", loss.to_vec0::<f32>()?);
}
Ok(())
}