summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/wuerstchen/main.rs26
-rw-r--r--candle-transformers/src/models/wuerstchen/ddpm.rs4
2 files changed, 25 insertions, 5 deletions
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs
index 12e4c10e..d40cfbd9 100644
--- a/candle-examples/examples/wuerstchen/main.rs
+++ b/candle-examples/examples/wuerstchen/main.rs
@@ -14,7 +14,7 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
-const GUIDANCE_SCALE: f64 = 7.5;
+const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
const RESOLUTION_MULTIPLE: f64 = 42.67;
const PRIOR_CIN: usize = 16;
@@ -288,16 +288,32 @@ fn run(args: Args) -> Result<()> {
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
let b_size = 1;
for idx in 0..num_samples {
- let latents = Tensor::randn(
+ let mut latents = Tensor::randn(
0f32,
1f32,
(b_size, PRIOR_CIN, latent_height, latent_width),
&device,
)?;
- // TODO: latents denoising loop, use the scheduler values.
- let ratio = Tensor::ones(1, DType::F32, &device)?;
- let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?;
+ let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
+ let timesteps = prior_scheduler.timesteps();
+ println!("prior denoising");
+ for (index, &t) in timesteps.iter().enumerate() {
+ let start_time = std::time::Instant::now();
+ if index == timesteps.len() - 1 {
+ continue;
+ }
+ let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+ let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
+ let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]);
+ let noise_pred = (noise_pred_uncond
+ + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?;
+ latents = prior_scheduler.step(&noise_pred, t, &latents)?;
+ let dt = start_time.elapsed().as_secs_f32();
+ println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
+ }
let latents = ((latents * 42.)? - 1.)?;
/*
let timesteps = scheduler.timesteps();
diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs
index f4f16bfb..80640072 100644
--- a/candle-transformers/src/models/wuerstchen/ddpm.rs
+++ b/candle-transformers/src/models/wuerstchen/ddpm.rs
@@ -38,6 +38,10 @@ impl DDPMWScheduler {
})
}
+ pub fn timesteps(&self) -> &[f64] {
+ &self.timesteps
+ }
+
fn alpha_cumprod(&self, t: f64) -> f64 {
let scaler = self.config.scaler;
let s = self.config.s;