summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-08 08:38:13 +0100
committerGitHub <noreply@github.com>2023-09-08 08:38:13 +0100
commit0e250aee4fcff8991c086ba0606a90db92b4e488 (patch)
tree394731df7cc9f7c21e5447d2a71c0fa2f14905d7
parentcfcbec9fc70aca2b0e08f382dec8634f88b61bce (diff)
downloadcandle-0e250aee4fcff8991c086ba0606a90db92b4e488.tar.gz
candle-0e250aee4fcff8991c086ba0606a90db92b4e488.tar.bz2
candle-0e250aee4fcff8991c086ba0606a90db92b4e488.zip
Shape with holes (#770)
* Shape with holes. * rustfmt.
-rw-r--r--candle-core/src/shape.rs168
-rw-r--r--candle-core/src/tensor.rs13
-rw-r--r--candle-flash-attn/build.rs9
3 files changed, 184 insertions, 6 deletions
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 9617d1ac..b1f56817 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -482,3 +482,171 @@ mod tests {
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
}
}
+
+pub trait ShapeWithOneHole {
+ fn into_shape(self, el_count: usize) -> Result<Shape>;
+}
+
+impl<S: Into<Shape>> ShapeWithOneHole for S {
+ fn into_shape(self, _el_count: usize) -> Result<Shape> {
+ Ok(self.into())
+ }
+}
+
+impl ShapeWithOneHole for ((),) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ Ok(el_count.into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((el_count / d1, d1).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, ()) = self;
+ if el_count % d1 != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
+ }
+ Ok((d1, el_count / d1).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, ()) = self;
+ let d = d1 * d2;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, ()) = self;
+ let d = d1 * d2 * d3;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d).into())
+ }
+}
+
+impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let ((), d1, d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((el_count / d, d1, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, (), d2, d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, el_count / d, d2, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, (), d3, d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, el_count / d, d3, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, (), d4) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, el_count / d, d4).into())
+ }
+}
+
+impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
+ fn into_shape(self, el_count: usize) -> Result<Shape> {
+ let (d1, d2, d3, d4, ()) = self;
+ let d = d1 * d2 * d3 * d4;
+ if el_count % d != 0 {
+ crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
+ }
+ Ok((d1, d2, d3, d4, el_count / d).into())
+ }
+}
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 1eca694c..6bb3d740 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1685,12 +1685,15 @@ impl Tensor {
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}
- // TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
/// If the input tensor is contiguous, this is a view on the original data. Otherwise this uses
/// a new storage and copies the data over, the returned tensor is always contiguous.
///
+ /// The shape can be specified using a tuple of `usize` and at most one `()` in which case
+ /// the behavior is the same as when using `-1` in PyTorch: this dimension size is adjusted so
+ /// as to match the number of elements in the tensor.
+ ///
/// ```rust
/// # use candle_core::{Tensor, DType, Device, D};
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
@@ -1700,10 +1703,14 @@ impl Tensor {
///
/// let c = a.reshape((3, 2))?;
/// assert_eq!(c.shape().dims(), &[3, 2]);
+ ///
+ /// let c = a.reshape((2, (), 1))?;
+ /// assert_eq!(c.shape().dims(), &[2, 3, 1]);
+ ///
/// # Ok::<(), candle_core::Error>(())
/// ```
- pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
- let shape = shape.into();
+ pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
+ let shape = s.into_shape(self.elem_count())?;
if shape.elem_count() != self.elem_count() {
return Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 4cc7e5fb..64275fda 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -57,10 +57,13 @@ fn main() -> Result<()> {
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
- Ok(build_dir) =>
- {
+ Ok(build_dir) => {
let path = PathBuf::from(build_dir);
- path.canonicalize().expect(&format!("Directory doesn't exists: {} (the current directory is {})", &path.display(), std::env::current_dir()?.display()))
+ path.canonicalize().expect(&format!(
+ "Directory doesn't exists: {} (the current directory is {})",
+ &path.display(),
+ std::env::current_dir()?.display()
+ ))
}
};
set_cuda_include_dir()?;