diff options
Diffstat (limited to 'candle-core/src')
-rw-r--r-- | candle-core/src/tensor.rs | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ff381620..f7bd894a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1759,6 +1759,32 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[self, &right], dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self], dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self, &right], dim) + } + } + fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } |