diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d.rs | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index 0bc820db..e52ec281 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] //! 2D UNet Denoising Models //! //! The 2D Unet models take as input a noisy sample and the current diffusion @@ -103,6 +102,7 @@ impl UNet2DConditionModel { vs: nn::VarBuilder, in_channels: usize, out_channels: usize, + use_flash_attn: bool, config: UNet2DConditionModelConfig, ) -> Result<Self> { let n_blocks = config.blocks.len(); @@ -161,6 +161,7 @@ impl UNet2DConditionModel { in_channels, out_channels, Some(time_embed_dim), + use_flash_attn, config, )?; Ok(UNetDownBlock::CrossAttn(block)) @@ -190,6 +191,7 @@ impl UNet2DConditionModel { vs.pp("mid_block"), bl_channels, Some(time_embed_dim), + use_flash_attn, mid_cfg, )?; @@ -242,6 +244,7 @@ impl UNet2DConditionModel { prev_out_channels, out_channels, Some(time_embed_dim), + use_flash_attn, config, )?; Ok(UNetUpBlock::CrossAttn(block)) |