diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 2f397815..b7adb2c0 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] //! 2D UNet Building Blocks //! use crate::attention::{ @@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn { vs: nn::VarBuilder, in_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: UNetMidBlock2DCrossAttnConfig, ) -> Result<Self> { let vs_resnets = vs.pp("resnets"); @@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn { in_channels, n_heads, in_channels / n_heads, + use_flash_attn, attn_cfg, )?; let resnet = ResnetBlock2D::new( @@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D { in_channels: usize, out_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: CrossAttnDownBlock2DConfig, ) -> Result<Self> { let downblock = DownBlock2D::new( @@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D { out_channels, n_heads, out_channels / n_heads, + use_flash_attn, cfg, ) }) @@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D { prev_output_channels: usize, out_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: CrossAttnUpBlock2DConfig, ) -> Result<Self> { let upblock = UpBlock2D::new( @@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D { out_channels, n_heads, out_channels / n_heads, + use_flash_attn, cfg, ) }) |