diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index cc068004..1152dc3e 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -183,6 +183,45 @@ impl Shape { } } +pub trait Dim { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>; +} + +impl Dim for usize { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let dim = *self; + if dim >= shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } +} + +pub enum D { + Minus1, + Minus2, +} + +impl Dim for D { + fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank - 1), + Self::Minus2 if rank >= 2 => Ok(rank - 2), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } +} + #[cfg(test)] mod tests { use super::*; |