summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-08-17 20:29:01 +0100
committerGitHub <noreply@github.com>2024-08-17 21:29:01 +0200
commit7cff5898ec2d6fca33a513ea02414c957f3ea0f1 (patch)
tree63f2086a3fe1a94b2560f9019bf6884b16a004d2
parentb75ef051cfa67b9c2a2cee822ad2e1ae796a4704 (diff)
downloadcandle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.tar.gz
candle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.tar.bz2
candle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.zip
Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3). * Forces u to be strictly positive.
-rw-r--r--candle-core/src/shape.rs4
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 567a711b..90a37be6 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -304,6 +304,7 @@ impl Dim for usize {
pub enum D {
Minus1,
Minus2,
+ Minus(usize),
}
impl D {
@@ -311,6 +312,7 @@ impl D {
let dim = match self {
Self::Minus1 => -1,
Self::Minus2 => -2,
+ Self::Minus(u) => -(*u as i32),
};
Error::DimOutOfRange {
shape: shape.clone(),
@@ -327,6 +329,7 @@ impl Dim for D {
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
+ Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u),
_ => Err(self.out_of_range(shape, op)),
}
}
@@ -336,6 +339,7 @@ impl Dim for D {
match self {
Self::Minus1 => Ok(rank),
Self::Minus2 if rank >= 1 => Ok(rank - 1),
+ Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u),
_ => Err(self.out_of_range(shape, op)),
}
}