diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 632ef116..b5e64454 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -194,7 +194,7 @@ impl Dim for usize { if dim >= shape.dims().len() { Err(Error::DimOutOfRange { shape: shape.clone(), - dim, + dim: dim as i32, op, })? } else { @@ -207,7 +207,7 @@ impl Dim for usize { if dim > shape.dims().len() { Err(Error::DimOutOfRange { shape: shape.clone(), - dim, + dim: dim as i32, op, })? } else { @@ -221,30 +221,36 @@ pub enum D { Minus2, } +impl D { + fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error { + let dim = match self { + Self::Minus1 => -1, + Self::Minus2 => -2, + }; + Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + } + } +} + impl Dim for D { fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> { let rank = shape.rank(); match self { Self::Minus1 if rank >= 1 => Ok(rank - 1), Self::Minus2 if rank >= 2 => Ok(rank - 2), - _ => Err(Error::DimOutOfRange { - shape: shape.clone(), - dim: 42, // TODO: Have an adequate error - op, - }), + _ => Err(self.out_of_range(shape, op)), } } fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { let rank = shape.rank(); match self { - Self::Minus1 if rank >= 1 => Ok(rank), - Self::Minus2 if rank >= 2 => Ok(rank - 1), - _ => Err(Error::DimOutOfRange { - shape: shape.clone(), - dim: 42, // TODO: Have an adequate error - op, - }), + Self::Minus1 => Ok(rank), + Self::Minus2 if rank >= 1 => Ok(rank - 1), + _ => Err(self.out_of_range(shape, op)), } } } |