summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 26a1035b..be258acb 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -754,6 +754,7 @@ impl UpBlock2D {
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
}
match &self.upsampler {
@@ -855,6 +856,7 @@ impl CrossAttnUpBlock2D {
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
+ xs = xs.contiguous()?;
xs = resnet.forward(&xs, temb)?;
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
}