summaryrefslogtreecommitdiff
path: root/candle-core/src/strided_index.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/strided_index.rs')
-rw-r--r--candle-core/src/strided_index.rs25
1 files changed, 16 insertions, 9 deletions
diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs
index 2a23e9ec..e6d2868b 100644
--- a/candle-core/src/strided_index.rs
+++ b/candle-core/src/strided_index.rs
@@ -1,27 +1,28 @@
+use crate::Layout;
+
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub(crate) struct StridedIndex<'a> {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
- dims: &'a [usize],
- stride: &'a [usize],
+ layout: &'a Layout,
}
impl<'a> StridedIndex<'a> {
- pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self {
+ pub(crate) fn new(layout: &'a Layout) -> Self {
+ let dims = layout.dims();
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
- Some(0)
+ Some(layout.start_offset())
};
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
- dims,
- stride,
+ layout,
}
}
}
@@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> {
Some(storage_index) => storage_index,
};
let mut updated = false;
- for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
+ for (multi_i, max_i) in self
+ .multi_index
+ .iter_mut()
+ .zip(self.layout.dims().iter())
+ .rev()
+ {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
@@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> {
let next_storage_index = self
.multi_index
.iter()
- .zip(self.stride.iter())
+ .zip(self.layout.stride().iter())
.map(|(&x, &y)| x * y)
- .sum();
+ .sum::<usize>()
+ + self.layout.start_offset();
Some(next_storage_index)
} else {
None