diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-27 11:57:27 +0200 |
---|---|---|
committer | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-06-27 11:57:27 +0200 |
commit | d7f729fb8f1d4b224f18ca3d7ae1163afe57a094 (patch) | |
tree | f60643690f3c7f34ae64923771cd568d75d85f5c /candle-core/src/strided_index.rs | |
parent | 6c4a960b15404b9307328fa4e2c929f813b6b092 (diff) | |
download | candle-d7f729fb8f1d4b224f18ca3d7ae1163afe57a094.tar.gz candle-d7f729fb8f1d4b224f18ca3d7ae1163afe57a094.tar.bz2 candle-d7f729fb8f1d4b224f18ca3d7ae1163afe57a094.zip |
Refactor the hierarchy.
Diffstat (limited to 'candle-core/src/strided_index.rs')
-rw-r--r-- | candle-core/src/strided_index.rs | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs new file mode 100644 index 00000000..2a23e9ec --- /dev/null +++ b/candle-core/src/strided_index.rs @@ -0,0 +1,61 @@ +/// An iterator over offset position for items of an N-dimensional arrays stored in a +/// flat buffer using some potential strides. +#[derive(Debug)] +pub(crate) struct StridedIndex<'a> { + next_storage_index: Option<usize>, + multi_index: Vec<usize>, + dims: &'a [usize], + stride: &'a [usize], +} + +impl<'a> StridedIndex<'a> { + pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + let elem_count: usize = dims.iter().product(); + let next_storage_index = if elem_count == 0 { + None + } else { + // This applies to the scalar case. + Some(0) + }; + StridedIndex { + next_storage_index, + multi_index: vec![0; dims.len()], + dims, + stride, + } + } +} + +impl<'a> Iterator for StridedIndex<'a> { + type Item = usize; + + fn next(&mut self) -> Option<Self::Item> { + let storage_index = match self.next_storage_index { + None => return None, + Some(storage_index) => storage_index, + }; + let mut updated = false; + for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + let next_i = *multi_i + 1; + if next_i < *max_i { + *multi_i = next_i; + updated = true; + break; + } else { + *multi_i = 0 + } + } + self.next_storage_index = if updated { + let next_storage_index = self + .multi_index + .iter() + .zip(self.stride.iter()) + .map(|(&x, &y)| x * y) + .sum(); + Some(next_storage_index) + } else { + None + }; + Some(storage_index) + } +} |