diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-06 22:39:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-06 21:39:53 +0100 |
commit | 141df4ad2b80f0691450f32b28bcc2314301050b (patch) | |
tree | 357a62dd4bce15a2f43dc86d507f62748d61a820 /candle-examples/examples/stable-diffusion/main.rs | |
parent | 166bfd5847144abec227836e497b509625470535 (diff) | |
download | candle-141df4ad2b80f0691450f32b28bcc2314301050b.tar.gz candle-141df4ad2b80f0691450f32b28bcc2314301050b.tar.bz2 candle-141df4ad2b80f0691450f32b28bcc2314301050b.zip |
Main diffusion loop for the SD example. (#332)
Diffstat (limited to 'candle-examples/examples/stable-diffusion/main.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 244 |
1 files changed, 239 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 33a1dc71..2203b03a 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -13,21 +13,255 @@ mod unet_2d_blocks; mod utils; mod vae; -use anyhow::Result; +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; use clap::Parser; +use tokenizers::Tokenizer; -#[derive(Parser, Debug)] +const GUIDANCE_SCALE: f64 = 7.5; + +#[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A very realistic photo of a rusty robot walking on a sandy beach" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, + /// The height in pixels of the generated image. #[arg(long)] - prompt: String, + height: Option<usize>, + + /// The width in pixels of the generated image. + #[arg(long)] + width: Option<usize>, + + /// The UNet weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + unet_weights: Option<String>, + + /// The CLIP weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + clip_weights: Option<String>, + + /// The VAE weight file, in .ot or .safetensors format. + #[arg(long, value_name = "FILE")] + vae_weights: Option<String>, + + #[arg( + long, + value_name = "FILE", + default_value = "data/bpe_simple_vocab_16e6.txt" + )] + /// The file specifying the vocabulary to used for tokenization. + vocab_file: String, + + /// The size of the sliced attention or 0 for automatic slicing (disabled by default) + #[arg(long)] + sliced_attention_size: Option<usize>, + + /// The number of steps to run the diffusion for. + #[arg(long, default_value_t = 30)] + n_steps: usize, + + /// The number of samples to generate. + #[arg(long, default_value_t = 1)] + num_samples: i64, + + /// The name of the final image to generate. + #[arg(long, value_name = "FILE", default_value = "sd_final.png")] + final_image: String, + + #[arg(long, value_enum, default_value = "v2-1")] + sd_version: StableDiffusionVersion, + + /// Generate intermediary images at each step. + #[arg(long, action)] + intermediary_images: bool, } -fn main() -> Result<()> { - let _args = Args::parse(); +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +enum StableDiffusionVersion { + V1_5, + V2_1, +} + +impl Args { + fn clip_weights(&self) -> String { + match &self.clip_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), + }, + } + } + + fn vae_weights(&self) -> String { + match &self.vae_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), + }, + } + } + + fn unet_weights(&self) -> String { + match &self.unet_weights { + Some(w) => w.clone(), + None => match self.sd_version { + StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), + StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), + }, + } + } +} + +fn output_filename( + basename: &str, + sample_idx: i64, + num_samples: i64, + timestep_idx: Option<usize>, +) -> String { + let filename = if num_samples > 1 { + match basename.rsplit_once('.') { + None => format!("{basename}.{sample_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}.{sample_idx}.{extension}") + } + } + } else { + basename.to_string() + }; + match timestep_idx { + None => filename, + Some(timestep_idx) => match filename.rsplit_once('.') { + None => format!("{filename}-{timestep_idx}.png"), + Some((filename_no_extension, extension)) => { + format!("{filename_no_extension}-{timestep_idx}.{extension}") + } + }, + } +} + +fn run(args: Args) -> Result<()> { + let clip_weights = args.clip_weights(); + let vae_weights = args.vae_weights(); + let unet_weights = args.unet_weights(); + let Args { + prompt, + uncond_prompt, + cpu, + height, + width, + n_steps, + vocab_file, + final_image, + sliced_attention_size, + num_samples, + sd_version, + .. + } = args; + let sd_config = match sd_version { + StableDiffusionVersion::V1_5 => { + stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) + } + StableDiffusionVersion::V2_1 => { + stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) + } + }; + + let scheduler = sd_config.build_scheduler(n_steps)?; + let device = candle_examples::device(cpu)?; + + let tokenizer = Tokenizer::from_file(vocab_file).map_err(E::msg)?; + println!("Running with prompt \"{prompt}\"."); + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + + let uncond_tokens = tokenizer + .encode(uncond_prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?; + + println!("Building the Clip transformer."); + let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?; + let text_embeddings = text_model.forward(&tokens)?; + let uncond_embeddings = text_model.forward(&uncond_tokens)?; + let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; + + println!("Building the autoencoder."); + let vae = sd_config.build_vae(&vae_weights, &device)?; + println!("Building the unet."); + let unet = sd_config.build_unet(&unet_weights, &device, 4)?; + + let bsize = 1; + for idx in 0..num_samples { + let mut latents = Tensor::randn( + 0f32, + 1f32, + (bsize, 4, sd_config.height / 8, sd_config.width / 8), + &device, + )?; + + // scale the initial noise by the standard deviation required by the scheduler + latents = (latents * scheduler.init_noise_sigma())?; + + for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { + println!("Timestep {timestep_index}/{n_steps}"); + let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; + + let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?; + let noise_pred = + unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?; + let noise_pred = noise_pred.chunk(2, 0)?; + let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); + let noise_pred = + (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?; + latents = scheduler.step(&noise_pred, timestep, &latents)?; + + if args.intermediary_images { + let image = vae.decode(&(&latents / 0.18215)?)?; + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let _image = (image * 255.)?.to_dtype(DType::U8); + let _image_filename = + output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); + // TODO: save igame + } + } + + println!( + "Generating the final image for sample {}/{}.", + idx + 1, + num_samples + ); + let image = vae.decode(&(&latents / 0.18215)?)?; + // TODO: Add the clamping between 0 and 1. + let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?; + let _image = (image * 255.)?.to_dtype(DType::U8); + let _image_filename = output_filename(&final_image, idx + 1, num_samples, None); + // TODO: save image. + } Ok(()) } + +fn main() -> Result<()> { + let args = Args::parse(); + run(args) +} |