summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs8
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs30
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs170
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs128
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs25
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs12
6 files changed, 295 insertions, 78 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
index 58f5e87e..1ae1bfc3 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -208,9 +208,9 @@ impl CrossAttention {
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
- let context = context.unwrap_or(xs);
- let key = self.to_k.forward(context)?;
- let value = self.to_v.forward(context)?;
+ let context = context.unwrap_or(xs).contiguous()?;
+ let key = self.to_k.forward(&context)?;
+ let value = self.to_v.forward(&context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
@@ -473,7 +473,7 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
- let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
+ let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 2927a404..d26c1c46 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -69,6 +69,36 @@ impl Config {
activation: Activation::Gelu,
}
}
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
+ pub fn sdxl() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
+ pub fn sdxl2() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 1280,
+ activation: Activation::Gelu,
+ }
+ }
}
// CLIP Text Model
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 1443986c..8372edcd 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -17,7 +17,7 @@ mod utils;
mod vae;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{DType, Device, IndexOp, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@@ -102,12 +102,16 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
+ Xl,
}
+#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
+ Tokenizer2,
Clip,
+ Clip2,
Unet,
Vae,
}
@@ -115,6 +119,7 @@ enum ModelFile {
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
+ Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
@@ -122,7 +127,7 @@ impl StableDiffusionVersion {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -134,7 +139,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -146,7 +151,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -155,12 +160,21 @@ impl StableDiffusionVersion {
}
}
}
+
+ fn clip2_file(&self, use_f16: bool) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
+ if use_f16 {
+ "text_encoder_2/model.fp16.safetensors"
+ } else {
+ "text_encoder_2/model.safetensors"
+ }
+ }
+ }
+ }
}
impl ModelFile {
- const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
- const TOKENIZER_PATH: &str = "tokenizer.json";
-
fn get(
&self,
filename: Option<String>,
@@ -172,8 +186,24 @@ impl ModelFile {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
- Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
+ Self::Tokenizer => {
+ let tokenizer_repo = match version {
+ StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
+ "openai/clip-vit-base-patch32"
+ }
+ StableDiffusionVersion::Xl => {
+ // This seems similar to the patch32 version except some very small
+ // difference in the split regex.
+ "openai/clip-vit-large-patch14"
+ }
+ };
+ (tokenizer_repo, "tokenizer.json")
+ }
+ Self::Tokenizer2 => {
+ ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
+ }
Self::Clip => (version.repo(), version.clip_file(use_f16)),
+ Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
@@ -211,6 +241,71 @@ fn output_filename(
}
}
+#[allow(clippy::too_many_arguments)]
+fn text_embeddings(
+ prompt: &str,
+ uncond_prompt: &str,
+ tokenizer: Option<String>,
+ clip_weights: Option<String>,
+ sd_version: StableDiffusionVersion,
+ sd_config: &stable_diffusion::StableDiffusionConfig,
+ use_f16: bool,
+ device: &Device,
+ dtype: DType,
+ first: bool,
+) -> Result<Tensor> {
+ let tokenizer_file = if first {
+ ModelFile::Tokenizer
+ } else {
+ ModelFile::Tokenizer2
+ };
+ let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let clip_weights_file = if first {
+ ModelFile::Clip
+ } else {
+ ModelFile::Clip2
+ };
+ let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
+ let clip_config = if first {
+ &sd_config.clip
+ } else {
+ sd_config.clip2.as_ref().unwrap()
+ };
+ let text_model =
+ stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+ Ok(text_embeddings)
+}
+
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Xl => {
+ stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
+ }
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
- let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let pad_id = match &sd_config.clip.pad_with {
- Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
- None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
- };
- println!("Running with prompt \"{prompt}\".");
- let mut tokens = tokenizer
- .encode(prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while tokens.len() < sd_config.clip.max_position_embeddings {
- tokens.push(pad_id)
- }
- let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- println!("Building the Clip transformer.");
- let text_embeddings = {
- let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
- let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
- let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ let which = match sd_version {
+ StableDiffusionVersion::Xl => vec![true, false],
+ _ => vec![true],
};
+ let text_embeddings = which
+ .iter()
+ .map(|first| {
+ text_embeddings(
+ &prompt,
+ &uncond_prompt,
+ tokenizer.clone(),
+ clip_weights.clone(),
+ sd_version,
+ &sd_config,
+ use_f16,
+ &device,
+ dtype,
+ *first,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
+ println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index e159fa0a..cffc00d8 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
+ pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@@ -27,10 +28,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 8),
- bc(640, true, 8),
- bc(1280, true, 8),
- bc(1280, false, 8),
+ bc(320, Some(1), 8),
+ bc(640, Some(1), 8),
+ bc(1280, Some(1), 8),
+ bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
@@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
+ clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@@ -88,10 +90,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 5),
- bc(640, true, 10),
- bc(1280, true, 20),
- bc(1280, false, 20),
+ bc(320, Some(1), 5),
+ bc(640, Some(1), 10),
+ bc(1280, Some(1), 20),
+ bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
@@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
+ clip2: None,
autoencoder,
scheduler,
unet,
@@ -155,6 +158,87 @@ impl StableDiffusionConfig {
)
}
+ fn sdxl_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 1024
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 1024
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn sdxl(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
+ PredictionType::Epsilon,
+ )
+ }
+
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -193,17 +277,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
+}
- pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
- &self,
- clip_weights: P,
- device: &Device,
- dtype: DType,
- ) -> Result<clip::ClipTextTransformer> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
- let weights = weights.deserialize()?;
- let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
- let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
- Ok(text_model)
- }
+pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
+ clip: &clip::Config,
+ clip_weights: P,
+ device: &Device,
+ dtype: DType,
+) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let text_model = clip::ClipTextTransformer::new(vs, clip)?;
+ Ok(text_model)
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index eb2dbf10..81bd9547 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -12,7 +12,9 @@ use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
- pub use_cross_attn: bool,
+ /// When `None` no cross-attn is used, when `Some(d)` then cross-attn is used and `d` is the
+ /// number of transformer blocks to be used.
+ pub use_cross_attn: Option<usize>,
pub attention_head_dim: usize,
}
@@ -41,22 +43,22 @@ impl Default for UNet2DConditionModelConfig {
blocks: vec![
BlockConfig {
out_channels: 320,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
- use_cross_attn: true,
+ use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
- use_cross_attn: false,
+ use_cross_attn: None,
attention_head_dim: 8,
},
],
@@ -149,13 +151,14 @@ impl UNet2DConditionModel {
downsample_padding: config.downsample_padding,
..Default::default()
};
- if use_cross_attn {
+ if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
@@ -179,6 +182,11 @@ impl UNet2DConditionModel {
})
.collect::<Result<Vec<_>>>()?;
+ // https://github.com/huggingface/diffusers/blob/a76f2ad538e73b34d5fe7be08c8eb8ab38c7e90c/src/diffusers/models/unet_2d_condition.py#L462
+ let mid_transformer_layers_per_block = match config.blocks.last() {
+ None => 1,
+ Some(block) => block.use_cross_attn.unwrap_or(1),
+ };
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
@@ -186,8 +194,10 @@ impl UNet2DConditionModel {
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block: mid_transformer_layers_per_block,
..Default::default()
};
+
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
@@ -231,13 +241,14 @@ impl UNet2DConditionModel {
add_upsample: i < n_blocks - 1,
..Default::default()
};
- if use_cross_attn {
+ if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
+ transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 65341e74..1db65222 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -366,6 +366,7 @@ pub struct UNetMidBlock2DCrossAttnConfig {
pub cross_attn_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
@@ -379,6 +380,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
cross_attn_dim: 1280,
sliced_attention_size: None, // Sliced attention disabled
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -414,7 +416,7 @@ impl UNetMidBlock2DCrossAttn {
let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
let n_heads = config.attn_num_head_channels;
let attn_cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
num_groups: resnet_groups,
context_dim: Some(config.cross_attn_dim),
sliced_attention_size: config.sliced_attention_size,
@@ -565,6 +567,7 @@ pub struct CrossAttnDownBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
@@ -575,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -605,7 +609,7 @@ impl CrossAttnDownBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.downblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,
@@ -767,6 +771,7 @@ pub struct CrossAttnUpBlock2DConfig {
// attention_type: "default"
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
+ pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
@@ -777,6 +782,7 @@ impl Default for CrossAttnUpBlock2DConfig {
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
+ transformer_layers_per_block: 1,
}
}
}
@@ -809,7 +815,7 @@ impl CrossAttnUpBlock2D {
)?;
let n_heads = config.attn_num_head_channels;
let cfg = SpatialTransformerConfig {
- depth: 1,
+ depth: config.transformer_layers_per_block,
context_dim: Some(config.cross_attention_dim),
num_groups: config.upblock.resnet_groups,
sliced_attention_size: config.sliced_attention_size,