summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs18
1 files changed, 12 insertions, 6 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 982f9db0..b016ead5 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -87,6 +87,12 @@ macro_rules! extract_dims {
}
}
}
+ impl crate::Tensor {
+ pub fn $fn_name(&self) -> Result<$out_type> {
+ self.shape().$fn_name()
+ }
+ }
+
impl std::convert::TryInto<$out_type> for Shape {
type Error = crate::Error;
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
@@ -328,23 +334,23 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
-extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
-extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
-extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
+extract_dims!(dims0, 0, |_: &Vec<usize>| (), ());
+extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
+extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
- r3,
+ dims3,
3,
|d: &[usize]| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
- r4,
+ dims4,
4,
|d: &[usize]| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
extract_dims!(
- r5,
+ dims5,
5,
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
(usize, usize, usize, usize, usize)