diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/main.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/main.rs | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs new file mode 100644 index 00000000..164ae420 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -0,0 +1,185 @@ +mod clip; +mod sampling; +mod vae; + +use candle::{DType, IndexOp, Tensor}; +use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT}; + +use crate::clip::StableDiffusion3TripleClipWithTokenizer; +use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename}; + +use anyhow::{Ok, Result}; +use clap::Parser; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A cute rusty robot holding a candle torch in its hand, \ + with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \ + bright background, high quality, 4k" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The CUDA device ID to use. + #[arg(long, default_value = "0")] + cuda_device_id: usize, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Use flash_attn to accelerate attention operation in the MMDiT. + #[arg(long)] + use_flash_attn: bool, + + /// The height in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + height: usize, + + /// The width in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + width: usize, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 28)] + num_inference_steps: usize, + + // CFG scale. + #[arg(long, default_value_t = 4.0)] + cfg_scale: f64, + + // Time shift factor (alpha). + #[arg(long, default_value_t = 3.0)] + time_shift: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option<u64>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + // Your main code here + run(args) +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + cuda_device_id, + tracing, + use_flash_attn, + height, + width, + num_inference_steps, + cfg_scale, + time_shift, + seed, + } = args; + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + // TODO: Support and test on Metal. + let device = if cpu { + candle::Device::Cpu + } else { + candle::Device::cuda_if_available(cuda_device_id)? + }; + + let api = hf_hub::api::sync::Api::new()?; + let sai_repo = { + let name = "stabilityai/stable-diffusion-3-medium"; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; + let vb_fp16 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)? + }; + + let (context, y) = { + let vb_fp32 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors( + &[model_file.clone()], + DType::F32, + &device, + )? + }; + let mut triple = StableDiffusion3TripleClipWithTokenizer::new( + vb_fp16.pp("text_encoders"), + vb_fp32.pp("text_encoders"), + )?; + let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; + let (context_uncond, y_uncond) = + triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; + ( + Tensor::cat(&[context, context_uncond], 0)?, + Tensor::cat(&[y, y_uncond], 0)?, + ) + }; + + let x = { + let mmdit = MMDiT::new( + &MMDiTConfig::sd3_medium(), + use_flash_attn, + vb_fp16.pp("model.diffusion_model"), + )?; + + if let Some(seed) = seed { + device.set_seed(seed)?; + } + let start_time = std::time::Instant::now(); + let x = sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )?; + let dt = start_time.elapsed().as_secs_f32(); + println!( + "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", + dt, + num_inference_steps as f32 / dt + ); + x + }; + + let img = { + let vb_vae = vb_fp16 + .clone() + .rename_f(sd3_vae_vb_rename) + .pp("first_stage_model"); + let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; + + // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. + // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 + autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)? + }; + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} |