summaryrefslogtreecommitdiff
path: root/candle-core/src/shape.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src/shape.rs')
-rw-r--r--candle-core/src/shape.rs205
1 files changed, 205 insertions, 0 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index aea8b887..4d500e7f 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -1,3 +1,4 @@
+//! The shape of a tensor is a tuple with the size of each of its dimensions.
#![allow(clippy::redundant_closure_call)]
use crate::{Error, Result};
@@ -72,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape {
}
}
+impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
+ fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
+ Self(vec![
+ d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
+ ])
+ }
+}
+
impl From<Vec<usize>> for Shape {
fn from(dims: Vec<usize>) -> Self {
Self(dims)
@@ -119,6 +128,7 @@ impl Shape {
Self(dims.to_vec())
}
+ /// The rank is the number of dimensions, 0 for a scalar value, 1 for a vector, etc.
pub fn rank(&self) -> usize {
self.0.len()
}
@@ -127,10 +137,12 @@ impl Shape {
self.0
}
+ /// The dimensions as a slice of `usize`.
pub fn dims(&self) -> &[usize] {
&self.0
}
+ /// The total number of elements, this is the product of all dimension sizes.
pub fn elem_count(&self) -> usize {
self.0.iter().product()
}
@@ -182,6 +194,8 @@ impl Shape {
true
}
+ /// Modifies the shape by adding a list of additional dimensions at the end of the existing
+ /// dimensions.
pub fn extend(mut self, additional_dims: &[usize]) -> Self {
self.0.extend(additional_dims);
self
@@ -419,6 +433,29 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
}
}
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4])
+ }
+}
+
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ let d4 = self.4.to_index(shape, op)?;
+ let d5 = self.5.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3, d4, d5])
+ }
+}
+
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
@@ -457,3 +494,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
+
+pub trait ShapeWithOneHole {
+ fn into_shape(self, el_count: usize) -> Result<Shape>;
+}
+
+impl<S: Into<Shape>> ShapeWithOneHole for S {
+ fn into_shape(self, _el_count: usize) -> Result<Shape> {
+ Ok(self.into())
+ }
+}
+
+impl ShapeWithOneHole for ((),) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ Ok(el_count.into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((el_count / d1, d1).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, ()) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((d1, el_count / d1).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, ()) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, ()) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, (), d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, d4, ()) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, d4, el_count / d).into())
+ }
+}