diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-08-17 20:29:01 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-17 21:29:01 +0200 |
commit | 7cff5898ec2d6fca33a513ea02414c957f3ea0f1 (patch) | |
tree | 63f2086a3fe1a94b2560f9019bf6884b16a004d2 | |
parent | b75ef051cfa67b9c2a2cee822ad2e1ae796a4704 (diff) | |
download | candle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.tar.gz candle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.tar.bz2 candle-7cff5898ec2d6fca33a513ea02414c957f3ea0f1.zip |
Support Minus(u) for arbitrary values of u, e.g. Minus(3). (#2428)
* Support Minus(u) for arbitrary values of u, e.g. Minus(3).
* Forces u to be strictly positive.
-rw-r--r-- | candle-core/src/shape.rs | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 567a711b..90a37be6 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -304,6 +304,7 @@ impl Dim for usize { pub enum D { Minus1, Minus2, + Minus(usize), } impl D { @@ -311,6 +312,7 @@ impl D { let dim = match self { Self::Minus1 => -1, Self::Minus2 => -2, + Self::Minus(u) => -(*u as i32), }; Error::DimOutOfRange { shape: shape.clone(), @@ -327,6 +329,7 @@ impl Dim for D { match self { Self::Minus1 if rank >= 1 => Ok(rank - 1), Self::Minus2 if rank >= 2 => Ok(rank - 2), + Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u), _ => Err(self.out_of_range(shape, op)), } } @@ -336,6 +339,7 @@ impl Dim for D { match self { Self::Minus1 => Ok(rank), Self::Minus2 if rank >= 1 => Ok(rank - 1), + Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u), _ => Err(self.out_of_range(shape, op)), } } |