summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs6
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 8dd6cf26..4d0c80a5 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -5,7 +5,7 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
-use candle::{Result, Tensor};
+use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
@@ -39,7 +39,9 @@ impl Downsample2D {
None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None),
Some(conv) => {
if self.padding == 0 {
- let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?;
+ let xs = xs
+ .pad_with_zeros(D::Minus1, 0, 1)?
+ .pad_with_zeros(D::Minus2, 0, 1)?;
conv.forward(&xs)
} else {
conv.forward(xs)