diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d.rs | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index eb2dbf10..81bd9547 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -12,7 +12,9 @@ use candle_nn::Module; #[derive(Debug, Clone, Copy)] pub struct BlockConfig { pub out_channels: usize, - pub use_cross_attn: bool, + /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the + /// number of transformer blocks to be used. + pub use_cross_attn: Option<usize>, pub attention_head_dim: usize, } @@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig { blocks: vec![ BlockConfig { out_channels: 320, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 640, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, - use_cross_attn: true, + use_cross_attn: Some(1), attention_head_dim: 8, }, BlockConfig { out_channels: 1280, - use_cross_attn: false, + use_cross_attn: None, attention_head_dim: 8, }, ], @@ -149,13 +151,14 @@ impl UNet2DConditionModel { downsample_padding: config.downsample_padding, ..Default::default() }; - if use_cross_attn { + if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnDownBlock2DConfig { downblock: db_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, }; let block = CrossAttnDownBlock2D::new( vs_db.pp(&i.to_string()), @@ -179,6 +182,11 @@ impl UNet2DConditionModel { }) .collect::<Result<Vec<_>>>()?; + // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462 + let mid_transformer_layers_per_block = match config.blocks.last() { + None => 1, + Some(block) => block.use_cross_attn.unwrap_or(1), + }; let mid_cfg = UNetMidBlock2DCrossAttnConfig { resnet_eps: config.norm_eps, output_scale_factor: config.mid_block_scale_factor, @@ -186,8 +194,10 @@ impl UNet2DConditionModel { attn_num_head_channels: bl_attention_head_dim, resnet_groups: Some(config.norm_num_groups), use_linear_projection: config.use_linear_projection, + transformer_layers_per_block: mid_transformer_layers_per_block, ..Default::default() }; + let mid_block = UNetMidBlock2DCrossAttn::new( vs.pp("mid_block"), bl_channels, @@ -231,13 +241,14 @@ impl UNet2DConditionModel { add_upsample: i < n_blocks - 1, ..Default::default() }; - if use_cross_attn { + if let Some(transformer_layers_per_block) = use_cross_attn { let config = CrossAttnUpBlock2DConfig { upblock: ub_cfg, attn_num_head_channels: attention_head_dim, cross_attention_dim: config.cross_attention_dim, sliced_attention_size, use_linear_projection: config.use_linear_projection, + transformer_layers_per_block, }; let block = CrossAttnUpBlock2D::new( vs_ub.pp(&i.to_string()), |