summaryrefslogtreecommitdiff
path: root/candle-core/src
diff options
context:
space:
mode:
Diffstat (limited to 'candle-core/src')
-rw-r--r--candle-core/src/tensor.rs26
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ff381620..f7bd894a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1759,6 +1759,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
+ if left == 0 && right == 0 {
+ Ok(self.clone())
+ } else if left == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[self, &right], dim)
+ } else if right == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self], dim)
+ } else {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self, &right], dim)
+ }
+ }
+
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}