summaryrefslogtreecommitdiff
path: root/candle-datasets/src/batcher.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-05 08:56:50 +0100
committerGitHub <noreply@github.com>2023-08-05 08:56:50 +0100
commit620f83cf66073f033d1fdc9846123c155422677e (patch)
tree028fd4bf3e3318bc172ffffb57cc4a5cac57bb50 /candle-datasets/src/batcher.rs
parentf7b2a0391d4a96607b5f208164e365b50ad0bbf7 (diff)
downloadcandle-620f83cf66073f033d1fdc9846123c155422677e.tar.gz
candle-620f83cf66073f033d1fdc9846123c155422677e.tar.bz2
candle-620f83cf66073f033d1fdc9846123c155422677e.zip
Add the candle-datasets crate (#322)
* Move the vision datasets to a separate crate. * Move the batcher bits. * Update the readme. * Move the tiny-stories bits. --------- Co-authored-by: Jane Doe <jane.doe@example.org>
Diffstat (limited to 'candle-datasets/src/batcher.rs')
-rw-r--r--candle-datasets/src/batcher.rs171
1 files changed, 171 insertions, 0 deletions
diff --git a/candle-datasets/src/batcher.rs b/candle-datasets/src/batcher.rs
new file mode 100644
index 00000000..b74f1417
--- /dev/null
+++ b/candle-datasets/src/batcher.rs
@@ -0,0 +1,171 @@
+use candle::{Result, Tensor};
+
+pub struct Batcher<I> {
+ inner: I,
+ batch_size: usize,
+ return_last_incomplete_batch: bool,
+}
+
+impl<I> Batcher<I> {
+ fn new(inner: I) -> Self {
+ Self {
+ inner,
+ batch_size: 16,
+ return_last_incomplete_batch: false,
+ }
+ }
+
+ pub fn batch_size(mut self, batch_size: usize) -> Self {
+ self.batch_size = batch_size;
+ self
+ }
+
+ pub fn return_last_incomplete_batch(mut self, r: bool) -> Self {
+ self.return_last_incomplete_batch = r;
+ self
+ }
+}
+
+pub struct Iter1<I: Iterator<Item = Tensor>> {
+ inner: I,
+}
+
+pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> {
+ inner: I,
+}
+
+impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> {
+ pub fn new1(inner: I) -> Self {
+ Self::new(Iter1 { inner })
+ }
+}
+
+impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> {
+ pub fn new2(inner: I) -> Self {
+ Self::new(Iter2 { inner })
+ }
+}
+
+pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> {
+ inner: I,
+}
+
+pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> {
+ inner: I,
+}
+
+impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> {
+ pub fn new_r1(inner: I) -> Self {
+ Self::new(IterResult1 { inner })
+ }
+}
+
+impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> {
+ pub fn new_r2(inner: I) -> Self {
+ Self::new(IterResult2 { inner })
+ }
+}
+
+impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> {
+ type Item = Result<Tensor>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut items = Vec::with_capacity(self.batch_size);
+ for _i in 0..self.batch_size {
+ // We have two levels of inner here so that we can have two implementations of the
+ // Iterator trait that are different for Iter1 and Iter2. If rust gets better
+ // specialization at some point we can get rid of this.
+ match self.inner.inner.next() {
+ Some(item) => items.push(item),
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ Some(Tensor::stack(&items, 0))
+ }
+}
+
+impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> {
+ type Item = Result<(Tensor, Tensor)>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut xs = Vec::with_capacity(self.batch_size);
+ let mut ys = Vec::with_capacity(self.batch_size);
+ for _i in 0..self.batch_size {
+ match self.inner.inner.next() {
+ Some((x, y)) => {
+ xs.push(x);
+ ys.push(y)
+ }
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ let xs = Tensor::stack(&xs, 0);
+ let ys = Tensor::stack(&ys, 0);
+ Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
+ }
+}
+
+impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> {
+ type Item = Result<Tensor>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut items = Vec::with_capacity(self.batch_size);
+ for _i in 0..self.batch_size {
+ // We have two levels of inner here so that we can have two implementations of the
+ // Iterator trait that are different for Iter1 and Iter2. If rust gets better
+ // specialization at some point we can get rid of this.
+ match self.inner.inner.next() {
+ Some(item) => items.push(item),
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ let items = items.into_iter().collect::<Result<Vec<Tensor>>>();
+ Some(items.and_then(|items| Tensor::stack(&items, 0)))
+ }
+}
+
+impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> {
+ type Item = Result<(Tensor, Tensor)>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let mut xs = Vec::with_capacity(self.batch_size);
+ let mut ys = Vec::with_capacity(self.batch_size);
+ let mut errs = vec![];
+ for _i in 0..self.batch_size {
+ match self.inner.inner.next() {
+ Some(Ok((x, y))) => {
+ xs.push(x);
+ ys.push(y)
+ }
+ Some(Err(err)) => errs.push(err),
+ None => {
+ if self.return_last_incomplete_batch {
+ break;
+ }
+ return None;
+ }
+ }
+ }
+ if !errs.is_empty() {
+ return Some(Err(errs.swap_remove(0)));
+ }
+ let xs = Tensor::stack(&xs, 0);
+ let ys = Tensor::stack(&ys, 0);
+ Some(xs.and_then(|xs| ys.map(|ys| (xs, ys))))
+ }
+}