summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/unet_2d.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs5
1 files changed, 4 insertions, 1 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index 0bc820db..e52ec281 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
@@ -103,6 +102,7 @@ impl UNet2DConditionModel {
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
+ use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
@@ -161,6 +161,7 @@ impl UNet2DConditionModel {
in_channels,
out_channels,
Some(time_embed_dim),
+ use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
@@ -190,6 +191,7 @@ impl UNet2DConditionModel {
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
+ use_flash_attn,
mid_cfg,
)?;
@@ -242,6 +244,7 @@ impl UNet2DConditionModel {
prev_out_channels,
out_channels,
Some(time_embed_dim),
+ use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))