summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mmdit/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mmdit/model.rs')
-rw-r--r--candle-transformers/src/models/mmdit/model.rs49
1 files changed, 40 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index 5b5c90b0..c7b4deed 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -1,10 +1,15 @@
-// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
+// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
+// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
+// with MMDiT-X support following the Stability-AI/sd3.5 repository.
+// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
-use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
+use super::blocks::{
+ ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
+};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
@@ -37,6 +42,20 @@ impl Config {
}
}
+ pub fn sd3_5_medium() -> Self {
+ Self {
+ patch_size: 2,
+ in_channels: 16,
+ out_channels: 16,
+ depth: 24,
+ head_size: 64,
+ adm_in_channels: 2048,
+ pos_embed_max_size: 384,
+ context_embed_size: 4096,
+ frequency_embedding_size: 256,
+ }
+ }
+
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
@@ -138,7 +157,7 @@ impl MMDiT {
}
pub struct MMDiTCore {
- joint_blocks: Vec<JointBlock>,
+ joint_blocks: Vec<Box<dyn JointBlock>>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
@@ -155,12 +174,24 @@ impl MMDiTCore {
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
- joint_blocks.push(JointBlock::new(
- hidden_size,
- num_heads,
- use_flash_attn,
- vb.pp(format!("joint_blocks.{}", i)),
- )?);
+ let joint_block_vb_pp = format!("joint_blocks.{}", i);
+ let joint_block: Box<dyn JointBlock> =
+ if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
+ Box::new(MMDiTXJointBlock::new(
+ hidden_size,
+ num_heads,
+ use_flash_attn,
+ vb.pp(&joint_block_vb_pp),
+ )?)
+ } else {
+ Box::new(MMDiTJointBlock::new(
+ hidden_size,
+ num_heads,
+ use_flash_attn,
+ vb.pp(&joint_block_vb_pp),
+ )?)
+ };
+ joint_blocks.push(joint_block);
}
Ok(Self {