summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mmdit/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mmdit/model.rs')
-rw-r--r--candle-transformers/src/models/mmdit/model.rs26
1 files changed, 22 insertions, 4 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index c7b4deed..21897aa3 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -130,7 +130,14 @@ impl MMDiT {
})
}
- pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
+ pub fn forward(
+ &self,
+ x: &Tensor,
+ t: &Tensor,
+ y: &Tensor,
+ context: &Tensor,
+ skip_layers: Option<&[usize]>,
+ ) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
@@ -150,7 +157,7 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
- let x = self.core.forward(&context, &x, &c)?;
+ let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
@@ -211,9 +218,20 @@ impl MMDiTCore {
})
}
- pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+ pub fn forward(
+ &self,
+ context: &Tensor,
+ x: &Tensor,
+ c: &Tensor,
+ skip_layers: Option<&[usize]>,
+ ) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
- for joint_block in &self.joint_blocks {
+ for (i, joint_block) in self.joint_blocks.iter().enumerate() {
+ if let Some(skip_layers) = &skip_layers {
+ if skip_layers.contains(&i) {
+ continue;
+ }
+ }
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;