diff options
author | Edwin Cheng <edwin0cheng@gmail.com> | 2023-12-03 15:37:10 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-03 08:37:10 +0100 |
commit | 37bf1ed012926b4180d2e9068b84118e1eb6e26d (patch) | |
tree | 73aa11818463fbe5bf39f4978dd6a5538a1a01a5 /candle-examples/examples/stable-diffusion/main.rs | |
parent | dd40edfe73b796ad12f6246319cde2e603c4dc56 (diff) | |
download | candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.gz candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.bz2 candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.zip |
Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo
* Set Leading as default in euler_ancestral discrete
* Use the appropriate default values for n_steps and guidance_scale.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/stable-diffusion/main.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 121 |
1 files changed, 90 insertions, 31 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 3e6de34d..8c3ca2ee 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D}; use clap::Parser; use tokenizers::Tokenizer; -const GUIDANCE_SCALE: f64 = 7.5; - #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { @@ -63,8 +61,8 @@ struct Args { sliced_attention_size: Option<usize>, /// The number of steps to run the diffusion for. - #[arg(long, default_value_t = 30)] - n_steps: usize, + #[arg(long)] + n_steps: Option<usize>, /// The number of samples to generate. #[arg(long, default_value_t = 1)] @@ -87,6 +85,9 @@ struct Args { #[arg(long)] use_f16: bool, + #[arg(long)] + guidance_scale: Option<f64>, + #[arg(long, value_name = "FILE")] img2img: Option<String>, @@ -102,6 +103,7 @@ enum StableDiffusionVersion { V1_5, V2_1, Xl, + Turbo, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -120,12 +122,13 @@ impl StableDiffusionVersion { Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", Self::V2_1 => "stabilityai/stable-diffusion-2-1", Self::V1_5 => "runwayml/stable-diffusion-v1-5", + Self::Turbo => "stabilityai/sdxl-turbo", } } fn unet_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "unet/diffusion_pytorch_model.fp16.safetensors" } else { @@ -137,7 +140,7 @@ impl StableDiffusionVersion { fn vae_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "vae/diffusion_pytorch_model.fp16.safetensors" } else { @@ -149,7 +152,7 @@ impl StableDiffusionVersion { fn clip_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder/model.fp16.safetensors" } else { @@ -161,7 +164,7 @@ impl StableDiffusionVersion { fn clip2_file(&self, use_f16: bool) -> &'static str { match self { - Self::V1_5 | Self::V2_1 | Self::Xl => { + Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { if use_f16 { "text_encoder_2/model.fp16.safetensors" } else { @@ -189,7 +192,7 @@ impl ModelFile { StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { "openai/clip-vit-base-patch32" } - StableDiffusionVersion::Xl => { + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { // This seems similar to the patch32 version except some very small // difference in the split regex. "openai/clip-vit-large-patch14" @@ -206,7 +209,11 @@ impl ModelFile { Self::Vae => { // Override for SDXL when using f16 weights. // See https://github.com/huggingface/candle/issues/1060 - if version == StableDiffusionVersion::Xl && use_f16 { + if matches!( + version, + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, + ) && use_f16 + { ( "madebyollin/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", @@ -261,6 +268,7 @@ fn text_embeddings( use_f16: bool, device: &Device, dtype: DType, + use_guide_scale: bool, first: bool, ) -> Result<Tensor> { let tokenizer_file = if first { @@ -285,16 +293,6 @@ fn text_embeddings( } let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; - let mut uncond_tokens = tokenizer - .encode(uncond_prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - while uncond_tokens.len() < sd_config.clip.max_position_embeddings { - uncond_tokens.push(pad_id) - } - let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; - println!("Building the Clip transformer."); let clip_weights_file = if first { ModelFile::Clip @@ -310,8 +308,24 @@ fn text_embeddings( let text_model = stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?; let text_embeddings = text_model.forward(&tokens)?; - let uncond_embeddings = text_model.forward(&uncond_tokens)?; - let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?; + + let text_embeddings = if use_guide_scale { + let mut uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + while uncond_tokens.len() < sd_config.clip.max_position_embeddings { + uncond_tokens.push(pad_id) + } + + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + + Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)? + } else { + text_embeddings.to_dtype(dtype)? + }; Ok(text_embeddings) } @@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> { unet_weights, tracing, use_f16, + guidance_scale, use_flash_attn, img2img, img2img_strength, @@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> { None }; + let guidance_scale = match guidance_scale { + Some(guidance_scale) => guidance_scale, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 7.5, + StableDiffusionVersion::Turbo => 0., + }, + }; + let n_steps = match n_steps { + Some(n_steps) => n_steps, + None => match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 30, + StableDiffusionVersion::Turbo => 1, + }, + }; let dtype = if use_f16 { DType::F16 } else { DType::F32 }; let sd_config = match sd_version { StableDiffusionVersion::V1_5 => { @@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> { StableDiffusionVersion::Xl => { stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width) } + StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo( + sliced_attention_size, + height, + width, + ), }; let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + let use_guide_scale = guidance_scale > 1.0; let which = match sd_version { - StableDiffusionVersion::Xl => vec![true, false], + StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false], _ => vec![true], }; let text_embeddings = which @@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> { use_f16, &device, dtype, + use_guide_scale, *first, ) }) .collect::<Result<Vec<_>>>()?; + let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; println!("{text_embeddings:?}"); @@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> { 0 }; let bsize = 1; + + let vae_scale = match sd_version { + StableDiffusionVersion::V1_5 + | StableDiffusionVersion::V2_1 + | StableDiffusionVersion::Xl => 0.18215, + StableDiffusionVersion::Turbo => 0.13025, + }; + for idx in 0..num_samples { let timesteps = scheduler.timesteps(); let latents = match &init_latent_dist { Some(init_latent_dist) => { - let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?; + let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?; if t_start < timesteps.len() { let noise = latents.randn_like(0f64, 1f64)?; scheduler.add_noise(&latents, noise, timesteps[t_start])? @@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> { continue; } let start_time = std::time::Instant::now(); - let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + let latent_model_input = if use_guide_scale { + Tensor::cat(&[&latents, &latents], 0)? + } else { + latents.clone() + }; let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; - let noise_pred = noise_pred.chunk(2, 0)?; - let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); - let noise_pred = - (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + + let noise_pred = if use_guide_scale { + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)? + } else { + noise_pred + }; + latents = scheduler.step(&noise_pred, timestep, &latents)?; let dt = start_time.elapsed().as_secs_f32(); println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); if args.intermediary_images { - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = @@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> { idx + 1, num_samples ); - let image = vae.decode(&(&latents / 0.18215)?)?; + let image = vae.decode(&(&latents / vae_scale)?)?; let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?; let image_filename = output_filename(&final_image, idx + 1, num_samples, None); |