diff options
-rw-r--r-- | candle-examples/examples/wuerstchen/main.rs | 26 | ||||
-rw-r--r-- | candle-transformers/src/models/wuerstchen/ddpm.rs | 4 |
2 files changed, 25 insertions, 5 deletions
diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 12e4c10e..d40cfbd9 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -14,7 +14,7 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; -const GUIDANCE_SCALE: f64 = 7.5; +const PRIOR_GUIDANCE_SCALE: f64 = 8.0; const RESOLUTION_MULTIPLE: f64 = 42.67; const PRIOR_CIN: usize = 16; @@ -288,16 +288,32 @@ fn run(args: Args) -> Result<()> { let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize; let b_size = 1; for idx in 0..num_samples { - let latents = Tensor::randn( + let mut latents = Tensor::randn( 0f32, 1f32, (b_size, PRIOR_CIN, latent_height, latent_width), &device, )?; - // TODO: latents denoising loop, use the scheduler values. - let ratio = Tensor::ones(1, DType::F32, &device)?; - let prior = prior.forward(&latents, &ratio, &prior_text_embeddings)?; + let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; + let timesteps = prior_scheduler.timesteps(); + println!("prior denoising"); + for (index, &t) in timesteps.iter().enumerate() { + let start_time = std::time::Instant::now(); + if index == timesteps.len() - 1 { + continue; + } + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; + let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_text, noise_pred_uncond) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = (noise_pred_uncond + + ((noise_pred_text - noise_pred_uncond)? * PRIOR_GUIDANCE_SCALE)?)?; + latents = prior_scheduler.step(&noise_pred, t, &latents)?; + let dt = start_time.elapsed().as_secs_f32(); + println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt); + } let latents = ((latents * 42.)? - 1.)?; /* let timesteps = scheduler.timesteps(); diff --git a/candle-transformers/src/models/wuerstchen/ddpm.rs b/candle-transformers/src/models/wuerstchen/ddpm.rs index f4f16bfb..80640072 100644 --- a/candle-transformers/src/models/wuerstchen/ddpm.rs +++ b/candle-transformers/src/models/wuerstchen/ddpm.rs @@ -38,6 +38,10 @@ impl DDPMWScheduler { }) } + pub fn timesteps(&self) -> &[f64] { + &self.timesteps + } + fn alpha_cumprod(&self, t: f64) -> f64 { let scaler = self.config.scaler; let s = self.config.s; |