summaryrefslogtreecommitdiff
path: root/candle-core
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-18 16:30:53 +0100
committerGitHub <noreply@github.com>2023-08-18 16:30:53 +0100
commitcb069d606323cec02c6bf54185c2fbfffffd4bdf (patch)
tree4af8c5d82a7a8820a82db1ea3fcfad770802c6c1 /candle-core
parent4f1541526cd52e7da356c26a2752bb187ca38e0b (diff)
downloadcandle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.tar.gz
candle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.tar.bz2
candle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.zip
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
Diffstat (limited to 'candle-core')
-rw-r--r--candle-core/src/backprop.rs10
-rw-r--r--candle-core/src/layout.rs25
-rw-r--r--candle-core/src/op.rs1
-rw-r--r--candle-core/src/shape.rs10
-rw-r--r--candle-core/src/tensor.rs36
5 files changed, 82 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs
index 2a60fe30..ee6f7d75 100644
--- a/candle-core/src/backprop.rs
+++ b/candle-core/src/backprop.rs
@@ -96,6 +96,7 @@ impl Tensor {
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
+ | Op::Permute(node, _)
| Op::Narrow(node, _, _, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
@@ -403,6 +404,15 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
+ Op::Permute(arg, dims) => {
+ let mut inv_dims = vec![0; dims.len()];
+ for (i, &dim_idx) in dims.iter().enumerate() {
+ inv_dims[dim_idx] = i
+ }
+ let arg_grad = grad.permute(inv_dims)?;
+ let sum_grad = grads.or_insert(arg)?;
+ *sum_grad = sum_grad.add(&arg_grad)?
+ }
};
}
}
diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs
index 95dc9667..dc532248 100644
--- a/candle-core/src/layout.rs
+++ b/candle-core/src/layout.rs
@@ -112,6 +112,31 @@ impl Layout {
})
}
+ pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
+ let is_permutation =
+ idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
+ if !is_permutation {
+ crate::bail!(
+ "dimension mismatch in permute, tensor {:?}, dims: {:?}",
+ self.dims(),
+ idxs
+ )
+ }
+ let stride = self.stride();
+ let dims = self.shape().dims();
+ let mut perm_stride = stride.to_vec();
+ let mut perm_dims = dims.to_vec();
+ for (i, &idx) in idxs.iter().enumerate() {
+ perm_stride[i] = stride[idx];
+ perm_dims[i] = dims[idx];
+ }
+ Ok(Self {
+ shape: Shape::from(perm_dims),
+ stride: perm_stride,
+ start_offset: self.start_offset,
+ })
+ }
+
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let shape = shape.into();
if shape.rank() < self.shape().rank() {
diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs
index cf99f86e..81ee8d59 100644
--- a/candle-core/src/op.rs
+++ b/candle-core/src/op.rs
@@ -117,6 +117,7 @@ pub enum Op {
Reshape(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
+ Permute(Tensor, Vec<usize>),
Elu(Tensor, f64),
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
CustomOp2(
diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs
index 83d11c09..d8f8f756 100644
--- a/candle-core/src/shape.rs
+++ b/candle-core/src/shape.rs
@@ -345,6 +345,16 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
+impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
+ fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
+ let d0 = self.0.to_index(shape, op)?;
+ let d1 = self.1.to_index(shape, op)?;
+ let d2 = self.2.to_index(shape, op)?;
+ let d3 = self.3.to_index(shape, op)?;
+ Ok(vec![d0, d1, d2, d3])
+ }
+}
+
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index 421c17e0..45aa07bc 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1459,6 +1459,42 @@ impl Tensor {
Ok(Tensor(Arc::new(tensor_)))
}
+ /// Returns a tensor with the same data as the input where the dimensions have been permuted.
+ /// dims must be a permutation, i.e. include each dimension index exactly once.
+ ///
+ /// ```rust
+ /// use candle_core::{Tensor, Device};
+ /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
+ /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
+ /// let tensor = tensor.permute((2, 3, 1, 0))?;
+ /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
+ /// # Ok::<(), candle_core::Error>(())
+ /// ```
+ pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
+ let dims = dims.to_indexes(self.shape(), "permute")?;
+ // O(n^2) permutation check but these arrays are small.
+ let is_permutation =
+ dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
+ if !is_permutation {
+ crate::bail!(
+ "dimension mismatch in permute, tensor {:?}, dims: {:?}",
+ self.dims(),
+ dims
+ )
+ }
+ let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
+ let tensor_ = Tensor_ {
+ id: TensorId::new(),
+ storage: self.storage.clone(),
+ layout: self.layout.permute(&dims)?,
+ op,
+ is_variable: false,
+ dtype: self.dtype,
+ device: self.device.clone(),
+ };
+ Ok(Tensor(Arc::new(tensor_)))
+ }
+
/// Returns true if the data is stored in a C contiguous (aka row major) way.
pub fn is_contiguous(&self) -> bool {
self.layout.is_contiguous()