diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-08 08:38:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-08 08:38:13 +0100 |
commit | 0e250aee4fcff8991c086ba0606a90db92b4e488 (patch) | |
tree | 394731df7cc9f7c21e5447d2a71c0fa2f14905d7 | |
parent | cfcbec9fc70aca2b0e08f382dec8634f88b61bce (diff) | |
download | candle-0e250aee4fcff8991c086ba0606a90db92b4e488.tar.gz candle-0e250aee4fcff8991c086ba0606a90db92b4e488.tar.bz2 candle-0e250aee4fcff8991c086ba0606a90db92b4e488.zip |
Shape with holes (#770)
* Shape with holes.
* rustfmt.
-rw-r--r-- | candle-core/src/shape.rs | 168 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 13 | ||||
-rw-r--r-- | candle-flash-attn/build.rs | 9 |
3 files changed, 184 insertions, 6 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 9617d1ac..b1f56817 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -482,3 +482,171 @@ mod tests { assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]); } } + +pub trait ShapeWithOneHole { + fn into_shape(self, el_count: usize) -> Result<Shape>; +} + +impl<S: Into<Shape>> ShapeWithOneHole for S { + fn into_shape(self, _el_count: usize) -> Result<Shape> { + Ok(self.into()) + } +} + +impl ShapeWithOneHole for ((),) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + Ok(el_count.into()) + } +} + +impl ShapeWithOneHole for ((), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((el_count / d1, d1).into()) + } +} + +impl ShapeWithOneHole for (usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, ()) = self; + if el_count % d1 != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d1}") + } + Ok((d1, el_count / d1).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, ()) = self; + let d = d1 * d2; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, ()) = self; + let d = d1 * d2 * d3; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d).into()) + } +} + +impl ShapeWithOneHole for ((), usize, usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let ((), d1, d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((el_count / d, d1, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, (), usize, usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, (), d2, d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, el_count / d, d2, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, (), usize, usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, (), d3, d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, el_count / d, d3, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, (), usize) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, (), d4) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, el_count / d, d4).into()) + } +} + +impl ShapeWithOneHole for (usize, usize, usize, usize, ()) { + fn into_shape(self, el_count: usize) -> Result<Shape> { + let (d1, d2, d3, d4, ()) = self; + let d = d1 * d2 * d3 * d4; + if el_count % d != 0 { + crate::bail!("tensor number of elements {el_count} is not divisible by {d}") + } + Ok((d1, d2, d3, d4, el_count / d).into()) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 1eca694c..6bb3d740 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1685,12 +1685,15 @@ impl Tensor { Ok(from_storage(storage, shape, BackpropOp::none(), true)) } - // TODO: Do we want to allow target shape using -1 on some dimensions? /// Reshape returns a tensor with the target shape provided that the number of elements of the /// original tensor is the same. /// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses /// a new storage and copies the data over, the returned tensor is always contiguous. /// + /// The shape can be specified using a tuple of `usize` and at most one `()` in which case + /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so + /// as to match the number of elements in the tensor. + /// /// ```rust /// # use candle_core::{Tensor, DType, Device, D}; /// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; @@ -1700,10 +1703,14 @@ impl Tensor { /// /// let c = a.reshape((3, 2))?; /// assert_eq!(c.shape().dims(), &[3, 2]); + /// + /// let c = a.reshape((2, (), 1))?; + /// assert_eq!(c.shape().dims(), &[2, 3, 1]); + /// /// # Ok::<(), candle_core::Error>(()) /// ``` - pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> { - let shape = shape.into(); + pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> { + let shape = s.into_shape(self.elem_count())?; if shape.elem_count() != self.elem_count() { return Err(Error::ShapeMismatchBinaryOp { lhs: self.shape().clone(), diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs index 4cc7e5fb..64275fda 100644 --- a/candle-flash-attn/build.rs +++ b/candle-flash-attn/build.rs @@ -57,10 +57,13 @@ fn main() -> Result<()> { #[allow(clippy::redundant_clone)] out_dir.clone() } - Ok(build_dir) => - { + Ok(build_dir) => { let path = PathBuf::from(build_dir); - path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display())) + path.canonicalize().expect(&format!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + std::env::current_dir()?.display() + )) } }; set_cuda_include_dir()?; |