summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-07 16:24:56 +0200
committerGitHub <noreply@github.com>2023-08-07 15:24:56 +0100
commitf53a333ea91233b41dd946c2c30213c79b4d1cb3 (patch)
treee9e525b6161f7fb31010ac26461f241729b22d75 /candle-core/src
parente72ba0b9e755bbac5bd60718765c043bba3a63dc (diff)
downloadcandle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.tar.gz
candle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.tar.bz2
candle-f53a333ea91233b41dd946c2c30213c79b4d1cb3.zip
Simple pad support. (#336)
* Simple pad support. * Fix the tensor indexing when padding.
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/tensor.rs26
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ff381620..f7bd894a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1759,6 +1759,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
+ if left == 0 && right == 0 {
+ Ok(self.clone())
+ } else if left == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[self, &right], dim)
+ } else if right == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self], dim)
+ } else {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self, &right], dim)
+ }
+ }
+
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}