diff options
Diffstat (limited to 'candle-core/src/strided_index.rs')
-rw-r--r-- | candle-core/src/strided_index.rs | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 2a23e9ec..e6d2868b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -1,27 +1,28 @@ +use crate::Layout; + /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option<usize>, multi_index: Vec<usize>, - dims: &'a [usize], - stride: &'a [usize], + layout: &'a Layout, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + pub(crate) fn new(layout: &'a Layout) -> Self { + let dims = layout.dims(); let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None } else { // This applies to the scalar case. - Some(0) + Some(layout.start_offset()) }; StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - dims, - stride, + layout, } } } @@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + for (multi_i, max_i) in self + .multi_index + .iter_mut() + .zip(self.layout.dims().iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.stride.iter()) + .zip(self.layout.stride().iter()) .map(|(&x, &y)| x * y) - .sum(); + .sum::<usize>() + + self.layout.start_offset(); Some(next_storage_index) } else { None |