diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/unet_2d_blocks.rs | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 65341e74..1db65222 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig { pub cross_attn_dim: usize, pub sliced_attention_size: Option<usize>, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for UNetMidBlock2DCrossAttnConfig { @@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig { cross_attn_dim: 1280, sliced_attention_size: None, // Sliced attention disabled use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn { let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?; let n_heads = config.attn_num_head_channels; let attn_cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, num_groups: resnet_groups, context_dim: Some(config.cross_attn_dim), sliced_attention_size: config.sliced_attention_size, @@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig { // attention_type: "default" pub sliced_attention_size: Option<usize>, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for CrossAttnDownBlock2DConfig { @@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig { cross_attention_dim: 1280, sliced_attention_size: None, use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D { )?; let n_heads = config.attn_num_head_channels; let cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, context_dim: Some(config.cross_attention_dim), num_groups: config.downblock.resnet_groups, sliced_attention_size: config.sliced_attention_size, @@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig { // attention_type: "default" pub sliced_attention_size: Option<usize>, pub use_linear_projection: bool, + pub transformer_layers_per_block: usize, } impl Default for CrossAttnUpBlock2DConfig { @@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig { cross_attention_dim: 1280, sliced_attention_size: None, use_linear_projection: false, + transformer_layers_per_block: 1, } } } @@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D { )?; let n_heads = config.attn_num_head_channels; let cfg = SpatialTransformerConfig { - depth: 1, + depth: config.transformer_layers_per_block, context_dim: Some(config.cross_attention_dim), num_groups: config.upblock.resnet_groups, sliced_attention_size: config.sliced_attention_size, |