diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-18 16:30:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-18 16:30:53 +0100 |
commit | cb069d606323cec02c6bf54185c2fbfffffd4bdf (patch) | |
tree | 4af8c5d82a7a8820a82db1ea3fcfad770802c6c1 /candle-core | |
parent | 4f1541526cd52e7da356c26a2752bb187ca38e0b (diff) | |
download | candle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.tar.gz candle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.tar.bz2 candle-cb069d606323cec02c6bf54185c2fbfffffd4bdf.zip |
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch).
* Add the backprop for dimension permutation.
Diffstat (limited to 'candle-core')
-rw-r--r-- | candle-core/src/backprop.rs | 10 | ||||
-rw-r--r-- | candle-core/src/layout.rs | 25 | ||||
-rw-r--r-- | candle-core/src/op.rs | 1 | ||||
-rw-r--r-- | candle-core/src/shape.rs | 10 | ||||
-rw-r--r-- | candle-core/src/tensor.rs | 36 |
5 files changed, 82 insertions, 0 deletions
diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 2a60fe30..ee6f7d75 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -96,6 +96,7 @@ impl Tensor { | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) + | Op::Permute(node, _) | Op::Narrow(node, _, _, _) | Op::Unary(node, _) | Op::Elu(node, _) @@ -403,6 +404,15 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::Permute(arg, dims) => { + let mut inv_dims = vec![0; dims.len()]; + for (i, &dim_idx) in dims.iter().enumerate() { + inv_dims[dim_idx] = i + } + let arg_grad = grad.permute(inv_dims)?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } }; } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 95dc9667..dc532248 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -112,6 +112,31 @@ impl Layout { }) } + pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> { + let is_permutation = + idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i)); + if !is_permutation { + crate::bail!( + "dimension mismatch in permute, tensor {:?}, dims: {:?}", + self.dims(), + idxs + ) + } + let stride = self.stride(); + let dims = self.shape().dims(); + let mut perm_stride = stride.to_vec(); + let mut perm_dims = dims.to_vec(); + for (i, &idx) in idxs.iter().enumerate() { + perm_stride[i] = stride[idx]; + perm_dims[i] = dims[idx]; + } + Ok(Self { + shape: Shape::from(perm_dims), + stride: perm_stride, + start_offset: self.start_offset, + }) + } + pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> { let shape = shape.into(); if shape.rank() < self.shape().rank() { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index cf99f86e..81ee8d59 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -117,6 +117,7 @@ pub enum Op { Reshape(Tensor), ToDevice(Tensor), Transpose(Tensor, usize, usize), + Permute(Tensor, Vec<usize>), Elu(Tensor, f64), CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>), CustomOp2( diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 83d11c09..d8f8f756 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -345,6 +345,16 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) { } } +impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) { + fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> { + let d0 = self.0.to_index(shape, op)?; + let d1 = self.1.to_index(shape, op)?; + let d2 = self.2.to_index(shape, op)?; + let d3 = self.3.to_index(shape, op)?; + Ok(vec![d0, d1, d2, d3]) + } +} + extract_dims!(dims0, 0, |_: &[usize]| (), ()); extract_dims!(dims1, 1, |d: &[usize]| d[0], usize); extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 421c17e0..45aa07bc 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1459,6 +1459,42 @@ impl Tensor { Ok(Tensor(Arc::new(tensor_))) } + /// Returns a tensor with the same data as the input where the dimensions have been permuted. + /// dims must be a permutation, i.e. include each dimension index exactly once. + /// + /// ```rust + /// use candle_core::{Tensor, Device}; + /// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?; + /// assert_eq!(tensor.dims(), &[2, 3, 4, 5]); + /// let tensor = tensor.permute((2, 3, 1, 0))?; + /// assert_eq!(tensor.dims(), &[4, 5, 3, 2]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> { + let dims = dims.to_indexes(self.shape(), "permute")?; + // O(n^2) permutation check but these arrays are small. + let is_permutation = + dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i)); + if !is_permutation { + crate::bail!( + "dimension mismatch in permute, tensor {:?}, dims: {:?}", + self.dims(), + dims + ) + } + let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone())); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout.permute(&dims)?, + op, + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } + /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { self.layout.is_contiguous() |