diff options
Diffstat (limited to 'candle-datasets/src/batcher.rs')
-rw-r--r-- | candle-datasets/src/batcher.rs | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/candle-datasets/src/batcher.rs b/candle-datasets/src/batcher.rs new file mode 100644 index 00000000..b74f1417 --- /dev/null +++ b/candle-datasets/src/batcher.rs @@ -0,0 +1,171 @@ +use candle::{Result, Tensor}; + +pub struct Batcher<I> { + inner: I, + batch_size: usize, + return_last_incomplete_batch: bool, +} + +impl<I> Batcher<I> { + fn new(inner: I) -> Self { + Self { + inner, + batch_size: 16, + return_last_incomplete_batch: false, + } + } + + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { + self.return_last_incomplete_batch = r; + self + } +} + +pub struct Iter1<I: Iterator<Item = Tensor>> { + inner: I, +} + +pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> { + inner: I, +} + +impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> { + pub fn new1(inner: I) -> Self { + Self::new(Iter1 { inner }) + } +} + +impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> { + pub fn new2(inner: I) -> Self { + Self::new(Iter2 { inner }) + } +} + +pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> { + inner: I, +} + +pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> { + inner: I, +} + +impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> { + pub fn new_r1(inner: I) -> Self { + Self::new(IterResult1 { inner }) + } +} + +impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> { + pub fn new_r2(inner: I) -> Self { + Self::new(IterResult2 { inner }) + } +} + +impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { + type Item = Result<Tensor>; + + fn next(&mut self) -> Option<Self::Item> { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + Some(Tensor::stack(&items, 0)) + } +} + +impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option<Self::Item> { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some((x, y)) => { + xs.push(x); + ys.push(y) + } + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +} + +impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { + type Item = Result<Tensor>; + + fn next(&mut self) -> Option<Self::Item> { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let items = items.into_iter().collect::<Result<Vec<Tensor>>>(); + Some(items.and_then(|items| Tensor::stack(&items, 0))) + } +} + +impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option<Self::Item> { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + let mut errs = vec![]; + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some(Ok((x, y))) => { + xs.push(x); + ys.push(y) + } + Some(Err(err)) => errs.push(err), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + if !errs.is_empty() { + return Some(Err(errs.swap_remove(0))); + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +} |