diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-06-28 15:59:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-28 15:59:53 +0100 |
commit | 0cfa21f26a88820fba91bb8ff02cf850eeeb97c3 (patch) | |
tree | efc1279fc9ba273425689e79ac5577801b1bddae /candle-core/src/strided_index.rs | |
parent | 8b4b2d1830e6fb5aed2c410256bb4e7076e5007d (diff) | |
parent | 6c9e6b5a99d4070be5c20d7c383e0ef7e3228260 (diff) | |
download | candle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.tar.gz candle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.tar.bz2 candle-0cfa21f26a88820fba91bb8ff02cf850eeeb97c3.zip |
Merge pull request #27 from LaurentMazare/layout-refactor
Refactor the stride/shape handling
Diffstat (limited to 'candle-core/src/strided_index.rs')
-rw-r--r-- | candle-core/src/strided_index.rs | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 2a23e9ec..e6d2868b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -1,27 +1,28 @@ +use crate::Layout; + /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option<usize>, multi_index: Vec<usize>, - dims: &'a [usize], - stride: &'a [usize], + layout: &'a Layout, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + pub(crate) fn new(layout: &'a Layout) -> Self { + let dims = layout.dims(); let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None } else { // This applies to the scalar case. - Some(0) + Some(layout.start_offset()) }; StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - dims, - stride, + layout, } } } @@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + for (multi_i, max_i) in self + .multi_index + .iter_mut() + .zip(self.layout.dims().iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.stride.iter()) + .zip(self.layout.stride().iter()) .map(|(&x, &y)| x * y) - .sum(); + .sum::<usize>() + + self.layout.start_offset(); Some(next_storage_index) } else { None |