diff options
-rw-r--r-- | candle-core/src/tensor.rs | 26 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/embeddings.rs | 2 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 6 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/utils.rs | 4 |
4 files changed, 31 insertions, 7 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ff381620..f7bd894a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1759,6 +1759,32 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } + pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> { + if left == 0 && right == 0 { + Ok(self.clone()) + } else if left == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[self, &right], dim) + } else if right == 0 { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self], dim) + } else { + let dim = dim.to_index(self.shape(), "pad_with_zeros")?; + let mut dims = self.dims().to_vec(); + dims[dim] = left; + let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + dims[dim] = right; + let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?; + Tensor::cat(&[&left, self, &right], dim) + } + } + fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs index f8a4f351..848f1760 100644 --- a/candle-examples/examples/stable-diffusion/embeddings.rs +++ b/candle-examples/examples/stable-diffusion/embeddings.rs @@ -57,7 +57,7 @@ impl Timesteps { Tensor::cat(&[&sin, &cos], D::Minus1)? }; if self.num_channels % 2 == 1 { - crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None) + emb.pad_with_zeros(D::Minus2, 0, 1) } else { Ok(emb) } diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 8dd6cf26..4d0c80a5 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -5,7 +5,7 @@ use crate::attention::{ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, }; use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; -use candle::{Result, Tensor}; +use candle::{Result, Tensor, D}; use candle_nn as nn; #[derive(Debug)] @@ -39,7 +39,9 @@ impl Downsample2D { None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), Some(conv) => { if self.padding == 0 { - let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?; + let xs = xs + .pad_with_zeros(D::Minus1, 0, 1)? + .pad_with_zeros(D::Minus2, 0, 1)?; conv.forward(&xs) } else { conv.forward(xs) diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 4294d823..08b78c04 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -4,10 +4,6 @@ pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> { todo!() } -pub fn pad(_: &Tensor) -> Result<Tensor> { - todo!() -} - pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { todo!() } |