diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-05 08:56:50 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-05 08:56:50 +0100 |
commit | 620f83cf66073f033d1fdc9846123c155422677e (patch) | |
tree | 028fd4bf3e3318bc172ffffb57cc4a5cac57bb50 /candle-datasets/src/batcher.rs | |
parent | f7b2a0391d4a96607b5f208164e365b50ad0bbf7 (diff) | |
download | candle-620f83cf66073f033d1fdc9846123c155422677e.tar.gz candle-620f83cf66073f033d1fdc9846123c155422677e.tar.bz2 candle-620f83cf66073f033d1fdc9846123c155422677e.zip |
Add the candle-datasets crate (#322)
* Move the vision datasets to a separate crate.
* Move the batcher bits.
* Update the readme.
* Move the tiny-stories bits.
---------
Co-authored-by: Jane Doe <jane.doe@example.org>
Diffstat (limited to 'candle-datasets/src/batcher.rs')
-rw-r--r-- | candle-datasets/src/batcher.rs | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/candle-datasets/src/batcher.rs b/candle-datasets/src/batcher.rs new file mode 100644 index 00000000..b74f1417 --- /dev/null +++ b/candle-datasets/src/batcher.rs @@ -0,0 +1,171 @@ +use candle::{Result, Tensor}; + +pub struct Batcher<I> { + inner: I, + batch_size: usize, + return_last_incomplete_batch: bool, +} + +impl<I> Batcher<I> { + fn new(inner: I) -> Self { + Self { + inner, + batch_size: 16, + return_last_incomplete_batch: false, + } + } + + pub fn batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { + self.return_last_incomplete_batch = r; + self + } +} + +pub struct Iter1<I: Iterator<Item = Tensor>> { + inner: I, +} + +pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> { + inner: I, +} + +impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> { + pub fn new1(inner: I) -> Self { + Self::new(Iter1 { inner }) + } +} + +impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> { + pub fn new2(inner: I) -> Self { + Self::new(Iter2 { inner }) + } +} + +pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> { + inner: I, +} + +pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> { + inner: I, +} + +impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> { + pub fn new_r1(inner: I) -> Self { + Self::new(IterResult1 { inner }) + } +} + +impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> { + pub fn new_r2(inner: I) -> Self { + Self::new(IterResult2 { inner }) + } +} + +impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { + type Item = Result<Tensor>; + + fn next(&mut self) -> Option<Self::Item> { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + Some(Tensor::stack(&items, 0)) + } +} + +impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option<Self::Item> { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some((x, y)) => { + xs.push(x); + ys.push(y) + } + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +} + +impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { + type Item = Result<Tensor>; + + fn next(&mut self) -> Option<Self::Item> { + let mut items = Vec::with_capacity(self.batch_size); + for _i in 0..self.batch_size { + // We have two levels of inner here so that we can have two implementations of the + // Iterator trait that are different for Iter1 and Iter2. If rust gets better + // specialization at some point we can get rid of this. + match self.inner.inner.next() { + Some(item) => items.push(item), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + let items = items.into_iter().collect::<Result<Vec<Tensor>>>(); + Some(items.and_then(|items| Tensor::stack(&items, 0))) + } +} + +impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> { + type Item = Result<(Tensor, Tensor)>; + + fn next(&mut self) -> Option<Self::Item> { + let mut xs = Vec::with_capacity(self.batch_size); + let mut ys = Vec::with_capacity(self.batch_size); + let mut errs = vec![]; + for _i in 0..self.batch_size { + match self.inner.inner.next() { + Some(Ok((x, y))) => { + xs.push(x); + ys.push(y) + } + Some(Err(err)) => errs.push(err), + None => { + if self.return_last_incomplete_batch { + break; + } + return None; + } + } + } + if !errs.is_empty() { + return Some(Err(errs.swap_remove(0))); + } + let xs = Tensor::stack(&xs, 0); + let ys = Tensor::stack(&ys, 0); + Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) + } +} |