summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-core/src/tensor.rs26
-rw-r--r--candle-examples/examples/stable-diffusion/embeddings.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs6
-rw-r--r--candle-examples/examples/stable-diffusion/utils.rs4
4 files changed, 31 insertions, 7 deletions
diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs
index ff381620..f7bd894a 100644
--- a/candle-core/src/tensor.rs
+++ b/candle-core/src/tensor.rs
@@ -1759,6 +1759,32 @@ impl Tensor {
Ok(from_storage(storage, shape, op, false))
}
+ pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
+ if left == 0 && right == 0 {
+ Ok(self.clone())
+ } else if left == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[self, &right], dim)
+ } else if right == 0 {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self], dim)
+ } else {
+ let dim = dim.to_index(self.shape(), "pad_with_zeros")?;
+ let mut dims = self.dims().to_vec();
+ dims[dim] = left;
+ let left = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ dims[dim] = right;
+ let right = Tensor::zeros(dims.as_slice(), self.dtype, self.device())?;
+ Tensor::cat(&[&left, self, &right], dim)
+ }
+ }
+
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
diff --git a/candle-examples/examples/stable-diffusion/embeddings.rs b/candle-examples/examples/stable-diffusion/embeddings.rs
index f8a4f351..848f1760 100644
--- a/candle-examples/examples/stable-diffusion/embeddings.rs
+++ b/candle-examples/examples/stable-diffusion/embeddings.rs
@@ -57,7 +57,7 @@ impl Timesteps {
Tensor::cat(&[&sin, &cos], D::Minus1)?
};
if self.num_channels % 2 == 1 {
- crate::utils::pad(&emb) // ([0, 1, 0, 0], 'constant', None)
+ emb.pad_with_zeros(D::Minus2, 0, 1)
} else {
Ok(emb)
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 8dd6cf26..4d0c80a5 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -5,7 +5,7 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
-use candle::{Result, Tensor};
+use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@@ -39,7 +39,9 @@ impl Downsample2D {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => {
if self.padding == 0 {
- let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
+ let xs = xs
+ .pad_with_zeros(D::Minus1, 0, 1)?
+ .pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs)
} else {
conv.forward(xs)
diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs
index 4294d823..08b78c04 100644
--- a/candle-examples/examples/stable-diffusion/utils.rs
+++ b/candle-examples/examples/stable-diffusion/utils.rs
@@ -4,10 +4,6 @@ pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
todo!()
}
-pub fn pad(_: &Tensor) -> Result<Tensor> {
- todo!()
-}
-
pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
todo!()
}