diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index beaa9455..32ebb23f 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -478,23 +478,6 @@ extract_dims!( (usize, usize, usize, usize, usize) ); -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn stride() { - let shape = Shape::from(()); - assert_eq!(shape.stride_contiguous(), Vec::<usize>::new()); - let shape = Shape::from(42); - assert_eq!(shape.stride_contiguous(), [1]); - let shape = Shape::from((42, 1337)); - assert_eq!(shape.stride_contiguous(), [1337, 1]); - let shape = Shape::from((299, 792, 458)); - assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); - } -} - pub trait ShapeWithOneHole { fn into_shape(self, el_count: usize) -> Result<Shape>; } @@ -627,3 +610,20 @@ impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { Ok((d1, d2, d3, d4, d).into()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stride() { + let shape = Shape::from(()); + assert_eq!(shape.stride_contiguous(), Vec::<usize>::new()); + let shape = Shape::from(42); + assert_eq!(shape.stride_contiguous(), [1]); + let shape = Shape::from((42, 1337)); + assert_eq!(shape.stride_contiguous(), [1337, 1]); + let shape = Shape::from((299, 792, 458)); + assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); + } +} |