summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs34
1 files changed, 17 insertions, 17 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index beaa9455..32ebb23f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -478,23 +478,6 @@ extract_dims!(
(usize, usize, usize, usize, usize)
);
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn stride() {
- let shape = Shape::from(());
- assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
- let shape = Shape::from(42);
- assert_eq!(shape.stride_contiguous(), [1]);
- let shape = Shape::from((42, 1337));
- assert_eq!(shape.stride_contiguous(), [1337, 1]);
- let shape = Shape::from((299, 792, 458));
- assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
- }
-}
-
pub trait ShapeWithOneHole {
fn into_shape(self, el_count: usize) -> Result<Shape>;
}
@@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
Ok((d1, d2, d3, d4, d).into())
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn stride() {
+ let shape = Shape::from(());
+ assert_eq!(shape.stride_contiguous(), Vec::<usize>::new());
+ let shape = Shape::from(42);
+ assert_eq!(shape.stride_contiguous(), [1]);
+ let shape = Shape::from((42, 1337));
+ assert_eq!(shape.stride_contiguous(), [1337, 1]);
+ let shape = Shape::from((299, 792, 458));
+ assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
+ }
+}