summaryrefslogtreecommitdiff
path: root/candle-examples/examples/llama2-c
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-01 09:40:34 +0100
committerGitHub <noreply@github.com>2023-08-01 09:40:34 +0100
commit614f911e9e91eefafb55c7701fea712413625d4b (patch)
tree4463b9d0304a65079b42c4b4918f414a0c5d6bcc /candle-examples/examples/llama2-c
parente1e8127f154e83c3c8877033c3c50344cca06083 (diff)
downloadcandle-614f911e9e91eefafb55c7701fea712413625d4b.tar.gz
candle-614f911e9e91eefafb55c7701fea712413625d4b.tar.bz2
candle-614f911e9e91eefafb55c7701fea712413625d4b.zip
Add some batcher variants that handle errors. (#294)
Diffstat (limited to 'candle-examples/examples/llama2-c')
-rw-r--r--candle-examples/examples/llama2-c/main.rs8
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index ff2a53fe..2cf71bb5 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -324,12 +324,12 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
None
} else {
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
- let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
- let targets = Tensor::new(&tokens[1..], &device).ok();
- inputs.zip(targets)
+ let inputs = Tensor::new(&tokens[..seq_len], &device);
+ let targets = Tensor::new(&tokens[1..], &device);
+ Some(inputs.and_then(|inputs| targets.map(|targets| (inputs, targets))))
}
});
- let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
+ let batch_iter = candle_nn::dataset::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0)?;