summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs39
1 files changed, 39 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index cc068004..1152dc3e 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -183,6 +183,45 @@ impl Shape {
}
}
+pub trait Dim {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize>;
+}
+
+impl Dim for usize {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let dim = *self;
+ if dim >= shape.dims().len() {
+ Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim,
+ op,
+ })?
+ } else {
+ Ok(dim)
+ }
+ }
+}
+
+pub enum D {
+ Minus1,
+ Minus2,
+}
+
+impl Dim for D {
+ fn to_index(&self, shape: &Shape, op: &'static str) -> Result<usize> {
+ let rank = shape.rank();
+ match self {
+ Self::Minus1 if rank >= 1 => Ok(rank - 1),
+ Self::Minus2 if rank >= 2 => Ok(rank - 2),
+ _ => Err(Error::DimOutOfRange {
+ shape: shape.clone(),
+ dim: 42, // TODO: Have an adequate error
+ op,
+ }),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;