diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-07 16:24:56 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 15:24:56 +0100 |
commit | f53a333ea91233b41dd946c2c30213c79b4d1cb3 (patch) | |
tree | e9e525b6161f7fb31010ac26461f241729b22d75 /candle-core/src/tensor.rs | |
parent | e72ba0b9e755bbac5bd60718765c043bba3a63dc (diff) | |
download | candle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.tar.gz candle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.tar.bz2 candle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.zip |
Simple pad support. (#336)
* Simple pad support.
* Fix the tensor indexing when padding.
Diffstat (limited to 'candle-core/src/tensor.rs')
-rw-r--r-- | candle-core/src/tensor.rs | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ff381620..f7bd894a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1759,6 +1759,32 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[self, &right], dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self], dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self, &right], dim) + } + } + fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |