diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-01 09:16:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-01 09:16:10 +0100 |
commit | e1e8127f154e83c3c8877033c3c50344cca06083 (patch) | |
tree | ad81f648b3007542c42a8ec76b4bb268135e2ff1 /candle-examples/examples/llama2-c | |
parent | fa98ca0c35861e15a532528ed27df0bd40bb4ce5 (diff) | |
download | candle-e1e8127f154e83c3c8877033c3c50344cca06083.tar.gz candle-e1e8127f154e83c3c8877033c3c50344cca06083.tar.bz2 candle-e1e8127f154e83c3c8877033c3c50344cca06083.zip |
Add the batcher. (#293)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r-- | candle-examples/examples/llama2-c/main.rs | 32 |
1 files changed, 14 insertions, 18 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index f9bbe149..ff2a53fe 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { println!("dataset loaded and encoded: {} tokens", tokens.len()); let seq_len = model.config.seq_len; - let mut inputs = vec![]; - let mut targets = vec![]; - for start_idx in (0..tokens.len()).step_by(seq_len) { + let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| { if start_idx + seq_len + 1 > tokens.len() { - break; - } - let tokens = &tokens[start_idx..start_idx + seq_len + 1]; - let inputs_ = Tensor::new(&tokens[..seq_len], &device)?; - let targets_ = Tensor::new(&tokens[1..], &device)?; - inputs.push(inputs_); - targets.push(targets_); - if inputs.len() >= args.batch_size { - let inp = Tensor::stack(&inputs, 0)?; - let tgt = Tensor::stack(&targets, 0)?; - let logits = model.forward(&inp, 0)?; - let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; - println!("{}", loss.to_vec0::<f32>()?); - inputs.clear(); - targets.clear(); + None + } else { + let tokens = &tokens[start_idx..start_idx + seq_len + 1]; + let inputs = Tensor::new(&tokens[..seq_len], &device).ok(); + let targets = Tensor::new(&tokens[1..], &device).ok(); + inputs.zip(targets) } + }); + let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size); + for inp_tgt in batch_iter { + let (inp, tgt) = inp_tgt?; + let logits = model.forward(&inp, 0)?; + let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; + println!("{}", loss.to_vec0::<f32>()?); } Ok(()) } |