blob: e6d2868b70008fb2159a6391817241527b547c43 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
use crate::Layout;
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
layout: &'a Layout,
}
impl<'a> StridedIndex<'a> {
pub(crate) fn new(layout: &'a Layout) -> Self {
let dims = layout.dims();
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(layout.start_offset())
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
layout,
}
}
}
impl<'a> Iterator for StridedIndex<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_storage_index {
None => return None,
Some(storage_index) => storage_index,
};
let mut updated = false;
for (multi_i, max_i) in self
.multi_index
.iter_mut()
.zip(self.layout.dims().iter())
.rev()
{
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
break;
} else {
*multi_i = 0
}
}
self.next_storage_index = if updated {
let next_storage_index = self
.multi_index
.iter()
.zip(self.layout.stride().iter())
.map(|(&x, &y)| x * y)
.sum::<usize>()
+ self.layout.start_offset();
Some(next_storage_index)
} else {
None
};
Some(storage_index)
}
}
|