diff options
author | Edwin Cheng <edwin0cheng@gmail.com> | 2023-12-03 03:59:23 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-02 19:59:23 +0000 |
commit | dd40edfe73b796ad12f6246319cde2e603c4dc56 (patch) | |
tree | c21e67e715a933f272a4f8cc1e082cf609296172 /candle-transformers/src/models/stable_diffusion/utils.rs | |
parent | 5aa1a65dab7164ca85b93a0d737589c27a9f4dc1 (diff) | |
download | candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.gz candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.bz2 candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.zip |
Add Euler Ancestral Discrete Scheduler (#1390)
* Add Euler Ancestral Discrete Scheduler
* Fix a bug of init_noise_sigma generation
* minor fixes
* use partition_point instead of custom bsearch
* Fix some clippy lints.
---------
Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/utils.rs')
-rw-r--r-- | candle-transformers/src/models/stable_diffusion/utils.rs | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs index cef06f1c..5b5fa0f7 100644 --- a/candle-transformers/src/models/stable_diffusion/utils.rs +++ b/candle-transformers/src/models/stable_diffusion/utils.rs @@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> { Tensor::from_vec(vs, steps, &Device::Cpu) } } + +/// A linear interpolator for a sorted array of x and y values. +struct LinearInterpolator<'x, 'y> { + xp: &'x [f64], + fp: &'y [f64], + cache: usize, +} + +impl<'x, 'y> LinearInterpolator<'x, 'y> { + fn accel_find(&mut self, x: f64) -> usize { + let xidx = self.cache; + if x < self.xp[xidx] { + self.cache = self.xp[0..xidx].partition_point(|o| *o < x); + self.cache = self.cache.saturating_sub(1); + } else if x >= self.xp[xidx + 1] { + self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx; + self.cache = self.cache.saturating_sub(1); + } + + self.cache + } + + fn eval(&mut self, x: f64) -> f64 { + if x < self.xp[0] || x > self.xp[self.xp.len() - 1] { + return f64::NAN; + } + + let idx = self.accel_find(x); + + let x_l = self.xp[idx]; + let x_h = self.xp[idx + 1]; + let y_l = self.fp[idx]; + let y_h = self.fp[idx + 1]; + let dx = x_h - x_l; + if dx > 0.0 { + y_l + (x - x_l) / dx * (y_h - y_l) + } else { + f64::NAN + } + } +} + +pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> { + let mut interpolator = LinearInterpolator { xp, fp, cache: 0 }; + x.iter().map(|&x| interpolator.eval(x)).collect() +} |