diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/sampling.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/sampling.rs | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs index cd881b6a..5e234371 100644 --- a/candle-examples/examples/stable-diffusion-3/sampling.rs +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -1,8 +1,15 @@ use anyhow::{Ok, Result}; -use candle::{DType, Tensor}; +use candle::{DType, IndexOp, Tensor}; use candle_transformers::models::flux; -use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function +use candle_transformers::models::mmdit::model::MMDiT; + +pub struct SkipLayerGuidanceConfig { + pub scale: f64, + pub start: f64, + pub end: f64, + pub layers: Vec<usize>, +} #[allow(clippy::too_many_arguments)] pub fn euler_sample( @@ -14,6 +21,7 @@ pub fn euler_sample( time_shift: f64, height: usize, width: usize, + slg_config: Option<SkipLayerGuidanceConfig>, ) -> Result<Tensor> { let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?; let sigmas = (0..=num_inference_steps) @@ -22,7 +30,7 @@ pub fn euler_sample( .map(|x| time_snr_shift(time_shift, x)) .collect::<Vec<f64>>(); - for window in sigmas.windows(2) { + for (step, window) in sigmas.windows(2).enumerate() { let (s_curr, s_prev) = match window { [a, b] => (a, b), _ => continue, @@ -34,8 +42,28 @@ pub fn euler_sample( &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, y, context, + None, )?; - x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?; + + let mut guidance = apply_cfg(cfg_scale, &noise_pred)?; + + if let Some(slg_config) = slg_config.as_ref() { + if (num_inference_steps as f64) * slg_config.start < (step as f64) + && (step as f64) < (num_inference_steps as f64) * slg_config.end + { + let slg_noise_pred = mmdit.forward( + &x, + &Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?, + &y.i(..1)?, + &context.i(..1)?, + Some(&slg_config.layers), + )?; + guidance = (guidance + + (slg_config.scale * (noise_pred.i(..1)? - slg_noise_pred.i(..1))?)?)?; + } + } + + x = (x + (guidance * (*s_prev - *s_curr))?)?; } Ok(x) } |