summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs7
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 2f397815..b7adb2c0 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
@@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
@@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
in_channels,
n_heads,
in_channels / n_heads,
+ use_flash_attn,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
@@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
@@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
+ use_flash_attn,
cfg,
)
})
@@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
@@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
+ use_flash_attn,
cfg,
)
})