summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/stable_diffusion/utils.rs
diff options
context:
space:
mode:
authorEdwin Cheng <edwin0cheng@gmail.com>2023-12-03 03:59:23 +0800
committerGitHub <noreply@github.com>2023-12-02 19:59:23 +0000
commitdd40edfe73b796ad12f6246319cde2e603c4dc56 (patch)
treec21e67e715a933f272a4f8cc1e082cf609296172 /candle-transformers/src/models/stable_diffusion/utils.rs
parent5aa1a65dab7164ca85b93a0d737589c27a9f4dc1 (diff)
downloadcandle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.gz
candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.tar.bz2
candle-dd40edfe73b796ad12f6246319cde2e603c4dc56.zip
Add Euler Ancestral Discrete Scheduler (#1390)
* Add Euler Ancestral Discrete Scheduler * Fix a bug of init_noise_sigma generation * minor fixes * use partition_point instead of custom bsearch * Fix some clippy lints. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers/src/models/stable_diffusion/utils.rs')
-rw-r--r--candle-transformers/src/models/stable_diffusion/utils.rs46
1 files changed, 46 insertions, 0 deletions
diff --git a/candle-transformers/src/models/stable_diffusion/utils.rs b/candle-transformers/src/models/stable_diffusion/utils.rs
index cef06f1c..5b5fa0f7 100644
--- a/candle-transformers/src/models/stable_diffusion/utils.rs
+++ b/candle-transformers/src/models/stable_diffusion/utils.rs
@@ -13,3 +13,49 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
Tensor::from_vec(vs, steps, &Device::Cpu)
}
}
+
+/// A linear interpolator for a sorted array of x and y values.
+struct LinearInterpolator<'x, 'y> {
+ xp: &'x [f64],
+ fp: &'y [f64],
+ cache: usize,
+}
+
+impl<'x, 'y> LinearInterpolator<'x, 'y> {
+ fn accel_find(&mut self, x: f64) -> usize {
+ let xidx = self.cache;
+ if x < self.xp[xidx] {
+ self.cache = self.xp[0..xidx].partition_point(|o| *o < x);
+ self.cache = self.cache.saturating_sub(1);
+ } else if x >= self.xp[xidx + 1] {
+ self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx;
+ self.cache = self.cache.saturating_sub(1);
+ }
+
+ self.cache
+ }
+
+ fn eval(&mut self, x: f64) -> f64 {
+ if x < self.xp[0] || x > self.xp[self.xp.len() - 1] {
+ return f64::NAN;
+ }
+
+ let idx = self.accel_find(x);
+
+ let x_l = self.xp[idx];
+ let x_h = self.xp[idx + 1];
+ let y_l = self.fp[idx];
+ let y_h = self.fp[idx + 1];
+ let dx = x_h - x_l;
+ if dx > 0.0 {
+ y_l + (x - x_l) / dx * (y_h - y_l)
+ } else {
+ f64::NAN
+ }
+ }
+}
+
+pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> {
+ let mut interpolator = LinearInterpolator { xp, fp, cache: 0 };
+ x.iter().map(|&x| interpolator.eval(x)).collect()
+}