summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/llama2-c/main.rs32
-rw-r--r--candle-nn/src/dataset.rs96
-rw-r--r--candle-nn/src/lib.rs1
3 files changed, 111 insertions, 18 deletions
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index f9bbe149..ff2a53fe 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -319,26 +319,22 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
println!("dataset loaded and encoded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
- let mut inputs = vec![];
- let mut targets = vec![];
- for start_idx in (0..tokens.len()).step_by(seq_len) {
+ let iter = (0..tokens.len()).step_by(seq_len).flat_map(|start_idx| {
if start_idx + seq_len + 1 > tokens.len() {
- break;
- }
- let tokens = &tokens[start_idx..start_idx + seq_len + 1];
- let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
- let targets_ = Tensor::new(&tokens[1..], &device)?;
- inputs.push(inputs_);
- targets.push(targets_);
- if inputs.len() >= args.batch_size {
- let inp = Tensor::stack(&inputs, 0)?;
- let tgt = Tensor::stack(&targets, 0)?;
- let logits = model.forward(&inp, 0)?;
- let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
- println!("{}", loss.to_vec0::<f32>()?);
- inputs.clear();
- targets.clear();
+ None
+ } else {
+ let tokens = &tokens[start_idx..start_idx + seq_len + 1];
+ let inputs = Tensor::new(&tokens[..seq_len], &device).ok();
+ let targets = Tensor::new(&tokens[1..], &device).ok();
+ inputs.zip(targets)
}
+ });
+ let batch_iter = candle_nn::dataset::Batcher::new2(iter).batch_size(args.batch_size);
+ for inp_tgt in batch_iter {
+ let (inp, tgt) = inp_tgt?;
+ let logits = model.forward(&inp, 0)?;
+ let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
+ println!("{}", loss.to_vec0::<f32>()?);
}
Ok(())
}
diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs
new file mode 100644
index 00000000..affe7b48
--- /dev/null
+++ b/candle-nn/src/dataset.rs
@@ -0,0 +1,96 @@
+use candle::{Result, Tensor};
+
+pub struct Batcher<I> {
+ inner: I,
+ batch_size: usize,
+ return_last_incomplete_batch: bool,
+}
+
+impl<I> Batcher<I> {
+ fn new(inner: I) -> Self {
+ Self {
+ inner,
+ batch_size: 16,
+ return_last_incomplete_batch: false,
+ }
+ }
+
+ pub fn batch_size(mut self, batch_size: usize) -> Self {
+ self.batch_size = batch_size;
+ self
+ }
+
+ pub fn return_last_incomplete_batch(mut self, r: bool) -> Self {
+ self.return_last_incomplete_batch = r;
+ self
+ }
+}
+
+pub struct Iter1<I: Iterator<Item = Tensor>> {
+ inner: I,
+}
+
+pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> {
+ inner: I,
+}
+
+impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> {
+ pub fn new1(inner: I) -> Self {
+ Self::new(Iter1 { inner })
+ }
+}
+
+impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
+ pub fn new2(inner: I) -> Self {
+ Self::new(Iter2 { inner })
+ }
+}
+
+impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
+ type Item = Result<Tensor>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut items = Vec::with_capacity(self.batch_size);
+ for _i in 0..self.batch_size {
+ // We have two levels of inner here so that we can have two implementations of the
+ // Iterator trait that are different for Iter1 and Iter2. If rust gets better
+ // specialization at some point we can get rid of this.
+ match self.inner.inner.next() {
+ Some(item) => items.push(item),
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ Some(Tensor::stack(&items, 0))
+ }
+}
+
+impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
+ type Item = Result<(Tensor, Tensor)>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut xs = Vec::with_capacity(self.batch_size);
+ let mut ys = Vec::with_capacity(self.batch_size);
+ for _i in 0..self.batch_size {
+ match self.inner.inner.next() {
+ Some((x, y)) => {
+ xs.push(x);
+ ys.push(y)
+ }
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ let xs = Tensor::stack(&xs, 0);
+ let ys = Tensor::stack(&ys, 0);
+ Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
+ }
+}
diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs
index d0b62dbb..e8086238 100644
--- a/candle-nn/src/lib.rs
+++ b/candle-nn/src/lib.rs
@@ -2,6 +2,7 @@
// error type if needed or add some specialized cases on the candle-core side.
pub mod activation;
pub mod conv;
+pub mod dataset;
pub mod embedding;
pub mod init;
pub mod layer_norm;