summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2023-07-10 15:21:24 +0200
committerGitHub <noreply@github.com>2023-07-10 15:21:24 +0200
commitdc5825967957e28e6ac4f57da18c7963f2be542c (patch)
treef8249c4d0259c1c8f0c1e46c7f1ecd95da258580 /candle-core/src/shape.rs
parent204618b7d37229cd19a7f85ed38e6ab916e1e0d1 (diff)
parent9a667155fd554fe270561783f6708445e2deb929 (diff)
downloadcandle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.gz
candle-dc5825967957e28e6ac4f57da18c7963f2be542c.tar.bz2
candle-dc5825967957e28e6ac4f57da18c7963f2be542c.zip
Merge pull request #120 from LaurentMazare/some_doc_plus_fix_stack
Adding some doc + Extended `stack` to work with extra final dimensions.
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs27
1 files changed, 27 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 1152dc3e..632ef116 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -185,6 +185,7 @@ impl Shape {
pub trait Dim {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize>;
}
impl Dim for usize {
@@ -200,6 +201,19 @@ impl Dim for usize {
Ok(dim)
}
}
+
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let dim = *self;
+ if dim > shape.dims().len() {
+ Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ })?
+ } else {
+ Ok(dim)
+ }
+ }
}
pub enum D {
@@ -220,6 +234,19 @@ impl Dim for D {
}),
}
}
+
+ fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let rank = shape.rank();
+ match self {
+ Self::Minus1 if rank >= 1 => Ok(rank),
+ Self::Minus2 if rank >= 2 => Ok(rank - 1),
+ _ => Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim: 42, // TODO: Have an adequate error
+ op,
+ }),
+ }
+ }
}
#[cfg(test)]