summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/unet_2d_blocks.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs12
1 files changed, 9 insertions, 3 deletions
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 65341e74..1db65222 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
@@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
@@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
@@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
@@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
@@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,