diff options
Diffstat (limited to 'candle-transformers/src/models/mmdit/model.rs')
-rw-r--r-- | candle-transformers/src/models/mmdit/model.rs | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index c7b4deed..21897aa3 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -130,7 +130,14 @@ impl MMDiT { }) } - pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> { + pub fn forward( + &self, + x: &Tensor, + t: &Tensor, + y: &Tensor, + context: &Tensor, + skip_layers: Option<&[usize]>, + ) -> Result<Tensor> { // Following the convention of the ComfyUI implementation. // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919 // @@ -150,7 +157,7 @@ impl MMDiT { let c = (c + y)?; let context = self.context_embedder.forward(context)?; - let x = self.core.forward(&context, &x, &c)?; + let x = self.core.forward(&context, &x, &c, skip_layers)?; let x = self.unpatchifier.unpatchify(&x, h, w)?; x.narrow(2, 0, h)?.narrow(3, 0, w) } @@ -211,9 +218,20 @@ impl MMDiTCore { }) } - pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> { + pub fn forward( + &self, + context: &Tensor, + x: &Tensor, + c: &Tensor, + skip_layers: Option<&[usize]>, + ) -> Result<Tensor> { let (mut context, mut x) = (context.clone(), x.clone()); - for joint_block in &self.joint_blocks { + for (i, joint_block) in self.joint_blocks.iter().enumerate() { + if let Some(skip_layers) = &skip_layers { + if skip_layers.contains(&i) { + continue; + } + } (context, x) = joint_block.forward(&context, &x, c)?; } let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?; |