diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 21:44:00 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 20:44:00 +0100 |
commit | 1c062bf06ba504a076b329c965c625be0ec67c1d (patch) | |
tree | 299cd6603c49726bc72596e416657d9f336ae82b /candle-examples/examples/stable-diffusion | |
parent | d34039e35267b3f4de83770f8da4ea31491bcec5 (diff) | |
download | candle-1c062bf06ba504a076b329c965c625be0ec67c1d.tar.gz candle-1c062bf06ba504a076b329c965c625be0ec67c1d.tar.bz2 candle-1c062bf06ba504a076b329c965c625be0ec67c1d.zip |
Add the ddim scheduler. (#330)
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
5 files changed, 445 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion/ddim.rs b/candle-examples/examples/stable-diffusion/ddim.rs new file mode 100644 index 00000000..9afff5aa --- /dev/null +++ b/candle-examples/examples/stable-diffusion/ddim.rs @@ -0,0 +1,181 @@ +#![allow(dead_code)] +//! # Denoising Diffusion Implicit Models +//! +//! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler +//! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM +//! generative process is the reverse of a Markovian process, DDIM generalizes +//! this to non-Markovian guidance. +//! +//! Denoising Diffusion Implicit Models, J. Song et al, 2020. +//! https://arxiv.org/abs/2010.02502 +use crate::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType}; +use candle::{Result, Tensor}; + +/// The configuration for the DDIM scheduler. +#[derive(Debug, Clone, Copy)] +pub struct DDIMSchedulerConfig { + /// The value of beta at the beginning of training. + pub beta_start: f64, + /// The value of beta at the end of training. + pub beta_end: f64, + /// How beta evolved during training. + pub beta_schedule: BetaSchedule, + /// The amount of noise to be added at each step. + pub eta: f64, + /// Adjust the indexes of the inference schedule by this value. + pub steps_offset: usize, + /// prediction type of the scheduler function, one of `epsilon` (predicting + /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) + /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) + pub prediction_type: PredictionType, + /// number of diffusion steps used to train the model + pub train_timesteps: usize, +} + +impl Default for DDIMSchedulerConfig { + fn default() -> Self { + Self { + beta_start: 0.00085f64, + beta_end: 0.012f64, + beta_schedule: BetaSchedule::ScaledLinear, + eta: 0., + steps_offset: 1, + prediction_type: PredictionType::Epsilon, + train_timesteps: 1000, + } + } +} + +/// The DDIM scheduler. +#[derive(Debug, Clone)] +pub struct DDIMScheduler { + timesteps: Vec<usize>, + alphas_cumprod: Vec<f64>, + step_ratio: usize, + init_noise_sigma: f64, + pub config: DDIMSchedulerConfig, +} + +// clip_sample: False, set_alpha_to_one: False +impl DDIMScheduler { + /// Creates a new DDIM scheduler given the number of steps to be + /// used for inference as well as the number of steps that was used + /// during training. + pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> { + let step_ratio = config.train_timesteps / inference_steps; + let timesteps: Vec<usize> = (0..(inference_steps)) + .map(|s| s * step_ratio + config.steps_offset) + .rev() + .collect(); + let betas = match config.beta_schedule { + BetaSchedule::ScaledLinear => crate::utils::linspace( + config.beta_start.sqrt(), + config.beta_end.sqrt(), + config.train_timesteps, + )? + .sqr()?, + BetaSchedule::Linear => { + crate::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)? + } + BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?, + }; + let betas = betas.to_vec1::<f64>()?; + let mut alphas_cumprod = Vec::with_capacity(betas.len()); + for &beta in betas.iter() { + let alpha = 1.0 - beta; + alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64)) + } + Ok(Self { + alphas_cumprod, + timesteps, + step_ratio, + init_noise_sigma: 1., + config, + }) + } + + pub fn timesteps(&self) -> &[usize] { + self.timesteps.as_slice() + } + + /// Ensures interchangeability with schedulers that need to scale the denoising model input + /// depending on the current timestep. + pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { + sample + } + + /// Performs a backward step during inference. + pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195 + let prev_timestep = if timestep > self.step_ratio { + timestep - self.step_ratio + } else { + 0 + }; + + let alpha_prod_t = self.alphas_cumprod[timestep]; + let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; + let beta_prod_t = 1. - alpha_prod_t; + let beta_prod_t_prev = 1. - alpha_prod_t_prev; + + let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { + PredictionType::Epsilon => { + let pred_original_sample = ((sample - (model_output * beta_prod_t.sqrt())?)? + * (1. / alpha_prod_t.sqrt()))?; + (pred_original_sample, model_output.clone()) + } + PredictionType::VPrediction => { + let pred_original_sample = + ((sample * alpha_prod_t.sqrt())? - (model_output * beta_prod_t.sqrt())?)?; + let pred_epsilon = + ((model_output * alpha_prod_t.sqrt())? + (sample * beta_prod_t.sqrt())?)?; + (pred_original_sample, pred_epsilon) + } + PredictionType::Sample => { + let pred_original_sample = model_output.clone(); + let pred_epsilon = ((sample - &pred_original_sample * alpha_prod_t.sqrt())? + * (1. / beta_prod_t.sqrt()))?; + (pred_original_sample, pred_epsilon) + } + }; + + let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev); + let std_dev_t = self.config.eta * variance.sqrt(); + + let pred_sample_direction = + (pred_epsilon * (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt())?; + let prev_sample = + ((pred_original_sample * alpha_prod_t_prev.sqrt())? + pred_sample_direction)?; + if self.config.eta > 0. { + &prev_sample + + Tensor::randn( + 0f32, + std_dev_t as f32, + prev_sample.shape(), + prev_sample.device(), + )? + } else { + Ok(prev_sample) + } + } + + pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> { + let timestep = if timestep >= self.alphas_cumprod.len() { + timestep - 1 + } else { + timestep + }; + let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); + let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); + (original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)? + } + + pub fn init_noise_sigma(&self) -> f64 { + self.init_noise_sigma + } +} diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 31848f38..33a1dc71 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -3,8 +3,11 @@ extern crate intel_mkl_src; mod attention; mod clip; +mod ddim; mod embeddings; mod resnet; +mod schedulers; +mod stable_diffusion; mod unet_2d; mod unet_2d_blocks; mod utils; diff --git a/candle-examples/examples/stable-diffusion/schedulers.rs b/candle-examples/examples/stable-diffusion/schedulers.rs new file mode 100644 index 00000000..3f6a1d72 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/schedulers.rs @@ -0,0 +1,45 @@ +#![allow(dead_code)] +//! # Diffusion pipelines and models +//! +//! Noise schedulers can be used to set the trade-off between +//! inference speed and quality. + +use candle::{Result, Tensor}; + +/// This represents how beta ranges from its minimum value to the maximum +/// during training. +#[derive(Debug, Clone, Copy)] +pub enum BetaSchedule { + /// Linear interpolation. + Linear, + /// Linear interpolation of the square root of beta. + ScaledLinear, + /// Glide cosine schedule + SquaredcosCapV2, +} + +#[derive(Debug, Clone, Copy)] +pub enum PredictionType { + Epsilon, + VPrediction, + Sample, +} + +/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of +/// `(1-beta)` over time from `t = [0,1]`. +/// +/// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` +/// up to that part of the diffusion process. +pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> { + let alpha_bar = |time_step: usize| { + f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) + }; + let mut betas = Vec::with_capacity(num_diffusion_timesteps); + for i in 0..num_diffusion_timesteps { + let t1 = i / num_diffusion_timesteps; + let t2 = (i + 1) / num_diffusion_timesteps; + betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta)); + } + let betas_len = betas.len(); + Tensor::from_vec(betas, betas_len, &candle::Device::Cpu) +} diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs new file mode 100644 index 00000000..c250ed56 --- /dev/null +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -0,0 +1,212 @@ +#![allow(dead_code)] +use crate::schedulers::PredictionType; +use crate::{clip, ddim, unet_2d, vae}; +use candle::{DType, Device, Result}; +use candle_nn as nn; + +#[derive(Clone, Debug)] +pub struct StableDiffusionConfig { + pub width: usize, + pub height: usize, + pub clip: clip::Config, + autoencoder: vae::AutoEncoderKLConfig, + unet: unet_2d::UNet2DConditionModelConfig, + scheduler: ddim::DDIMSchedulerConfig, +} + +impl StableDiffusionConfig { + pub fn v1_5( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, true, 8), + bc(640, true, 8), + bc(1280, true, 8), + bc(1280, false, 8), + ], + center_input_sample: false, + cross_attention_dim: 768, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: false, + }; + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + height + } else { + 512 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 512 + }; + + Self { + width, + height, + clip: clip::Config::v1_5(), + autoencoder, + scheduler: Default::default(), + unet, + } + } + + fn v2_1_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, true, 5), + bc(640, true, 10), + bc(1280, true, 20), + bc(1280, false, 20), + ], + center_input_sample: false, + cross_attention_dim: 1024, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + prediction_type, + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + height + } else { + 768 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 768 + }; + + Self { + width, + height, + clip: clip::Config::v2_1(), + autoencoder, + scheduler, + unet, + } + } + + pub fn v2_1( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json + Self::v2_1_( + sliced_attention_size, + height, + width, + PredictionType::VPrediction, + ) + } + + pub fn v2_1_inpaint( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json + // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction + // type being "epsilon" by default and not "v_prediction". + Self::v2_1_( + sliced_attention_size, + height, + width, + PredictionType::Epsilon, + ) + } + + pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> { + let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; + let weights = weights.deserialize()?; + let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; + Ok(autoencoder) + } + + pub fn build_unet( + &self, + unet_weights: &str, + device: &Device, + in_channels: usize, + ) -> Result<unet_2d::UNet2DConditionModel> { + let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; + let weights = weights.deserialize()?; + let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?; + Ok(unet) + } + + pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { + ddim::DDIMScheduler::new(n_steps, self.scheduler) + } + + pub fn build_clip_transformer( + &self, + clip_weights: &str, + device: &Device, + ) -> Result<clip::ClipTextTransformer> { + let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; + let weights = weights.deserialize()?; + let vs = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); + let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?; + Ok(text_model) + } +} diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index aa49e560..50ee48e9 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -15,3 +15,7 @@ pub fn pad(_: &Tensor) -> Result<Tensor> { pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> { todo!() } + +pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> { + todo!() +} |