diff options
author | hhllhhyyds <161805554+hhllhhyyds@users.noreply.github.com> | 2024-12-24 15:41:26 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-24 08:41:26 +0100 |
commit | 11aa30be10ebf42d10799a0726a874c74e30ad3e (patch) | |
tree | c30e22ab90cbb2a404c6763c7f64b97c93dee790 | |
parent | 1be6b090c7920c35f5492845d219e3a99ce4d115 (diff) | |
download | candle-11aa30be10ebf42d10799a0726a874c74e30ad3e.tar.gz candle-11aa30be10ebf42d10799a0726a874c74e30ad3e.tar.bz2 candle-11aa30be10ebf42d10799a0726a874c74e30ad3e.zip |
Fix Batcher iterator break when return_last_incomplete_batch and items.is_empty (#2654) (#2655)
-rw-r--r-- | candle-datasets/src/batcher.rs | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/candle-datasets/src/batcher.rs b/candle-datasets/src/batcher.rs index b74f1417..03e4bbef 100644 --- a/candle-datasets/src/batcher.rs +++ b/candle-datasets/src/batcher.rs @@ -78,7 +78,7 @@ impl<I: Iterator<Item = Tensor>> Iterator for Batcher<Iter1<I>> { match self.inner.inner.next() { Some(item) => items.push(item), None => { - if self.return_last_incomplete_batch { + if self.return_last_incomplete_batch && !items.is_empty() { break; } return None; @@ -102,7 +102,7 @@ impl<I: Iterator<Item = (Tensor, Tensor)>> Iterator for Batcher<Iter2<I>> { ys.push(y) } None => { - if self.return_last_incomplete_batch { + if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() { break; } return None; @@ -127,7 +127,7 @@ impl<I: Iterator<Item = Result<Tensor>>> Iterator for Batcher<IterResult1<I>> { match self.inner.inner.next() { Some(item) => items.push(item), None => { - if self.return_last_incomplete_batch { + if self.return_last_incomplete_batch && !items.is_empty() { break; } return None; @@ -154,7 +154,7 @@ impl<I: Iterator<Item = Result<(Tensor, Tensor)>>> Iterator for Batcher<IterResu } Some(Err(err)) => errs.push(err), None => { - if self.return_last_incomplete_batch { + if self.return_last_incomplete_batch && !xs.is_empty() && !ys.is_empty() { break; } return None; |