summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-10-29 22:19:07 -0700
committerGitHub <noreply@github.com>2024-10-30 06:19:07 +0100
commitd232e132f6af552c351bb046a38df4bce009c8aa (patch)
treedf10d4ac4b73b32716cba3bcc38bfa07b0d91a83 /candle-transformers
parent139ff56aeb1a6bbf0ed742f936a7a96bebccfa30 (diff)
downloadcandle-d232e132f6af552c351bb046a38df4bce009c8aa.tar.gz
candle-d232e132f6af552c351bb046a38df4bce009c8aa.tar.bz2
candle-d232e132f6af552c351bb046a38df4bce009c8aa.zip
Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn * further adjust attn and joint_attn * add mmdit-x support * support sd3.5-medium in the example * update README.md
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mmdit/blocks.rs191
-rw-r--r--candle-transformers/src/models/mmdit/model.rs49
2 files changed, 217 insertions, 23 deletions
diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs
index a1777f91..912e2498 100644
--- a/candle-transformers/src/models/mmdit/blocks.rs
+++ b/candle-transformers/src/models/mmdit/blocks.rs
@@ -36,7 +36,6 @@ impl Module for LayerNormNoAffine {
impl DiTBlock {
pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
- // {'hidden_size': 1536, 'num_heads': 24}
let norm1 = LayerNormNoAffine::new(1e-6);
let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
let norm2 = LayerNormNoAffine::new(1e-6);
@@ -103,6 +102,117 @@ impl DiTBlock {
}
}
+pub struct SelfAttnModulateIntermediates {
+ gate_msa: Tensor,
+ shift_mlp: Tensor,
+ scale_mlp: Tensor,
+ gate_mlp: Tensor,
+ gate_msa2: Tensor,
+}
+
+pub struct SelfAttnDiTBlock {
+ norm1: LayerNormNoAffine,
+ attn: AttnProjections,
+ attn2: AttnProjections,
+ norm2: LayerNormNoAffine,
+ mlp: Mlp,
+ ada_ln_modulation: nn::Sequential,
+}
+
+impl SelfAttnDiTBlock {
+ pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let norm1 = LayerNormNoAffine::new(1e-6);
+ let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
+ let attn2 = AttnProjections::new(hidden_size, num_heads, vb.pp("attn2"))?;
+ let norm2 = LayerNormNoAffine::new(1e-6);
+ let mlp_ratio = 4;
+ let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
+ let n_mods = 9;
+ let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
+ hidden_size,
+ n_mods * hidden_size,
+ vb.pp("adaLN_modulation.1"),
+ )?);
+
+ Ok(Self {
+ norm1,
+ attn,
+ attn2,
+ norm2,
+ mlp,
+ ada_ln_modulation,
+ })
+ }
+
+ pub fn pre_attention(
+ &self,
+ x: &Tensor,
+ c: &Tensor,
+ ) -> Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
+ let modulation = self.ada_ln_modulation.forward(c)?;
+ let chunks = modulation.chunk(9, D::Minus1)?;
+ let (
+ shift_msa,
+ scale_msa,
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ shift_msa2,
+ scale_msa2,
+ gate_msa2,
+ ) = (
+ chunks[0].clone(),
+ chunks[1].clone(),
+ chunks[2].clone(),
+ chunks[3].clone(),
+ chunks[4].clone(),
+ chunks[5].clone(),
+ chunks[6].clone(),
+ chunks[7].clone(),
+ chunks[8].clone(),
+ );
+
+ let norm_x = self.norm1.forward(x)?;
+ let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
+ let qkv = self.attn.pre_attention(&modulated_x)?;
+
+ let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
+ let qkv2 = self.attn2.pre_attention(&modulated_x2)?;
+
+ Ok((
+ qkv,
+ qkv2,
+ SelfAttnModulateIntermediates {
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ gate_msa2,
+ },
+ ))
+ }
+
+ pub fn post_attention(
+ &self,
+ attn: &Tensor,
+ attn2: &Tensor,
+ x: &Tensor,
+ mod_interm: &SelfAttnModulateIntermediates,
+ ) -> Result<Tensor> {
+ let attn_out = self.attn.post_attention(attn)?;
+ let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
+ let attn_out2 = self.attn2.post_attention(attn2)?;
+ let x = x.add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;
+
+ let norm_x = self.norm2.forward(&x)?;
+ let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
+ let mlp_out = self.mlp.forward(&modulated_x)?;
+ let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
+ Ok(x)
+ }
+}
+
pub struct QkvOnlyDiTBlock {
norm1: LayerNormNoAffine,
attn: QkvOnlyAttnProjections,
@@ -190,14 +300,18 @@ fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
}
-pub struct JointBlock {
+pub trait JointBlock {
+ fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)>;
+}
+
+pub struct MMDiTJointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
use_flash_attn: bool,
}
-impl JointBlock {
+impl MMDiTJointBlock {
pub fn new(
hidden_size: usize,
num_heads: usize,
@@ -214,8 +328,10 @@ impl JointBlock {
use_flash_attn,
})
}
+}
- pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
+impl JointBlock for MMDiTJointBlock {
+ fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
let (context_attn, x_attn) =
@@ -228,6 +344,49 @@ impl JointBlock {
}
}
+pub struct MMDiTXJointBlock {
+ x_block: SelfAttnDiTBlock,
+ context_block: DiTBlock,
+ num_heads: usize,
+ use_flash_attn: bool,
+}
+
+impl MMDiTXJointBlock {
+ pub fn new(
+ hidden_size: usize,
+ num_heads: usize,
+ use_flash_attn: bool,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
+ let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
+
+ Ok(Self {
+ x_block,
+ context_block,
+ num_heads,
+ use_flash_attn,
+ })
+ }
+}
+
+impl JointBlock for MMDiTXJointBlock {
+ fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
+ let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
+ let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
+ let (context_attn, x_attn) =
+ joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
+ let x_attn2 = attn(&x_qkv2, self.num_heads, self.use_flash_attn)?;
+ let context_out =
+ self.context_block
+ .post_attention(&context_attn, context, &context_interm)?;
+ let x_out = self
+ .x_block
+ .post_attention(&x_attn, &x_attn2, x, &x_interm)?;
+ Ok((context_out, x_out))
+ }
+}
+
pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
@@ -309,26 +468,30 @@ fn joint_attn(
v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
};
- let (batch_size, seqlen, _) = qkv.q.dims3()?;
+ let seqlen = qkv.q.dim(1)?;
+ let attn = attn(&qkv, num_heads, use_flash_attn)?;
+ let context_qkv_seqlen = context_qkv.q.dim(1)?;
+ let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
+ let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
+
+ Ok((context_attn, x_attn))
+}
+
+fn attn(qkv: &Qkv, num_heads: usize, use_flash_attn: bool) -> Result<Tensor> {
+ let batch_size = qkv.q.dim(0)?;
+ let seqlen = qkv.q.dim(1)?;
let qkv = Qkv {
q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
- v: qkv.v,
+ v: qkv.v.clone(),
};
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
-
let attn = if use_flash_attn {
flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
} else {
flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
};
-
- let attn = attn.reshape((batch_size, seqlen, ()))?;
- let context_qkv_seqlen = context_qkv.q.dim(1)?;
- let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
- let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
-
- Ok((context_attn, x_attn))
+ attn.reshape((batch_size, seqlen, ()))
}
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index 5b5c90b0..c7b4deed 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -1,10 +1,15 @@
-// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
+// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206),
+// as well as the MMDiT-X variant introduced for Stable Diffusion 3.5-medium (https://huggingface.co/stabilityai/stable-diffusion-3.5-medium)
// This follows the implementation of the MMDiT model in the ComfyUI repository.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
+// with MMDiT-X support following the Stability-AI/sd3.5 repository.
+// https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/mmditx.py#L1
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
-use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
+use super::blocks::{
+ ContextQkvOnlyJointBlock, FinalLayer, JointBlock, MMDiTJointBlock, MMDiTXJointBlock,
+};
use super::embedding::{
PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
};
@@ -37,6 +42,20 @@ impl Config {
}
}
+ pub fn sd3_5_medium() -> Self {
+ Self {
+ patch_size: 2,
+ in_channels: 16,
+ out_channels: 16,
+ depth: 24,
+ head_size: 64,
+ adm_in_channels: 2048,
+ pos_embed_max_size: 384,
+ context_embed_size: 4096,
+ frequency_embedding_size: 256,
+ }
+ }
+
pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
@@ -138,7 +157,7 @@ impl MMDiT {
}
pub struct MMDiTCore {
- joint_blocks: Vec<JointBlock>,
+ joint_blocks: Vec<Box<dyn JointBlock>>,
context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
final_layer: FinalLayer,
}
@@ -155,12 +174,24 @@ impl MMDiTCore {
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
for i in 0..depth - 1 {
- joint_blocks.push(JointBlock::new(
- hidden_size,
- num_heads,
- use_flash_attn,
- vb.pp(format!("joint_blocks.{}", i)),
- )?);
+ let joint_block_vb_pp = format!("joint_blocks.{}", i);
+ let joint_block: Box<dyn JointBlock> =
+ if vb.contains_tensor(&format!("{}.x_block.attn2.qkv.weight", joint_block_vb_pp)) {
+ Box::new(MMDiTXJointBlock::new(
+ hidden_size,
+ num_heads,
+ use_flash_attn,
+ vb.pp(&joint_block_vb_pp),
+ )?)
+ } else {
+ Box::new(MMDiTJointBlock::new(
+ hidden_size,
+ num_heads,
+ use_flash_attn,
+ vb.pp(&joint_block_vb_pp),
+ )?)
+ };
+ joint_blocks.push(joint_block);
}
Ok(Self {