summaryrefslogtreecommitdiff
path: root/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/shape.rs')
-rw-r--r--src/shape.rs13
1 files changed, 12 insertions, 1 deletions
diff --git a/src/shape.rs b/src/shape.rs
index ebc497cf..aa66e706 100644
--- a/src/shape.rs
+++ b/src/shape.rs
@@ -1,7 +1,7 @@
use crate::{Error, Result};
#[derive(Clone, PartialEq, Eq)]
-pub struct Shape(pub(crate) Vec<usize>);
+pub struct Shape(Vec<usize>);
impl std::fmt::Debug for Shape {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -63,6 +63,12 @@ impl From<(usize, usize, usize)> for Shape {
}
}
+impl From<Vec<usize>> for Shape {
+ fn from(dims: Vec<usize>) -> Self {
+ Self(dims)
+ }
+}
+
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
pub fn $fn_name(&self) -> Result<$out_type> {
@@ -142,6 +148,11 @@ impl Shape {
}
true
}
+
+ pub fn extend(mut self, additional_dims: &[usize]) -> Self {
+ self.0.extend(additional_dims);
+ self
+ }
}
#[cfg(test)]