diff options
Diffstat (limited to 'candle-transformers/src/models/mmdit/model.rs')
-rw-r--r-- | candle-transformers/src/models/mmdit/model.rs | 49 |
1 files changed, 40 insertions, 9 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 5b5c90b0..c7b4deed 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -1,10 +1,15 @@ -// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206). +// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206), +// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) // This follows the implementation of the MMDiT model in the ComfyUI repository. // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1 +// with MMDiT-X support following the Stability-AI/sd3.5 repository. +// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1 use candle::{Module, Result, Tensor, D}; use candle_nn as nn; -use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock}; +use super::blocks::{ + ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock, +}; use super::embedding::{ PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder, }; @@ -37,6 +42,20 @@ impl Config { } } + pub fn sd3_5_medium() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 24, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 384, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } + pub fn sd3_5_large() -> Self { Self { patch_size: 2, @@ -138,7 +157,7 @@ impl MMDiT { } pub struct MMDiTCore { - joint_blocks: Vec<JointBlock>, + joint_blocks: Vec<Box<dyn JointBlock>>, context_qkv_only_joint_block: ContextQkvOnlyJointBlock, final_layer: FinalLayer, } @@ -155,12 +174,24 @@ impl MMDiTCore { ) -> Result<Self> { let mut joint_blocks = Vec::with_capacity(depth - 1); for i in 0..depth - 1 { - joint_blocks.push(JointBlock::new( - hidden_size, - num_heads, - use_flash_attn, - vb.pp(format!("joint_blocks.{}", i)), - )?); + let joint_block_vb_pp = format!("joint_blocks.{}", i); + let joint_block: Box<dyn JointBlock> = + if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) { + Box::new(MMDiTXJointBlock::new( + hidden_size, + num_heads, + use_flash_attn, + vb.pp(&joint_block_vb_pp), + )?) + } else { + Box::new(MMDiTJointBlock::new( + hidden_size, + num_heads, + use_flash_attn, + vb.pp(&joint_block_vb_pp), + )?) + }; + joint_blocks.push(joint_block); } Ok(Self { |