diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 982f9db0..b016ead5 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -87,6 +87,12 @@ macro_rules! extract_dims { } } } + impl crate::Tensor { + pub fn $fn_name(&self) -> Result<$out_type> { + self.shape().$fn_name() + } + } + impl std::convert::TryInto<$out_type> for Shape { type Error = crate::Error; fn try_into(self) -> std::result::Result<$out_type, Self::Error> { @@ -328,23 +334,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } -extract_dims!(r0, 0, |_: &Vec<usize>| (), ()); -extract_dims!(r1, 1, |d: &[usize]| d[0], usize); -extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); +extract_dims!(dims0, 0, |_: &Vec<usize>| (), ()); +extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); +extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( - r3, + dims3, 3, |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize) ); extract_dims!( - r4, + dims4, 4, |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize) ); extract_dims!( - r5, + dims5, 5, |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), (usize, usize, usize, usize, usize) |