diff options
Diffstat (limited to 'src/shape.rs')
-rw-r--r-- | src/shape.rs | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/src/shape.rs b/src/shape.rs index ebc497cf..aa66e706 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,7 +1,7 @@ use crate::{Error, Result}; #[derive(Clone, PartialEq, Eq)] -pub struct Shape(pub(crate) Vec<usize>); +pub struct Shape(Vec<usize>); impl std::fmt::Debug for Shape { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape { } } +impl From<Vec<usize>> for Shape { + fn from(dims: Vec<usize>) -> Self { + Self(dims) + } +} + macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { pub fn $fn_name(&self) -> Result<$out_type> { @@ -142,6 +148,11 @@ impl Shape { } true } + + pub fn extend(mut self, additional_dims: &[usize]) -> Self { + self.0.extend(additional_dims); + self + } } #[cfg(test)] |