diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index a5e21aad..83d11c09 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape { macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + pub fn $fn_name(dims: &[usize]) -> Result<$out_type> { + if dims.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: dims.len(), + shape: Shape::from(dims), + } + .bt()) + } else { + Ok($dims(dims)) + } + } + impl Shape { pub fn $fn_name(&self) -> Result<$out_type> { - if self.0.len() != $cnt { - Err(Error::UnexpectedNumberOfDims { - expected: $cnt, - got: self.0.len(), - shape: self.clone(), - } - .bt()) - } else { - Ok($dims(&self.0)) - } + $fn_name(self.0.as_slice()) } } + impl crate::Tensor { pub fn $fn_name(&self) -> Result<$out_type> { self.shape().$fn_name() @@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } -extract_dims!(dims0, 0, |_: &Vec<usize>| (), ()); +extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( |