diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 8dd6cf26..4d0c80a5 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -5,7 +5,7 @@ use crate::attention::{ AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, }; use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; -use candle::{Result, Tensor}; +use candle::{Result, Tensor, D}; use candle_nn as nn; #[derive(Debug)] @@ -39,7 +39,9 @@ impl Downsample2D { None => crate::utils::avg_pool2d(xs), // [2, 2], [2, 2], [0, 0], false, true, None), Some(conv) => { if self.padding == 0 { - let xs = crate::utils::pad(xs)?; // [0, 1, 0, 1], "constant", Some(0.))?; + let xs = xs + .pad_with_zeros(D::Minus1, 0, 1)? + .pad_with_zeros(D::Minus2, 0, 1)?; conv.forward(&xs) } else { conv.forward(xs) |