summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs27
1 files changed, 16 insertions, 11 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index a5e21aad..83d11c09 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -79,20 +79,25 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
+ pub fn $fn_name(dims: &[usize]) -> Result<$out_type> {
+ if dims.len() != $cnt {
+ Err(Error::UnexpectedNumberOfDims {
+ expected: $cnt,
+ got: dims.len(),
+ shape: Shape::from(dims),
+ }
+ .bt())
+ } else {
+ Ok($dims(dims))
+ }
+ }
+
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
- if self.0.len() != $cnt {
- Err(Error::UnexpectedNumberOfDims {
- expected: $cnt,
- got: self.0.len(),
- shape: self.clone(),
- }
- .bt())
- } else {
- Ok($dims(&self.0))
- }
+ $fn_name(self.0.as_slice())
}
}
+
impl crate::Tensor {
pub fn $fn_name(&self) -> Result<$out_type> {
self.shape().$fn_name()
@@ -340,7 +345,7 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
-extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
+extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(