summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/shape.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 567a711b..90a37be6 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -304,6 +304,7 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
+ Minus(usize),
}
impl D {
@@ -311,6 +312,7 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
+ Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@@ -327,6 +329,7 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
+ Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@@ -336,6 +339,7 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
+ Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}