summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorEdwin Cheng <edwin0cheng@gmail.com>2023-12-03 03:59:23 +0800
committerGitHub <noreply@github.com>2023-12-02 19:59:23 +0000
commitdd40edfe73b796ad12f6246319cde2e603c4dc56 (patch)
treec21e67e715a933f272a4f8cc1e082cf609296172 /candle-transformers
parent5aa1a65dab7164ca85b93a0d737589c27a9f4dc1 (diff)
downloadcandle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.gz
candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.bz2
candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.zip
Add Euler Ancestral Discrete Scheduler (#1390)
* Add Euler Ancestral Discrete Scheduler * Fix a bug of init_noise_sigma generation * minor fixes * use partition_point instead of custom bsearch * Fix some clippy lints. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/stable_diffusion/ddim.rs33
-rw-r--r--candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs221
-rw-r--r--candle-transformers/src/models/stable_diffusion/mod.rs1
-rw-r--r--candle-transformers/src/models/stable_diffusion/schedulers.rs16
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs46
5 files changed, 312 insertions, 5 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/ddim.rs b/candle-transformers/src/models/stable_diffusion/ddim.rs
index 916b7349..b9426094 100644
--- a/candle-transformers/src/models/stable_diffusion/ddim.rs
+++ b/candle-transformers/src/models/stable_diffusion/ddim.rs
@@ -7,7 +7,7 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
-use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
+use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@@ -29,6 +29,8 @@ pub struct DDIMSchedulerConfig {
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
+ /// time step spacing for the diffusion process
+ pub timestep_spacing: TimestepSpacing,
}
impl Default for DDIMSchedulerConfig {
@@ -41,6 +43,7 @@ impl Default for DDIMSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
+ timestep_spacing: TimestepSpacing::Leading,
}
}
}
@@ -62,10 +65,30 @@ impl DDIMScheduler {
/// during training.
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
- let timesteps: Vec<usize> = (0..(inference_steps))
- .map(|s| s * step_ratio + config.steps_offset)
- .rev()
- .collect();
+ let timesteps: Vec<usize> = match config.timestep_spacing {
+ TimestepSpacing::Leading => (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect(),
+ TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
+ if *n > step_ratio {
+ Some(n - step_ratio)
+ } else {
+ None
+ }
+ })
+ .map(|n| n - 1)
+ .collect(),
+ TimestepSpacing::Linspace => {
+ super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
+ .to_vec1::<f64>()?
+ .iter()
+ .map(|&f| f as usize)
+ .rev()
+ .collect()
+ }
+ };
+
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
diff --git a/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
new file mode 100644
index 00000000..7acbf040
--- /dev/null
+++ b/candle-transformers/src/models/stable_diffusion/euler_ancestral_discrete.rs
@@ -0,0 +1,221 @@
+//! Ancestral sampling with Euler method steps.
+//!
+//! Reference implemenation in Rust:
+//!
+//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs
+//!
+//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd].
+///
+/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
+use super::{
+ schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType, TimestepSpacing},
+ utils::interp,
+};
+use candle::{bail, Error, Result, Tensor};
+
+/// The configuration for the EulerAncestral Discrete scheduler.
+#[derive(Debug, Clone, Copy)]
+pub struct EulerAncestralDiscreteSchedulerConfig {
+ /// The value of beta at the beginning of training.n
+ pub beta_start: f64,
+ /// The value of beta at the end of training.
+ pub beta_end: f64,
+ /// How beta evolved during training.
+ pub beta_schedule: BetaSchedule,
+ /// Adjust the indexes of the inference schedule by this value.
+ pub steps_offset: usize,
+ /// prediction type of the scheduler function, one of `epsilon` (predicting
+ /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
+ /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ pub prediction_type: PredictionType,
+ /// number of diffusion steps used to train the model
+ pub train_timesteps: usize,
+ /// time step spacing for the diffusion process
+ pub timestep_spacing: TimestepSpacing,
+}
+
+impl Default for EulerAncestralDiscreteSchedulerConfig {
+ fn default() -> Self {
+ Self {
+ beta_start: 0.00085f64,
+ beta_end: 0.012f64,
+ beta_schedule: BetaSchedule::ScaledLinear,
+ steps_offset: 1,
+ prediction_type: PredictionType::Epsilon,
+ train_timesteps: 1000,
+ timestep_spacing: TimestepSpacing::Trailing,
+ }
+ }
+}
+
+/// The EulerAncestral Discrete scheduler.
+#[derive(Debug, Clone)]
+pub struct EulerAncestralDiscreteScheduler {
+ timesteps: Vec<usize>,
+ sigmas: Vec<f64>,
+ init_noise_sigma: f64,
+ pub config: EulerAncestralDiscreteSchedulerConfig,
+}
+
+// clip_sample: False, set_alpha_to_one: False
+impl EulerAncestralDiscreteScheduler {
+ /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be
+ /// used for inference as well as the number of steps that was used
+ /// during training.
+ pub fn new(
+ inference_steps: usize,
+ config: EulerAncestralDiscreteSchedulerConfig,
+ ) -> Result<Self> {
+ let step_ratio = config.train_timesteps / inference_steps;
+ let timesteps: Vec<usize> = match config.timestep_spacing {
+ TimestepSpacing::Leading => (0..(inference_steps))
+ .map(|s| s * step_ratio + config.steps_offset)
+ .rev()
+ .collect(),
+ TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
+ if *n > step_ratio {
+ Some(n - step_ratio)
+ } else {
+ None
+ }
+ })
+ .map(|n| n - 1)
+ .collect(),
+ TimestepSpacing::Linspace => {
+ super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
+ .to_vec1::<f64>()?
+ .iter()
+ .map(|&f| f as usize)
+ .rev()
+ .collect()
+ }
+ };
+
+ let betas = match config.beta_schedule {
+ BetaSchedule::ScaledLinear => super::utils::linspace(
+ config.beta_start.sqrt(),
+ config.beta_end.sqrt(),
+ config.train_timesteps,
+ )?
+ .sqr()?,
+ BetaSchedule::Linear => {
+ super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
+ }
+ BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
+ };
+ let betas = betas.to_vec1::<f64>()?;
+ let mut alphas_cumprod = Vec::with_capacity(betas.len());
+ for &beta in betas.iter() {
+ let alpha = 1.0 - beta;
+ alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
+ }
+ let sigmas: Vec<f64> = alphas_cumprod
+ .iter()
+ .map(|&f| ((1. - f) / f).sqrt())
+ .collect();
+
+ let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();
+
+ let mut sigmas_int = interp(
+ &timesteps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
+ &sigmas_xa,
+ &sigmas,
+ );
+ sigmas_int.push(0.0);
+
+ // standard deviation of the inital noise distribution
+ // f64 does not implement Ord such that there is no `max`, so we need to use this workaround
+ let init_noise_sigma = *sigmas_int
+ .iter()
+ .chain(std::iter::once(&0.0))
+ .reduce(|a, b| if a > b { a } else { b })
+ .expect("init_noise_sigma could not be reduced from sigmas - this should never happen");
+
+ Ok(Self {
+ sigmas: sigmas_int,
+ timesteps,
+ init_noise_sigma,
+ config,
+ })
+ }
+
+ pub fn timesteps(&self) -> &[usize] {
+ self.timesteps.as_slice()
+ }
+
+ /// Ensures interchangeability with schedulers that need to scale the denoising model input
+ /// depending on the current timestep.
+ ///
+ /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
+ pub fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
+ let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
+ Some(i) => i,
+ None => bail!("timestep out of this schedulers bounds: {timestep}"),
+ };
+
+ let sigma = self
+ .sigmas
+ .get(step_index)
+ .expect("step_index out of sigma bounds - this shouldn't happen");
+
+ sample / ((sigma.powi(2) + 1.).sqrt())
+ }
+
+ /// Performs a backward step during inference.
+ pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
+ let step_index = self
+ .timesteps
+ .iter()
+ .position(|&p| p == timestep)
+ .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
+
+ let sigma_from = &self.sigmas[step_index];
+ let sigma_to = &self.sigmas[step_index + 1];
+
+ // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ let pred_original_sample = match self.config.prediction_type {
+ PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
+ PredictionType::VPrediction => {
+ ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
+ + (sample / (sigma_from.powi(2) + 1.0))?)?
+ }
+ PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
+ };
+
+ let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
+ / sigma_from.powi(2))
+ .sqrt();
+ let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();
+
+ // 2. convert to a ODE derivative
+ let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
+ let dt = sigma_down - *sigma_from;
+ let prev_sample = (sample + derivative * dt)?;
+
+ let noise = prev_sample.randn_like(0.0, 1.0)?;
+
+ prev_sample + noise * sigma_up
+ }
+
+ pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
+ let step_index = self
+ .timesteps
+ .iter()
+ .position(|&p| p == timestep)
+ .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
+
+ let sigma = self
+ .sigmas
+ .get(step_index)
+ .expect("step_index out of sigma bounds - this shouldn't happen");
+
+ original + (noise * *sigma)?
+ }
+
+ pub fn init_noise_sigma(&self) -> f64 {
+ match self.config.timestep_spacing {
+ TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
+ TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
+ }
+ }
+}
diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs
index 66ef7149..cad24524 100644
--- a/candle-transformers/src/models/stable_diffusion/mod.rs
+++ b/candle-transformers/src/models/stable_diffusion/mod.rs
@@ -3,6 +3,7 @@ pub mod clip;
pub mod ddim;
pub mod ddpm;
pub mod embeddings;
+pub mod euler_ancestral_discrete;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
diff --git a/candle-transformers/src/models/stable_diffusion/schedulers.rs b/candle-transformers/src/models/stable_diffusion/schedulers.rs
index 3f6a1d72..f414bde7 100644
--- a/candle-transformers/src/models/stable_diffusion/schedulers.rs
+++ b/candle-transformers/src/models/stable_diffusion/schedulers.rs
@@ -25,6 +25,22 @@ pub enum PredictionType {
Sample,
}
+/// Time step spacing for the diffusion process.
+///
+/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+#[derive(Debug, Clone, Copy)]
+pub enum TimestepSpacing {
+ Leading,
+ Linspace,
+ Trailing,
+}
+
+impl Default for TimestepSpacing {
+ fn default() -> Self {
+ Self::Leading
+ }
+}
+
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index cef06f1c..5b5fa0f7 100644
--- a/candle-transformers/src/models/stable_diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
Tensor::from_vec(vs, steps, &Device::Cpu)
}
}
+
+/// A linear interpolator for a sorted array of x and y values.
+struct LinearInterpolator<'x, 'y> {
+ xp: &'x [f64],
+ fp: &'y [f64],
+ cache: usize,
+}
+
+impl<'x, 'y> LinearInterpolator<'x, 'y> {
+ fn accel_find(&mut self, x: f64) -> usize {
+ let xidx = self.cache;
+ if x < self.xp[xidx] {
+ self.cache = self.xp[0..xidx].partition_point(|o| *o < x);
+ self.cache = self.cache.saturating_sub(1);
+ } else if x >= self.xp[xidx + 1] {
+ self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx;
+ self.cache = self.cache.saturating_sub(1);
+ }
+
+ self.cache
+ }
+
+ fn eval(&mut self, x: f64) -> f64 {
+ if x < self.xp[0] || x > self.xp[self.xp.len() - 1] {
+ return f64::NAN;
+ }
+
+ let idx = self.accel_find(x);
+
+ let x_l = self.xp[idx];
+ let x_h = self.xp[idx + 1];
+ let y_l = self.fp[idx];
+ let y_h = self.fp[idx + 1];
+ let dx = x_h - x_l;
+ if dx > 0.0 {
+ y_l + (x - x_l) / dx * (y_h - y_l)
+ } else {
+ f64::NAN
+ }
+ }
+}
+
+pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> {
+ let mut interpolator = LinearInterpolator { xp, fp, cache: 0 };
+ x.iter().map(|&x| interpolator.eval(x)).collect()
+}