From 530ab96036604b125276433b67ebb840e841aede Mon Sep 17 00:00:00 2001
From: Czxck001 <10724409+Czxck001@users.noreply.github.com>
Date: Fri, 1 Nov 2024 10:10:40 -0700
Subject: Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium
 (#2590)

* support skip layer guidance (slg) for stable diffusion 3.5 medium

* Tweak the comments formatting.

* Proper error message.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
---
 .../examples/stable-diffusion-3/main.rs            | 27 ++++++++++++++--
 .../examples/stable-diffusion-3/sampling.rs        | 36 +++++++++++++++++++---
 candle-transformers/src/models/mmdit/model.rs      | 26 +++++++++++++---
 3 files changed, 79 insertions(+), 10 deletions(-)

diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 9ad057e3..8c9a78d2 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -75,14 +75,19 @@ struct Args {
     #[arg(long)]
     num_inference_steps: Option<usize>,
 
-    // CFG scale.
+    /// CFG scale.
     #[arg(long)]
     cfg_scale: Option<f64>,
 
-    // Time shift factor (alpha).
+    /// Time shift factor (alpha).
     #[arg(long, default_value_t = 3.0)]
     time_shift: f64,
 
+    /// Use Skip Layer Guidance (SLG) for the sampling.
+    /// Currently only supports Stable Diffusion 3.5 Medium.
+    #[arg(long)]
+    use_slg: bool,
+
     /// The seed to use when generating random samples.
     #[arg(long)]
     seed: Option<u64>,
@@ -105,6 +110,7 @@ fn main() -> Result<()> {
         time_shift,
         seed,
         which,
+        use_slg,
     } = Args::parse();
 
     let _guard = if tracing {
@@ -211,6 +217,22 @@ fn main() -> Result<()> {
     if let Some(seed) = seed {
         device.set_seed(seed)?;
     }
+
+    let slg_config = if use_slg {
+        match which {
+            // https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
+            Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
+                scale: 2.5,
+                start: 0.01,
+                end: 0.2,
+                layers: vec![7, 8, 9],
+            }),
+            _ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
+        }
+    } else {
+        None
+    };
+
     let start_time = std::time::Instant::now();
     let x = {
         let mmdit = MMDiT::new(
@@ -227,6 +249,7 @@ fn main() -> Result<()> {
             time_shift,
             height,
             width,
+            slg_config,
         )?
     };
     let dt = start_time.elapsed().as_secs_f32();
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs
index cd881b6a..5e234371 100644
--- a/candle-examples/examples/stable-diffusion-3/sampling.rs
+++ b/candle-examples/examples/stable-diffusion-3/sampling.rs
@@ -1,8 +1,15 @@
 use anyhow::{Ok, Result};
-use candle::{DType, Tensor};
+use candle::{DType, IndexOp, Tensor};
 
 use candle_transformers::models::flux;
-use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
+use candle_transformers::models::mmdit::model::MMDiT;
+
+pub struct SkipLayerGuidanceConfig {
+    pub scale: f64,
+    pub start: f64,
+    pub end: f64,
+    pub layers: Vec<usize>,
+}
 
 #[allow(clippy::too_many_arguments)]
 pub fn euler_sample(
@@ -14,6 +21,7 @@ pub fn euler_sample(
     time_shift: f64,
     height: usize,
     width: usize,
+    slg_config: Option<SkipLayerGuidanceConfig>,
 ) -> Result<Tensor> {
     let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
     let sigmas = (0..=num_inference_steps)
@@ -22,7 +30,7 @@ pub fn euler_sample(
         .map(|x| time_snr_shift(time_shift, x))
         .collect::<Vec<f64>>();
 
-    for window in sigmas.windows(2) {
+    for (step, window) in sigmas.windows(2).enumerate() {
         let (s_curr, s_prev) = match window {
             [a, b] => (a, b),
             _ => continue,
@@ -34,8 +42,28 @@ pub fn euler_sample(
             &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
             y,
             context,
+            None,
         )?;
-        x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
+
+        let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
+
+        if let Some(slg_config) = slg_config.as_ref() {
+            if (num_inference_steps as f64) * slg_config.start < (step as f64)
+                && (step as f64) < (num_inference_steps as f64) * slg_config.end
+            {
+                let slg_noise_pred = mmdit.forward(
+                    &x,
+                    &Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
+                    &y.i(..1)?,
+                    &context.i(..1)?,
+                    Some(&slg_config.layers),
+                )?;
+                guidance = (guidance
+                    + (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
+            }
+        }
+
+        x = (x + (guidance * (*s_prev - *s_curr))?)?;
     }
     Ok(x)
 }
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index c7b4deed..21897aa3 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -130,7 +130,14 @@ impl MMDiT {
         })
     }
 
-    pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
+    pub fn forward(
+        &self,
+        x: &Tensor,
+        t: &Tensor,
+        y: &Tensor,
+        context: &Tensor,
+        skip_layers: Option<&[usize]>,
+    ) -> Result<Tensor> {
         // Following the convention of the ComfyUI implementation.
         // https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
         //
@@ -150,7 +157,7 @@ impl MMDiT {
         let c = (c + y)?;
         let context = self.context_embedder.forward(context)?;
 
-        let x = self.core.forward(&context, &x, &c)?;
+        let x = self.core.forward(&context, &x, &c, skip_layers)?;
         let x = self.unpatchifier.unpatchify(&x, h, w)?;
         x.narrow(2, 0, h)?.narrow(3, 0, w)
     }
@@ -211,9 +218,20 @@ impl MMDiTCore {
         })
     }
 
-    pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+    pub fn forward(
+        &self,
+        context: &Tensor,
+        x: &Tensor,
+        c: &Tensor,
+        skip_layers: Option<&[usize]>,
+    ) -> Result<Tensor> {
         let (mut context, mut x) = (context.clone(), x.clone());
-        for joint_block in &self.joint_blocks {
+        for (i, joint_block) in self.joint_blocks.iter().enumerate() {
+            if let Some(skip_layers) = &skip_layers {
+                if skip_layers.contains(&i) {
+                    continue;
+                }
+            }
             (context, x) = joint_block.forward(&context, &x, c)?;
         }
         let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;
-- 
cgit v1.2.3