diff options
Diffstat (limited to 'candle-nn/src/dataset.rs')
-rw-r--r-- | candle-nn/src/dataset.rs | 171 |
1 files changed, 0 insertions, 171 deletions
diff --git a/candle-nn/src/dataset.rs b/candle-nn/src/dataset.rs deleted file mode 100644 index b74f1417..00000000 --- a/candle-nn/src/dataset.rs +++ /dev/null @@ -1,171 +0,0 @@ -use candle::{Result, Tensor}; - -pub struct Batcher<I> { - inner: I, - batch_size: usize, - return_last_incomplete_batch: bool, -} - -impl<I> Batcher<I> { - fn new(inner: I) -> Self { - Self { - inner, - batch_size: 16, - return_last_incomplete_batch: false, - } - } - - pub fn batch_size(mut self, batch_size: usize) -> Self { - self.batch_size = batch_size; - self - } - - pub fn return_last_incomplete_batch(mut self, r: bool) -> Self { - self.return_last_incomplete_batch = r; - self - } -} - -pub struct Iter1<I: Iterator<Item = Tensor>> { - inner: I, -} - -pub struct Iter2<I: Iterator<Item = (Tensor, Tensor)>> { - inner: I, -} - -impl<I: Iterator<Item = Tensor>> Batcher<Iter1<I>> { - pub fn new1(inner: I) -> Self { - Self::new(Iter1 { inner }) - } -} - -impl<I: Iterator<Item = (Tensor, Tensor)>> Batcher<Iter2<I>> { - pub fn new2(inner: I) -> Self { - Self::new(Iter2 { inner }) - } -} - -pub struct IterResult1<I: Iterator<Item = Result<Tensor>>> { - inner: I, -} - -pub struct IterResult2<I: Iterator<Item = Result<(Tensor, Tensor)>>> { - inner: I, -} - -impl<I: Iterator<Item = Result<Tensor>>> Batcher<IterResult1<I>> { - pub fn new_r1(inner: I) -> Self { - Self::new(IterResult1 { inner }) - } -} - -impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Batcher<IterResult2<I>> { - pub fn new_r2(inner: I) -> Self { - Self::new(IterResult2 { inner }) - } -} - -impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { - type Item = Result<Tensor>; - - fn next(&mut self) -> Option<Self::Item> { - let mut items = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - // We have two levels of inner here so that we can have two implementations of the - // Iterator trait that are different for Iter1 and Iter2. If rust gets better - // specialization at some point we can get rid of this. - match self.inner.inner.next() { - Some(item) => items.push(item), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - Some(Tensor::stack(&items, 0)) - } -} - -impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option<Self::Item> { - let mut xs = Vec::with_capacity(self.batch_size); - let mut ys = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - match self.inner.inner.next() { - Some((x, y)) => { - xs.push(x); - ys.push(y) - } - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - let xs = Tensor::stack(&xs, 0); - let ys = Tensor::stack(&ys, 0); - Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) - } -} - -impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { - type Item = Result<Tensor>; - - fn next(&mut self) -> Option<Self::Item> { - let mut items = Vec::with_capacity(self.batch_size); - for _i in 0..self.batch_size { - // We have two levels of inner here so that we can have two implementations of the - // Iterator trait that are different for Iter1 and Iter2. If rust gets better - // specialization at some point we can get rid of this. - match self.inner.inner.next() { - Some(item) => items.push(item), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - let items = items.into_iter().collect::<Result<Vec<Tensor>>>(); - Some(items.and_then(|items| Tensor::stack(&items, 0))) - } -} - -impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResult2<I>> { - type Item = Result<(Tensor, Tensor)>; - - fn next(&mut self) -> Option<Self::Item> { - let mut xs = Vec::with_capacity(self.batch_size); - let mut ys = Vec::with_capacity(self.batch_size); - let mut errs = vec![]; - for _i in 0..self.batch_size { - match self.inner.inner.next() { - Some(Ok((x, y))) => { - xs.push(x); - ys.push(y) - } - Some(Err(err)) => errs.push(err), - None => { - if self.return_last_incomplete_batch { - break; - } - return None; - } - } - } - if !errs.is_empty() { - return Some(Err(errs.swap_remove(0))); - } - let xs = Tensor::stack(&xs, 0); - let ys = Tensor::stack(&ys, 0); - Some(xs.and_then(|xs| ys.map(|ys| (xs, ys)))) - } -} |