use candle::{Result, Tensor}; pub struct Batcher<I> { inner: I, batch_size: usize, return_last_incomplete_batch: bool, } impl<I> Batcher<I> { fn new(inner: I) -> Self { Self { inner, batch_size: 16, return_last_incomplete_batch: false, } } pub fn batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; self } pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { self.return_last_incomplete_batch = r; self } } pub struct Iter1<I: Iterator<Item = Tensor>> { inner: I, } pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> { inner: I, } impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> { pub fn new1(inner: I) -> Self { Self::new(Iter1 { inner }) } } impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> { pub fn new2(inner: I) -> Self { Self::new(Iter2 { inner }) } } pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> { inner: I, } pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> { inner: I, } impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> { pub fn new_r1(inner: I) -> Self { Self::new(IterResult1 { inner }) } } impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> { pub fn new_r2(inner: I) -> Self { Self::new(IterResult2 { inner }) } } impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { type Item = Result<Tensor>; fn next(&mut self) -> Option<Self::Item> { let mut items = Vec::with_capacity(self.batch_size); for _i in 0..self.batch_size { // We have two levels of inner here so that we can have two implementations of the // Iterator trait that are different for Iter1 and Iter2. If rust gets better // specialization at some point we can get rid of this. match self.inner.inner.next() { Some(item) => items.push(item), None => { if self.return_last_incomplete_batch { break; } return None; } } } Some(Tensor::stack(&items, 0)) } } impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { type Item = Result<(Tensor, Tensor)>; fn next(&mut self) -> Option<Self::Item> { let mut xs = Vec::with_capacity(self.batch_size); let mut ys = Vec::with_capacity(self.batch_size); for _i in 0..self.batch_size { match self.inner.inner.next() { Some((x, y)) => { xs.push(x); ys.push(y) } None => { if self.return_last_incomplete_batch { break; } return None; } } } let xs = Tensor::stack(&xs, 0); let ys = Tensor::stack(&ys, 0); Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) } } impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { type Item = Result<Tensor>; fn next(&mut self) -> Option<Self::Item> { let mut items = Vec::with_capacity(self.batch_size); for _i in 0..self.batch_size { // We have two levels of inner here so that we can have two implementations of the // Iterator trait that are different for Iter1 and Iter2. If rust gets better // specialization at some point we can get rid of this. match self.inner.inner.next() { Some(item) => items.push(item), None => { if self.return_last_incomplete_batch { break; } return None; } } } let items = items.into_iter().collect::<Result<Vec<Tensor>>>(); Some(items.and_then(|items| Tensor::stack(&items, 0))) } } impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> { type Item = Result<(Tensor, Tensor)>; fn next(&mut self) -> Option<Self::Item> { let mut xs = Vec::with_capacity(self.batch_size); let mut ys = Vec::with_capacity(self.batch_size); let mut errs = vec![]; for _i in 0..self.batch_size { match self.inner.inner.next() { Some(Ok((x, y))) => { xs.push(x); ys.push(y) } Some(Err(err)) => errs.push(err), None => { if self.return_last_incomplete_batch { break; } return None; } } } if !errs.is_empty() { return Some(Err(errs.swap_remove(0))); } let xs = Tensor::stack(&xs, 0); let ys = Tensor::stack(&ys, 0); Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) } }