summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-08-05 10:26:15 -0700
committerGitHub <noreply@github.com>2024-08-05 19:26:15 +0200
commitdfdce2b6022ee5328c1a0d7305bdd3e3e4d3bc75 (patch)
treeaf547f9ac03bc53ed82d7b290eecabcc656b4ce1 /candle-transformers
parent500c9f288214cf817c8c52c7c9a6187cb279c563 (diff)
downloadcandle-dfdce2b6022ee5328c1a0d7305bdd3e3e4d3bc75.tar.gz
candle-dfdce2b6022ee5328c1a0d7305bdd3e3e4d3bc75.tar.bz2
candle-dfdce2b6022ee5328c1a0d7305bdd3e3e4d3bc75.zip
Add the MMDiT model of Stable Diffusion 3 (#2397)
* add mmdit of stable diffusion 3 lint add comments * correct a misplaced comment * fix cargo fmt * fix clippy error * use bail! instead of assert! * use get_on_dim in splitting qkv
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mmdit/blocks.rs294
-rw-r--r--candle-transformers/src/models/mmdit/embedding.rs197
-rw-r--r--candle-transformers/src/models/mmdit/mod.rs4
-rw-r--r--candle-transformers/src/models/mmdit/model.rs173
-rw-r--r--candle-transformers/src/models/mmdit/projections.rs94
-rw-r--r--candle-transformers/src/models/mod.rs1
6 files changed, 763 insertions, 0 deletions
diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs
new file mode 100644
index 00000000..e2b924a0
--- /dev/null
+++ b/candle-transformers/src/models/mmdit/blocks.rs
@@ -0,0 +1,294 @@
+use candle::{Module, Result, Tensor, D};
+use candle_nn as nn;
+
+use super::projections::{AttnProjections, Mlp, Qkv, QkvOnlyAttnProjections};
+
+pub struct ModulateIntermediates {
+ gate_msa: Tensor,
+ shift_mlp: Tensor,
+ scale_mlp: Tensor,
+ gate_mlp: Tensor,
+}
+
+pub struct DiTBlock {
+ norm1: LayerNormNoAffine,
+ attn: AttnProjections,
+ norm2: LayerNormNoAffine,
+ mlp: Mlp,
+ ada_ln_modulation: nn::Sequential,
+}
+
+pub struct LayerNormNoAffine {
+ eps: f64,
+}
+
+impl LayerNormNoAffine {
+ pub fn new(eps: f64) -> Self {
+ Self { eps }
+ }
+}
+
+impl Module for LayerNormNoAffine {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ nn::LayerNorm::new_no_bias(Tensor::ones_like(x)?, self.eps).forward(x)
+ }
+}
+
+impl DiTBlock {
+ pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ // {'hidden_size': 1536, 'num_heads': 24}
+ let norm1 = LayerNormNoAffine::new(1e-6);
+ let attn = AttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
+ let norm2 = LayerNormNoAffine::new(1e-6);
+ let mlp_ratio = 4;
+ let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
+ let n_mods = 6;
+ let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
+ hidden_size,
+ n_mods * hidden_size,
+ vb.pp("adaLN_modulation.1"),
+ )?);
+
+ Ok(Self {
+ norm1,
+ attn,
+ norm2,
+ mlp,
+ ada_ln_modulation,
+ })
+ }
+
+ pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<(Qkv, ModulateIntermediates)> {
+ let modulation = self.ada_ln_modulation.forward(c)?;
+ let chunks = modulation.chunk(6, D::Minus1)?;
+ let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
+ chunks[0].clone(),
+ chunks[1].clone(),
+ chunks[2].clone(),
+ chunks[3].clone(),
+ chunks[4].clone(),
+ chunks[5].clone(),
+ );
+
+ let norm_x = self.norm1.forward(x)?;
+ let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
+ let qkv = self.attn.pre_attention(&modulated_x)?;
+
+ Ok((
+ qkv,
+ ModulateIntermediates {
+ gate_msa,
+ shift_mlp,
+ scale_mlp,
+ gate_mlp,
+ },
+ ))
+ }
+
+ pub fn post_attention(
+ &self,
+ attn: &Tensor,
+ x: &Tensor,
+ mod_interm: &ModulateIntermediates,
+ ) -> Result<Tensor> {
+ let attn_out = self.attn.post_attention(attn)?;
+ let x = x.add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
+
+ let norm_x = self.norm2.forward(&x)?;
+ let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
+ let mlp_out = self.mlp.forward(&modulated_x)?;
+ let x = x.add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)?;
+
+ Ok(x)
+ }
+}
+
+pub struct QkvOnlyDiTBlock {
+ norm1: LayerNormNoAffine,
+ attn: QkvOnlyAttnProjections,
+ ada_ln_modulation: nn::Sequential,
+}
+
+impl QkvOnlyDiTBlock {
+ pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let norm1 = LayerNormNoAffine::new(1e-6);
+ let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
+ let n_mods = 2;
+ let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
+ hidden_size,
+ n_mods * hidden_size,
+ vb.pp("adaLN_modulation.1"),
+ )?);
+
+ Ok(Self {
+ norm1,
+ attn,
+ ada_ln_modulation,
+ })
+ }
+
+ pub fn pre_attention(&self, x: &Tensor, c: &Tensor) -> Result<Qkv> {
+ let modulation = self.ada_ln_modulation.forward(c)?;
+ let chunks = modulation.chunk(2, D::Minus1)?;
+ let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
+
+ let norm_x = self.norm1.forward(x)?;
+ let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
+ self.attn.pre_attention(&modulated_x)
+ }
+}
+
+pub struct FinalLayer {
+ norm_final: LayerNormNoAffine,
+ linear: nn::Linear,
+ ada_ln_modulation: nn::Sequential,
+}
+
+impl FinalLayer {
+ pub fn new(
+ hidden_size: usize,
+ patch_size: usize,
+ out_channels: usize,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let norm_final = LayerNormNoAffine::new(1e-6);
+ let linear = nn::linear(
+ hidden_size,
+ patch_size * patch_size * out_channels,
+ vb.pp("linear"),
+ )?;
+ let ada_ln_modulation = nn::seq().add(nn::Activation::Silu).add(nn::linear(
+ hidden_size,
+ 2 * hidden_size,
+ vb.pp("adaLN_modulation.1"),
+ )?);
+
+ Ok(Self {
+ norm_final,
+ linear,
+ ada_ln_modulation,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let modulation = self.ada_ln_modulation.forward(c)?;
+ let chunks = modulation.chunk(2, D::Minus1)?;
+ let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
+
+ let norm_x = self.norm_final.forward(x)?;
+ let modulated_x = modulate(&norm_x, &shift, &scale)?;
+ let output = self.linear.forward(&modulated_x)?;
+
+ Ok(output)
+ }
+}
+
+fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> Result<Tensor> {
+ let shift = shift.unsqueeze(1)?;
+ let scale = scale.unsqueeze(1)?;
+ let scale_plus_one = scale.add(&Tensor::ones_like(&scale)?)?;
+ shift.broadcast_add(&x.broadcast_mul(&scale_plus_one)?)
+}
+
+pub struct JointBlock {
+ x_block: DiTBlock,
+ context_block: DiTBlock,
+ num_heads: usize,
+}
+
+impl JointBlock {
+ pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
+ let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
+
+ Ok(Self {
+ x_block,
+ context_block,
+ num_heads,
+ })
+ }
+
+ pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
+ let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
+ let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
+ let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
+ let context_out =
+ self.context_block
+ .post_attention(&context_attn, context, &context_interm)?;
+ let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
+ Ok((context_out, x_out))
+ }
+}
+
+pub struct ContextQkvOnlyJointBlock {
+ x_block: DiTBlock,
+ context_block: QkvOnlyDiTBlock,
+ num_heads: usize,
+}
+
+impl ContextQkvOnlyJointBlock {
+ pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
+ let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
+ Ok(Self {
+ x_block,
+ context_block,
+ num_heads,
+ })
+ }
+
+ pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let context_qkv = self.context_block.pre_attention(context, c)?;
+ let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
+
+ let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
+
+ let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
+ Ok(x_out)
+ }
+}
+
+// A QKV-attention that is compatible with the interface of candle_flash_attn::flash_attn
+// Flash attention regards q, k, v dimensions as (batch_size, seqlen, nheads, headdim)
+fn flash_compatible_attention(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+) -> Result<Tensor> {
+ let q_dims_for_matmul = q.transpose(1, 2)?.dims().to_vec();
+ let rank = q_dims_for_matmul.len();
+ let q = q.transpose(1, 2)?.flatten_to(rank - 3)?;
+ let k = k.transpose(1, 2)?.flatten_to(rank - 3)?;
+ let v = v.transpose(1, 2)?.flatten_to(rank - 3)?;
+ let attn_weights = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
+ let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
+ attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
+}
+
+fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
+ let qkv = Qkv {
+ q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
+ k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
+ v: Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?,
+ };
+
+ let (batch_size, seqlen, _) = qkv.q.dims3()?;
+ let qkv = Qkv {
+ q: qkv.q.reshape((batch_size, seqlen, num_heads, ()))?,
+ k: qkv.k.reshape((batch_size, seqlen, num_heads, ()))?,
+ v: qkv.v,
+ };
+
+ let headdim = qkv.q.dim(D::Minus1)?;
+ let softmax_scale = 1.0 / (headdim as f64).sqrt();
+ // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
+ let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
+
+ let attn = attn.reshape((batch_size, seqlen, ()))?;
+ let context_qkv_seqlen = context_qkv.q.dim(1)?;
+ let context_attn = attn.narrow(1, 0, context_qkv_seqlen)?;
+ let x_attn = attn.narrow(1, context_qkv_seqlen, seqlen - context_qkv_seqlen)?;
+
+ Ok((context_attn, x_attn))
+}
diff --git a/candle-transformers/src/models/mmdit/embedding.rs b/candle-transformers/src/models/mmdit/embedding.rs
new file mode 100644
index 00000000..6e200b18
--- /dev/null
+++ b/candle-transformers/src/models/mmdit/embedding.rs
@@ -0,0 +1,197 @@
+use candle::{bail, DType, Module, Result, Tensor};
+use candle_nn as nn;
+
+pub struct PatchEmbedder {
+ proj: nn::Conv2d,
+}
+
+impl PatchEmbedder {
+ pub fn new(
+ patch_size: usize,
+ in_channels: usize,
+ embed_dim: usize,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let proj = nn::conv2d(
+ in_channels,
+ embed_dim,
+ patch_size,
+ nn::Conv2dConfig {
+ stride: patch_size,
+ ..Default::default()
+ },
+ vb.pp("proj"),
+ )?;
+
+ Ok(Self { proj })
+ }
+}
+
+impl Module for PatchEmbedder {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = self.proj.forward(x)?;
+
+ // flatten spatial dim and transpose to channels last
+ let (b, c, h, w) = x.dims4()?;
+ x.reshape((b, c, h * w))?.transpose(1, 2)
+ }
+}
+
+pub struct Unpatchifier {
+ patch_size: usize,
+ out_channels: usize,
+}
+
+impl Unpatchifier {
+ pub fn new(patch_size: usize, out_channels: usize) -> Result<Self> {
+ Ok(Self {
+ patch_size,
+ out_channels,
+ })
+ }
+
+ pub fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> Result<Tensor> {
+ let h = (h + 1) / self.patch_size;
+ let w = (w + 1) / self.patch_size;
+
+ let x = x.reshape((
+ x.dim(0)?,
+ h,
+ w,
+ self.patch_size,
+ self.patch_size,
+ self.out_channels,
+ ))?;
+ let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
+ x.reshape((
+ x.dim(0)?,
+ self.out_channels,
+ self.patch_size * h,
+ self.patch_size * w,
+ ))
+ }
+}
+
+pub struct PositionEmbedder {
+ pos_embed: Tensor,
+ patch_size: usize,
+ pos_embed_max_size: usize,
+}
+
+impl PositionEmbedder {
+ pub fn new(
+ hidden_size: usize,
+ patch_size: usize,
+ pos_embed_max_size: usize,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let pos_embed = vb.get(
+ (1, pos_embed_max_size * pos_embed_max_size, hidden_size),
+ "pos_embed",
+ )?;
+ Ok(Self {
+ pos_embed,
+ patch_size,
+ pos_embed_max_size,
+ })
+ }
+ pub fn get_cropped_pos_embed(&self, h: usize, w: usize) -> Result<Tensor> {
+ let h = (h + 1) / self.patch_size;
+ let w = (w + 1) / self.patch_size;
+
+ if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
+ bail!("Input size is too large for the position embedding")
+ }
+
+ let top = (self.pos_embed_max_size - h) / 2;
+ let left = (self.pos_embed_max_size - w) / 2;
+
+ let pos_embed =
+ self.pos_embed
+ .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
+ let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
+ pos_embed.reshape((1, h * w, ()))
+ }
+}
+
+pub struct TimestepEmbedder {
+ mlp: nn::Sequential,
+ frequency_embedding_size: usize,
+}
+
+impl TimestepEmbedder {
+ pub fn new(
+ hidden_size: usize,
+ frequency_embedding_size: usize,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let mlp = nn::seq()
+ .add(nn::linear(
+ frequency_embedding_size,
+ hidden_size,
+ vb.pp("mlp.0"),
+ )?)
+ .add(nn::Activation::Silu)
+ .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
+
+ Ok(Self {
+ mlp,
+ frequency_embedding_size,
+ })
+ }
+
+ fn timestep_embedding(t: &Tensor, dim: usize, max_period: f64) -> Result<Tensor> {
+ if dim % 2 != 0 {
+ bail!("Embedding dimension must be even")
+ }
+
+ if t.dtype() != DType::F32 && t.dtype() != DType::F64 {
+ bail!("Input tensor must be floating point")
+ }
+
+ let half = dim / 2;
+ let freqs = Tensor::arange(0f32, half as f32, t.device())?
+ .to_dtype(candle::DType::F32)?
+ .mul(&Tensor::full(
+ (-f64::ln(max_period) / half as f64) as f32,
+ half,
+ t.device(),
+ )?)?
+ .exp()?;
+
+ let args = t
+ .unsqueeze(1)?
+ .to_dtype(candle::DType::F32)?
+ .matmul(&freqs.unsqueeze(0)?)?;
+ let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
+ embedding.to_dtype(candle::DType::F16)
+ }
+}
+
+impl Module for TimestepEmbedder {
+ fn forward(&self, t: &Tensor) -> Result<Tensor> {
+ let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size, 10000.0)?;
+ self.mlp.forward(&t_freq)
+ }
+}
+
+pub struct VectorEmbedder {
+ mlp: nn::Sequential,
+}
+
+impl VectorEmbedder {
+ pub fn new(input_dim: usize, hidden_size: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let mlp = nn::seq()
+ .add(nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?)
+ .add(nn::Activation::Silu)
+ .add(nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?);
+
+ Ok(Self { mlp })
+ }
+}
+
+impl Module for VectorEmbedder {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ self.mlp.forward(x)
+ }
+}
diff --git a/candle-transformers/src/models/mmdit/mod.rs b/candle-transformers/src/models/mmdit/mod.rs
new file mode 100644
index 00000000..9c4db6e0
--- /dev/null
+++ b/candle-transformers/src/models/mmdit/mod.rs
@@ -0,0 +1,4 @@
+pub mod blocks;
+pub mod embedding;
+pub mod model;
+pub mod projections;
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
new file mode 100644
index 00000000..1523836c
--- /dev/null
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -0,0 +1,173 @@
+// Implement the MMDiT model originally introduced for Stable Diffusion 3 (https://arxiv.org/abs/2403.03206).
+// This follows the implementation of the MMDiT model in the ComfyUI repository.
+// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L1
+use candle::{Module, Result, Tensor, D};
+use candle_nn as nn;
+
+use super::blocks::{ContextQkvOnlyJointBlock, FinalLayer, JointBlock};
+use super::embedding::{
+ PatchEmbedder, PositionEmbedder, TimestepEmbedder, Unpatchifier, VectorEmbedder,
+};
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub patch_size: usize,
+ pub in_channels: usize,
+ pub out_channels: usize,
+ pub depth: usize,
+ pub head_size: usize,
+ pub adm_in_channels: usize,
+ pub pos_embed_max_size: usize,
+ pub context_embed_size: usize,
+ pub frequency_embedding_size: usize,
+}
+
+impl Config {
+ pub fn sd3() -> Self {
+ Self {
+ patch_size: 2,
+ in_channels: 16,
+ out_channels: 16,
+ depth: 24,
+ head_size: 64,
+ adm_in_channels: 2048,
+ pos_embed_max_size: 192,
+ context_embed_size: 4096,
+ frequency_embedding_size: 256,
+ }
+ }
+}
+
+pub struct MMDiT {
+ core: MMDiTCore,
+ patch_embedder: PatchEmbedder,
+ pos_embedder: PositionEmbedder,
+ timestep_embedder: TimestepEmbedder,
+ vector_embedder: VectorEmbedder,
+ context_embedder: nn::Linear,
+ unpatchifier: Unpatchifier,
+}
+
+impl MMDiT {
+ pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
+ let hidden_size = cfg.head_size * cfg.depth;
+ let core = MMDiTCore::new(
+ cfg.depth,
+ hidden_size,
+ cfg.depth,
+ cfg.patch_size,
+ cfg.out_channels,
+ vb.clone(),
+ )?;
+ let patch_embedder = PatchEmbedder::new(
+ cfg.patch_size,
+ cfg.in_channels,
+ hidden_size,
+ vb.pp("x_embedder"),
+ )?;
+ let pos_embedder = PositionEmbedder::new(
+ hidden_size,
+ cfg.patch_size,
+ cfg.pos_embed_max_size,
+ vb.clone(),
+ )?;
+ let timestep_embedder = TimestepEmbedder::new(
+ hidden_size,
+ cfg.frequency_embedding_size,
+ vb.pp("t_embedder"),
+ )?;
+ let vector_embedder =
+ VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
+ let context_embedder = nn::linear(
+ cfg.context_embed_size,
+ hidden_size,
+ vb.pp("context_embedder"),
+ )?;
+ let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels)?;
+
+ Ok(Self {
+ core,
+ patch_embedder,
+ pos_embedder,
+ timestep_embedder,
+ vector_embedder,
+ context_embedder,
+ unpatchifier,
+ })
+ }
+
+ pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
+ // Following the convention of the ComfyUI implementation.
+ // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
+ //
+ // Forward pass of DiT.
+ // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
+ // t: (N,) tensor of diffusion timesteps
+ // y: (N,) tensor of class labels
+ let h = x.dim(D::Minus2)?;
+ let w = x.dim(D::Minus1)?;
+ let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
+ let x = self
+ .patch_embedder
+ .forward(x)?
+ .broadcast_add(&cropped_pos_embed)?;
+ let c = self.timestep_embedder.forward(t)?;
+ let y = self.vector_embedder.forward(y)?;
+ let c = (c + y)?;
+ let context = self.context_embedder.forward(context)?;
+
+ let x = self.core.forward(&context, &x, &c)?;
+ let x = self.unpatchifier.unpatchify(&x, h, w)?;
+ x.narrow(2, 0, h)?.narrow(3, 0, w)
+ }
+}
+
+pub struct MMDiTCore {
+ joint_blocks: Vec<JointBlock>,
+ context_qkv_only_joint_block: ContextQkvOnlyJointBlock,
+ final_layer: FinalLayer,
+}
+
+impl MMDiTCore {
+ pub fn new(
+ depth: usize,
+ hidden_size: usize,
+ num_heads: usize,
+ patch_size: usize,
+ out_channels: usize,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
+ let mut joint_blocks = Vec::with_capacity(depth - 1);
+ for i in 0..depth - 1 {
+ joint_blocks.push(JointBlock::new(
+ hidden_size,
+ num_heads,
+ vb.pp(format!("joint_blocks.{}", i)),
+ )?);
+ }
+
+ Ok(Self {
+ joint_blocks,
+ context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
+ hidden_size,
+ num_heads,
+ vb.pp(format!("joint_blocks.{}", depth - 1)),
+ )?,
+ final_layer: FinalLayer::new(
+ hidden_size,
+ patch_size,
+ out_channels,
+ vb.pp("final_layer"),
+ )?,
+ })
+ }
+
+ pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+ let (mut context, mut x) = (context.clone(), x.clone());
+ for joint_block in &self.joint_blocks {
+ (context, x) = joint_block.forward(&context, &x, c)?;
+ }
+ let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
+ self.final_layer.forward(&x, c)
+ }
+}
diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs
new file mode 100644
index 00000000..1077398f
--- /dev/null
+++ b/candle-transformers/src/models/mmdit/projections.rs
@@ -0,0 +1,94 @@
+use candle::{Module, Result, Tensor};
+use candle_nn as nn;
+
+pub struct Qkv {
+ pub q: Tensor,
+ pub k: Tensor,
+ pub v: Tensor,
+}
+
+pub struct Mlp {
+ fc1: nn::Linear,
+ act: nn::Activation,
+ fc2: nn::Linear,
+}
+
+impl Mlp {
+ pub fn new(
+ in_features: usize,
+ hidden_features: usize,
+ vb: candle_nn::VarBuilder,
+ ) -> Result<Self> {
+ let fc1 = nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
+ let act = nn::Activation::GeluPytorchTanh;
+ let fc2 = nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
+
+ Ok(Self { fc1, act, fc2 })
+ }
+}
+
+impl Module for Mlp {
+ fn forward(&self, x: &Tensor) -> Result<Tensor> {
+ let x = self.fc1.forward(x)?;
+ let x = self.act.forward(&x)?;
+ self.fc2.forward(&x)
+ }
+}
+
+pub struct QkvOnlyAttnProjections {
+ qkv: nn::Linear,
+ head_dim: usize,
+}
+
+impl QkvOnlyAttnProjections {
+ pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ // {'dim': 1536, 'num_heads': 24}
+ let head_dim = dim / num_heads;
+ let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
+ Ok(Self { qkv, head_dim })
+ }
+
+ pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
+ let qkv = self.qkv.forward(x)?;
+ split_qkv(&qkv, self.head_dim)
+ }
+}
+
+pub struct AttnProjections {
+ head_dim: usize,
+ qkv: nn::Linear,
+ proj: nn::Linear,
+}
+
+impl AttnProjections {
+ pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ let head_dim = dim / num_heads;
+ let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
+ let proj = nn::linear(dim, dim, vb.pp("proj"))?;
+ Ok(Self {
+ head_dim,
+ qkv,
+ proj,
+ })
+ }
+
+ pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
+ let qkv = self.qkv.forward(x)?;
+ split_qkv(&qkv, self.head_dim)
+ }
+
+ pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {
+ self.proj.forward(x)
+ }
+}
+
+fn split_qkv(qkv: &Tensor, head_dim: usize) -> Result<Qkv> {
+ let (batch_size, seq_len, _) = qkv.dims3()?;
+ let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
+ let q = qkv.get_on_dim(2, 0)?;
+ let q = q.reshape((batch_size, seq_len, ()))?;
+ let k = qkv.get_on_dim(2, 1)?;
+ let k = k.reshape((batch_size, seq_len, ()))?;
+ let v = qkv.get_on_dim(2, 2)?;
+ Ok(Qkv { q, k, v })
+}
diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs
index 18c8833d..c0de550b 100644
--- a/candle-transformers/src/models/mod.rs
+++ b/candle-transformers/src/models/mod.rs
@@ -32,6 +32,7 @@ pub mod metavoice;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;
+pub mod mmdit;
pub mod mobilenetv4;
pub mod mobileone;
pub mod moondream;