summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion-3/sampling.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/sampling.rs')
-rw-r--r--candle-examples/examples/stable-diffusion-3/sampling.rs36
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs
index cd881b6a..5e234371 100644
--- a/candle-examples/examples/stable-diffusion-3/sampling.rs
+++ b/candle-examples/examples/stable-diffusion-3/sampling.rs
@@ -1,8 +1,15 @@
use anyhow::{Ok, Result};
-use candle::{DType, Tensor};
+use candle::{DType, IndexOp, Tensor};
use candle_transformers::models::flux;
-use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
+use candle_transformers::models::mmdit::model::MMDiT;
+
+pub struct SkipLayerGuidanceConfig {
+ pub scale: f64,
+ pub start: f64,
+ pub end: f64,
+ pub layers: Vec<usize>,
+}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
@@ -14,6 +21,7 @@ pub fn euler_sample(
time_shift: f64,
height: usize,
width: usize,
+ slg_config: Option<SkipLayerGuidanceConfig>,
) -> Result<Tensor> {
let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
let sigmas = (0..=num_inference_steps)
@@ -22,7 +30,7 @@ pub fn euler_sample(
.map(|x| time_snr_shift(time_shift, x))
.collect::<Vec<f64>>();
- for window in sigmas.windows(2) {
+ for (step, window) in sigmas.windows(2).enumerate() {
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
@@ -34,8 +42,28 @@ pub fn euler_sample(
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
+ None,
)?;
- x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
+
+ let mut guidance = apply_cfg(cfg_scale, &noise_pred)?;
+
+ if let Some(slg_config) = slg_config.as_ref() {
+ if (num_inference_steps as f64) * slg_config.start < (step as f64)
+ && (step as f64) < (num_inference_steps as f64) * slg_config.end
+ {
+ let slg_noise_pred = mmdit.forward(
+ &x,
+ &Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
+ &y.i(..1)?,
+ &context.i(..1)?,
+ Some(&slg_config.layers),
+ )?;
+ guidance = (guidance
+ + (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?;
+ }
+ }
+
+ x = (x + (guidance * (*s_prev - *s_curr))?)?;
}
Ok(x)
}