diff options
author | Nicolas Patry <patry.nicolas@protonmail.com> | 2023-07-10 15:21:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-10 15:21:24 +0200 |
commit | dc5825967957e28e6ac4f57da18c7963f2be542c (patch) | |
tree | f8249c4d0259c1c8f0c1e46c7f1ecd95da258580 /candle-core/src/shape.rs | |
parent | 204618b7d37229cd19a7f85ed38e6ab916e1e0d1 (diff) | |
parent | 9a667155fd554fe270561783f6708445e2deb929 (diff) | |
download | candle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.gz candle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.bz2 candle-dc5825967957e28e6ac4f57da18c7963f2be542c.zip |
Merge pull request #120 from LaurentMazare/some_doc_plus_fix_stack
Adding some doc + Extended `stack` to work with extra final dimensions.
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 1152dc3e..632ef116 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,7 @@ impl Shape { pub trait Dim { fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>; + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>; } impl Dim for usize { @@ -200,6 +201,19 @@ impl Dim for usize { Ok(dim) } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let dim = *self; + if dim > shape.dims().len() { + Err(Error::DimOutOfRange { + shape: shape.clone(), + dim, + op, + })? + } else { + Ok(dim) + } + } } pub enum D { @@ -220,6 +234,19 @@ impl Dim for D { }), } } + + fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> { + let rank = shape.rank(); + match self { + Self::Minus1 if rank >= 1 => Ok(rank), + Self::Minus2 if rank >= 2 => Ok(rank - 1), + _ => Err(Error::DimOutOfRange { + shape: shape.clone(), + dim: 42, // TODO: Have an adequate error + op, + }), + } + } } #[cfg(test)] |