diff options
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r-- | candle-core/src/shape.rs | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index aea8b887..4d500e7f 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -1,3 +1,4 @@ +//! The shape of a tensor is a tuple with the size of each of its dimensions. #![allow(clippy::redundant_closure_call)] use crate::{Error, Result}; @@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape { } } +impl From<(usize, usize, usize, usize, usize, usize)> for Shape { + fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { + Self(vec![ + d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, + ]) + } +} + impl From<Vec<usize>> for Shape { fn from(dims: Vec<usize>) -> Self { Self(dims) @@ -119,6 +128,7 @@ impl Shape { Self(dims.to_vec()) } + /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc. pub fn rank(&self) -> usize { self.0.len() } @@ -127,10 +137,12 @@ impl Shape { self.0 } + /// The dimensions as a slice of `usize`. pub fn dims(&self) -> &[usize] { &self.0 } + /// The total number of elements, this is the product of all dimension sizes. pub fn elem_count(&self) -> usize { self.0.iter().product() } @@ -182,6 +194,8 @@ impl Shape { true } + /// Modifies the shape by adding a list of additional dimensions at the end of the existing + /// dimensions. pub fn extend(mut self, additional_dims: &[usize]) -> Self { self.0.extend(additional_dims); self @@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) { } } +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4]) + } +} + +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + let d4 = self.4.to_index(shape, op)?; + let d5 = self.5.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3, d4, d5]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); @@ -457,3 +494,171 @@ mod tests { assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result<Shape>; +} + +impl<S: Into<Shape>> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result<Shape> { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + Ok(el_count.into()) + } +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((el_count / d1, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, ()) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((d1, el_count / d1).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, ()) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, ()) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, (), d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, d4, ()) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, d4, el_count / d).into()) + } +} |