summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-11-01 10:10:40 -0700
committerGitHub <noreply@github.com>2024-11-01 18:10:40 +0100
commit530ab96036604b125276433b67ebb840e841aede (patch)
tree86042bf29e8c4c0d6fe0ad14e907aefc2e359d83
parent7ac0de15a9fafe59d9f97fb6d90662790488433e (diff)
downloadcandle-530ab96036604b125276433b67ebb840e841aede.tar.gz
candle-530ab96036604b125276433b67ebb840e841aede.tar.bz2
candle-530ab96036604b125276433b67ebb840e841aede.zip
Support Skip Layer Guidance (SLG) for Stable Diffusion 3.5 Medium (#2590)
* support skip layer guidance (slg) for stable diffusion 3.5 medium * Tweak the comments formatting. * Proper error message. * Cosmetic tweaks. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs27
-rw-r--r--candle-examples/examples/stable-diffusion-3/sampling.rs36
-rw-r--r--candle-transformers/src/models/mmdit/model.rs26
3 files changed, 79 insertions, 10 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
index 9ad057e3..8c9a78d2 100644
--- a/candle-examples/examples/stable-diffusion-3/main.rs
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -75,14 +75,19 @@ struct Args {
#[arg(long)]
num_inference_steps: Option<usize>,
- // CFG scale.
+ /// CFG scale.
#[arg(long)]
cfg_scale: Option<f64>,
- // Time shift factor (alpha).
+ /// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
time_shift: f64,
+ /// Use Skip Layer Guidance (SLG) for the sampling.
+ /// Currently only supports Stable Diffusion 3.5 Medium.
+ #[arg(long)]
+ use_slg: bool,
+
/// The seed to use when generating random samples.
#[arg(long)]
seed: Option<u64>,
@@ -105,6 +110,7 @@ fn main() -> Result<()> {
time_shift,
seed,
which,
+ use_slg,
} = Args::parse();
let _guard = if tracing {
@@ -211,6 +217,22 @@ fn main() -> Result<()> {
if let Some(seed) = seed {
device.set_seed(seed)?;
}
+
+ let slg_config = if use_slg {
+ match which {
+ // https://github.com/Stability-AI/sd3.5/blob/4e484e05308d83fb77ae6f680028e6c313f9da54/sd3_infer.py#L388-L394
+ Which::V3_5Medium => Some(sampling::SkipLayerGuidanceConfig {
+ scale: 2.5,
+ start: 0.01,
+ end: 0.2,
+ layers: vec![7, 8, 9],
+ }),
+ _ => anyhow::bail!("--use-slg can only be used with 3.5-medium"),
+ }
+ } else {
+ None
+ };
+
let start_time = std::time::Instant::now();
let x = {
let mmdit = MMDiT::new(
@@ -227,6 +249,7 @@ fn main() -> Result<()> {
time_shift,
height,
width,
+ slg_config,
)?
};
let dt = start_time.elapsed().as_secs_f32();
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs
index cd881b6a..5e234371 100644
--- a/candle-examples/examples/stable-diffusion-3/sampling.rs
+++ b/candle-examples/examples/stable-diffusion-3/sampling.rs
@@ -1,8 +1,15 @@
use anyhow::{Ok, Result};
-use candle::{DType, Tensor};
+use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
-use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
+use candle_transformers::models::mmdit::model::MMDiT;
+
+pub struct SkipLayerGuidanceConfig {
+ pub scale: f64,
+ pub start: f64,
+ pub end: f64,
+ pub layers: Vec<usize>,
+}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
@@ -14,6 +21,7 @@ pub fn euler_sample(
time_shift: f64,
height: usize,
width: usize,
+ slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
@@ -22,7 +30,7 @@ pub fn euler_sample(
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
- for window in sigmas.windows(2) {
+ for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
@@ -34,8 +42,28 @@ pub fn euler_sample(
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
+ None,
)?;
- x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
+
+ let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
+
+ if let Some(slg_config) = slg_config.as_ref() {
+ if (num_inference_steps as f64) * slg_config.start < (step as f64)
+ && (step as f64) < (num_inference_steps as f64) * slg_config.end
+ {
+ let slg_noise_pred = mmdit.forward(
+ &x,
+ &Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
+ &y.i(..1)?,
+ &context.i(..1)?,
+ Some(&slg_config.layers),
+ )?;
+ guidance = (guidance
+ + (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
+ }
+ }
+
+ x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index c7b4deed..21897aa3 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -130,7 +130,14 @@ impl MMDiT {
})
}
- pub fn forward(&self, x: &Tensor, t: &Tensor, y: &Tensor, context: &Tensor) -> Result<Tensor> {
+ pub fn forward(
+ &self,
+ x: &Tensor,
+ t: &Tensor,
+ y: &Tensor,
+ context: &Tensor,
+ skip_layers: Option<&[usize]>,
+ ) -> Result<Tensor> {
// Following the convention of the ComfyUI implementation.
// https://github.com/comfyanonymous/ComfyUI/blob/78e133d0415784924cd2674e2ee48f3eeca8a2aa/comfy/ldm/modules/diffusionmodules/mmdit.py#L919
//
@@ -150,7 +157,7 @@ impl MMDiT {
let c = (c + y)?;
let context = self.context_embedder.forward(context)?;
- let x = self.core.forward(&context, &x, &c)?;
+ let x = self.core.forward(&context, &x, &c, skip_layers)?;
let x = self.unpatchifier.unpatchify(&x, h, w)?;
x.narrow(2, 0, h)?.narrow(3, 0, w)
}
@@ -211,9 +218,20 @@ impl MMDiTCore {
})
}
- pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<Tensor> {
+ pub fn forward(
+ &self,
+ context: &Tensor,
+ x: &Tensor,
+ c: &Tensor,
+ skip_layers: Option<&[usize]>,
+ ) -> Result<Tensor> {
let (mut context, mut x) = (context.clone(), x.clone());
- for joint_block in &self.joint_blocks {
+ for (i, joint_block) in self.joint_blocks.iter().enumerate() {
+ if let Some(skip_layers) = &skip_layers {
+ if skip_layers.contains(&i) {
+ continue;
+ }
+ }
(context, x) = joint_block.forward(&context, &x, c)?;
}
let x = self.context_qkv_only_joint_block.forward(&context, &x, c)?;