diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 1152dc3e..632ef116 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,7 @@ impl Shape { pub trait Dim { fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>; } impl Dim for usize { @@ -200,6 +201,19 @@ impl Dim for usize { Ok(dim) } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } } pub enum D { @@ -220,6 +234,19 @@ impl Dim for D { }), } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank), + Self::Minus2 if rank >= 2 => Ok(rank - 1), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } } #[cfg(test)] |