summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs34
1 files changed, 20 insertions, 14 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 632ef116..b5e64454 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -194,7 +194,7 @@ impl Dim for usize {
if dim >= shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -207,7 +207,7 @@ impl Dim for usize {
if dim > shape.dims().len() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
- dim,
+ dim: dim as i32,
op,
})?
} else {
@@ -221,30 +221,36 @@ pub enum D {
Minus2,
}
+impl D {
+ fn out_of_range(&self, shape: &Shape, op: &'static str) -> Error {
+ let dim = match self {
+ Self::Minus1 => -1,
+ Self::Minus2 => -2,
+ };
+ Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ }
+ }
+}
+
impl Dim for D {
fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
Self::Minus1 if rank >= 1 => Ok(rank - 1),
Self::Minus2 if rank >= 2 => Ok(rank - 2),
- _ => Err(Error::DimOutOfRange {
- shape: shape.clone(),
- dim: 42, // TODO: Have an adequate error
- op,
- }),
+ _ => Err(self.out_of_range(shape, op)),
}
}
fn to_index_plus_one(&self, shape: &Shape, op: &'static str) -> Result<usize> {
let rank = shape.rank();
match self {
- Self::Minus1 if rank >= 1 => Ok(rank),
- Self::Minus2 if rank >= 2 => Ok(rank - 1),
- _ => Err(Error::DimOutOfRange {
- shape: shape.clone(),
- dim: 42, // TODO: Have an adequate error
- op,
- }),
+ Self::Minus1 => Ok(rank),
+ Self::Minus2 if rank >= 1 => Ok(rank - 1),
+ _ => Err(self.out_of_range(shape, op)),
}
}
}