diff options
author | Edwin Cheng <edwin0cheng@gmail.com> | 2023-12-03 15:37:10 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-03 08:37:10 +0100 |
commit | 37bf1ed012926b4180d2e9068b84118e1eb6e26d (patch) | |
tree | 73aa11818463fbe5bf39f4978dd6a5538a1a01a5 /candle-transformers | |
parent | dd40edfe73b796ad12f6246319cde2e603c4dc56 (diff) | |
download | candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.gz candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.bz2 candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.zip |
Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo
* Set Leading as default in euler_ancestral discrete
* Use the appropriate default values for n_steps and guidance_scale.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
4 files changed, 168 insertions, 35 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs index b9426094..d804ed56 100644 --- a/candle-transformers/src/models/stable_diffusion/ddim.rs +++ b/candle-transformers/src/models/stable_diffusion/ddim.rs @@ -7,7 +7,9 @@ //! //! Denoising Diffusion Implicit Models, J. Song et al, 2020. //! https://arxiv.org/abs/2010.02502 -use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing}; +use super::schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing, +}; use candle::{Result, Tensor}; /// The configuration for the DDIM scheduler. @@ -48,6 +50,12 @@ impl Default for DDIMSchedulerConfig { } } +impl SchedulerConfig for DDIMSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?)) + } +} + /// The DDIM scheduler. #[derive(Debug, Clone)] pub struct DDIMScheduler { @@ -63,7 +71,7 @@ impl DDIMScheduler { /// Creates a new DDIM scheduler given the number of steps to be /// used for inference as well as the number of steps that was used /// during training. - pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { let step_ratio = config.train_timesteps / inference_steps; let timesteps: Vec<usize> = match config.timestep_spacing { TimestepSpacing::Leading => (0..(inference_steps)) @@ -115,19 +123,11 @@ impl DDIMScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { - self.timesteps.as_slice() - } - - /// Ensures interchangeability with schedulers that need to scale the denoising model input - /// depending on the current timestep. - pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { - Ok(sample) - } - +impl Scheduler for DDIMScheduler { /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -186,7 +186,17 @@ impl DDIMScheduler { } } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> { + Ok(sample) + } + + fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { @@ -197,7 +207,7 @@ impl DDIMScheduler { (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { self.init_noise_sigma } } diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs index 7acbf040..85e86e6e 100644 --- a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs +++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs @@ -8,7 +8,10 @@ /// /// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 use super::{ - schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing}, + schedulers::{ + betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, + TimestepSpacing, + }, utils::interp, }; use candle::{bail, Error, Result, Tensor}; @@ -43,11 +46,20 @@ impl Default for EulerAncestralDiscreteSchedulerConfig { steps_offset: 1, prediction_type: PredictionType::Epsilon, train_timesteps: 1000, - timestep_spacing: TimestepSpacing::Trailing, + timestep_spacing: TimestepSpacing::Leading, } } } +impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> { + Ok(Box::new(EulerAncestralDiscreteScheduler::new( + inference_steps, + *self, + )?)) + } +} + /// The EulerAncestral Discrete scheduler. #[derive(Debug, Clone)] pub struct EulerAncestralDiscreteScheduler { @@ -138,8 +150,10 @@ impl EulerAncestralDiscreteScheduler { config, }) } +} - pub fn timesteps(&self) -> &[usize] { +impl Scheduler for EulerAncestralDiscreteScheduler { + fn timesteps(&self) -> &[usize] { self.timesteps.as_slice() } @@ -147,7 +161,7 @@ impl EulerAncestralDiscreteScheduler { /// depending on the current timestep. /// /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm - pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> { + fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> { let step_index = match self.timesteps.iter().position(|&t| t == timestep) { Some(i) => i, None => bail!("timestep out of this schedulers bounds: {timestep}"), @@ -162,7 +176,7 @@ impl EulerAncestralDiscreteScheduler { } /// Performs a backward step during inference. - pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { let step_index = self .timesteps .iter() @@ -197,7 +211,7 @@ impl EulerAncestralDiscreteScheduler { prev_sample + noise * sigma_up } - pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { let step_index = self .timesteps .iter() @@ -212,7 +226,7 @@ impl EulerAncestralDiscreteScheduler { original + (noise * *sigma)? } - pub fn init_noise_sigma(&self) -> f64 { + fn init_noise_sigma(&self) -> f64 { match self.config.timestep_spacing { TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma, TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(), diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index cad24524..30f23975 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -11,9 +11,13 @@ pub mod unet_2d_blocks; pub mod utils; pub mod vae; +use std::sync::Arc; + use candle::{DType, Device, Result}; use candle_nn as nn; +use self::schedulers::{Scheduler, SchedulerConfig}; + #[derive(Clone, Debug)] pub struct StableDiffusionConfig { pub width: usize, @@ -22,7 +26,7 @@ pub struct StableDiffusionConfig { pub clip2: Option<clip::Config>, autoencoder: vae::AutoEncoderKLConfig, unet: unet_2d::UNet2DConditionModelConfig, - scheduler: ddim::DDIMSchedulerConfig, + scheduler: Arc<dyn SchedulerConfig>, } impl StableDiffusionConfig { @@ -76,13 +80,18 @@ impl StableDiffusionConfig { 512 }; - Self { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { + prediction_type: schedulers::PredictionType::Epsilon, + ..Default::default() + }); + + StableDiffusionConfig { width, height, clip: clip::Config::v1_5(), clip2: None, autoencoder, - scheduler: Default::default(), + scheduler, unet, } } @@ -125,10 +134,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -144,7 +153,7 @@ impl StableDiffusionConfig { 768 }; - Self { + StableDiffusionConfig { width, height, clip: clip::Config::v2_1(), @@ -206,10 +215,10 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -225,6 +234,76 @@ impl StableDiffusionConfig { 1024 }; + StableDiffusionConfig { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + fn sdxl_turbo_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: schedulers::PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = Arc::new( + euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { + prediction_type, + timestep_spacing: schedulers::TimestepSpacing::Trailing, + ..Default::default() + }, + ); + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + Self { width, height, @@ -250,6 +329,20 @@ impl StableDiffusionConfig { ) } + pub fn sdxl_turbo( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + Self::sdxl_turbo_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json + schedulers::PredictionType::Epsilon, + ) + } + pub fn ssd1b( sliced_attention_size: Option<usize>, height: Option<usize>, @@ -286,9 +379,9 @@ impl StableDiffusionConfig { latent_channels: 4, norm_num_groups: 32, }; - let scheduler = ddim::DDIMSchedulerConfig { + let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() - }; + }); let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -348,8 +441,8 @@ impl StableDiffusionConfig { Ok(unet) } - pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { - ddim::DDIMScheduler::new(n_steps, self.scheduler) + pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> { + self.scheduler.build(n_steps) } } diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs index f414bde7..0f0441e0 100644 --- a/candle-transformers/src/models/stable_diffusion/schedulers.rs +++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs @@ -3,9 +3,25 @@ //! //! Noise schedulers can be used to set the trade-off between //! inference speed and quality. - use candle::{Result, Tensor}; +pub trait SchedulerConfig: std::fmt::Debug { + fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>; +} + +/// This trait represents a scheduler for the diffusion process. +pub trait Scheduler { + fn timesteps(&self) -> &[usize]; + + fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>; + + fn init_noise_sigma(&self) -> f64; + + fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>; + + fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>; +} + /// This represents how beta ranges from its minimum value to the maximum /// during training. #[derive(Debug, Clone, Copy)] |