diff options
author | Czxck001 <10724409+Czxck001@users.noreply.github.com> | 2024-10-13 13:08:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-13 22:08:40 +0200 |
commit | ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch) | |
tree | 8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-examples/examples | |
parent | 0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff) | |
download | candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2 candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip |
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example
Add get_qkv_linear to handle different dimensionality in linears
Add stable diffusion 3 example
Add use_quant_conv and use_post_quant_conv for vae in stable diffusion
adapt existing AutoEncoderKLConfig to the change
add forward_until_encoder_layer to ClipTextTransformer
rename sd3 config to sd3_medium in mmdit; minor clean-up
Enable flash-attn for mmdit impl when the feature is enabled.
Add sd3 example codebase
add document
crediting references
pass the cargo fmt test
pass the clippy test
* fix typos
* expose cfg_scale and time_shift as options
* Replace the sample image with JPG version. Change image output format accordingly.
* make meaningful error messages
* remove the tail-end assignment in sd3_vae_vb_rename
* remove the CUDA requirement
* use default_value in clap args
* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime
* resolve clippy errors and warnings
* use default_value_t
* Pin the web-sys dependency.
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/README.md | 54 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg | bin | 0 -> 83401 bytes | |||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/clip.rs | 201 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/main.rs | 185 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/sampling.rs | 55 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/vae.rs | 93 |
6 files changed, 588 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/README.md b/candle-examples/examples/stable-diffusion-3/README.md new file mode 100644 index 00000000..746a31fa --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/README.md @@ -0,0 +1,54 @@ +# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium + + + +*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* + +Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. + +- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium) +- [research paper](https://arxiv.org/pdf/2403.03206) +- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) + +## Getting access to the weights + +The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. + +On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally. + +## Running the model + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- \ + --height 1024 --width 1024 \ + --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' +``` + +To display other options available, + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- --help +``` + +If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`. + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ... +``` + +## Performance Benchmark + +Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). + +[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). + +System specs (Desktop PCIE 5 x8/x8 dual-GPU setup): + +- Operating System: Ubuntu 23.10 +- CPU: i9 12900K w/o overclocking. +- RAM: 64G dual-channel DDR5 @ 4800 MT/s + +| Speed (iter/s) | w/o flash-attn | w/ flash-attn | +| -------------- | -------------- | ------------- | +| RTX 3090 Ti | 0.83 | 2.15 | +| RTX 4090 | 1.72 | 4.06 | diff --git a/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg Binary files differnew file mode 100644 index 00000000..58ca16c3 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs new file mode 100644 index 00000000..77263d96 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/clip.rs @@ -0,0 +1,201 @@ +use anyhow::{Error as E, Ok, Result}; +use candle::{DType, IndexOp, Module, Tensor, D}; +use candle_transformers::models::{stable_diffusion, t5}; +use tokenizers::tokenizer::Tokenizer; + +struct ClipWithTokenizer { + clip: stable_diffusion::clip::ClipTextTransformer, + config: stable_diffusion::clip::Config, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl ClipWithTokenizer { + fn new( + vb: candle_nn::VarBuilder, + config: stable_diffusion::clip::Config, + tokenizer_path: &str, + max_position_embeddings: usize, + ) -> Result<Self> { + let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?; + let path_buf = hf_hub::api::sync::Api::new()? + .model(tokenizer_path.to_string()) + .get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg( + "Failed to serialize huggingface PathBuf of CLIP tokenizer", + ))?) + .map_err(E::msg)?; + Ok(Self { + clip, + config, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let pad_id = match &self.config.pad_with { + Some(padding) => *self + .tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(E::msg("Failed to tokenize CLIP padding."))?, + None => *self + .tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?, + }; + + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let eos_position = tokens.len() - 1; + + while tokens.len() < self.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + let (text_embeddings, text_embeddings_penultimate) = self + .clip + .forward_until_encoder_layer(&tokens, usize::MAX, -2)?; + let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?; + + Ok((text_embeddings_penultimate, text_embeddings_pooled)) + } +} + +struct T5WithTokenizer { + t5: t5::T5EncoderModel, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl T5WithTokenizer { + fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> { + let api = hf_hub::api::sync::Api::new()?; + let repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let model = t5::T5EncoderModel::load(vb, &config)?; + + let tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(Self { + t5: model, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result<Tensor> { + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + tokens.resize(self.max_position_embeddings, 0); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let embeddings = self.t5.forward(&input_token_ids)?; + Ok(embeddings) + } +} + +pub struct StableDiffusion3TripleClipWithTokenizer { + clip_l: ClipWithTokenizer, + clip_g: ClipWithTokenizer, + clip_g_text_projection: candle_nn::Linear, + t5: T5WithTokenizer, +} + +impl StableDiffusion3TripleClipWithTokenizer { + pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> { + let max_position_embeddings = 77usize; + let clip_l = ClipWithTokenizer::new( + vb_fp16.pp("clip_l.transformer"), + stable_diffusion::clip::Config::sdxl(), + "openai/clip-vit-large-patch14", + max_position_embeddings, + )?; + + let clip_g = ClipWithTokenizer::new( + vb_fp16.pp("clip_g.transformer"), + stable_diffusion::clip::Config::sdxl2(), + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + max_position_embeddings, + )?; + + let text_projection = candle_nn::linear_no_bias( + 1280, + 1280, + vb_fp16.pp("clip_g.transformer.text_projection"), + )?; + + // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. + // This is a temporary workaround until the T5 implementation is updated to support fp16. + // Also see: + // https://github.com/huggingface/candle/issues/2480 + // https://github.com/huggingface/candle/pull/2481 + let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; + + Ok(Self { + clip_l, + clip_g, + clip_g_text_projection: text_projection, + t5, + }) + } + + pub fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let (clip_l_embeddings, clip_l_embeddings_pooled) = + self.clip_l.encode_text_to_embedding(prompt, device)?; + let (clip_g_embeddings, clip_g_embeddings_pooled) = + self.clip_g.encode_text_to_embedding(prompt, device)?; + + let clip_g_embeddings_pooled = self + .clip_g_text_projection + .forward(&clip_g_embeddings_pooled.unsqueeze(0)?)? + .squeeze(0)?; + + let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)? + .unsqueeze(0)?; + let clip_embeddings_concat = Tensor::cat( + &[&clip_l_embeddings, &clip_g_embeddings], + D::Minus1, + )? + .pad_with_zeros(D::Minus1, 0, 2048)?; + + let t5_embeddings = self + .t5 + .encode_text_to_embedding(prompt, device)? + .to_dtype(DType::F16)?; + let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; + + Ok((context, y)) + } +} diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs new file mode 100644 index 00000000..164ae420 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -0,0 +1,185 @@ +mod clip; +mod sampling; +mod vae; + +use candle::{DType, IndexOp, Tensor}; +use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT}; + +use crate::clip::StableDiffusion3TripleClipWithTokenizer; +use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename}; + +use anyhow::{Ok, Result}; +use clap::Parser; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A cute rusty robot holding a candle torch in its hand, \ + with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \ + bright background, high quality, 4k" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// The CUDA device ID to use. + #[arg(long, default_value = "0")] + cuda_device_id: usize, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Use flash_attn to accelerate attention operation in the MMDiT. + #[arg(long)] + use_flash_attn: bool, + + /// The height in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + height: usize, + + /// The width in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + width: usize, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 28)] + num_inference_steps: usize, + + // CFG scale. + #[arg(long, default_value_t = 4.0)] + cfg_scale: f64, + + // Time shift factor (alpha). + #[arg(long, default_value_t = 3.0)] + time_shift: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option<u64>, +} + +fn main() -> Result<()> { + let args = Args::parse(); + // Your main code here + run(args) +} + +fn run(args: Args) -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + cuda_device_id, + tracing, + use_flash_attn, + height, + width, + num_inference_steps, + cfg_scale, + time_shift, + seed, + } = args; + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + // TODO: Support and test on Metal. + let device = if cpu { + candle::Device::Cpu + } else { + candle::Device::cuda_if_available(cuda_device_id)? + }; + + let api = hf_hub::api::sync::Api::new()?; + let sai_repo = { + let name = "stabilityai/stable-diffusion-3-medium"; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; + let vb_fp16 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)? + }; + + let (context, y) = { + let vb_fp32 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors( + &[model_file.clone()], + DType::F32, + &device, + )? + }; + let mut triple = StableDiffusion3TripleClipWithTokenizer::new( + vb_fp16.pp("text_encoders"), + vb_fp32.pp("text_encoders"), + )?; + let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; + let (context_uncond, y_uncond) = + triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; + ( + Tensor::cat(&[context, context_uncond], 0)?, + Tensor::cat(&[y, y_uncond], 0)?, + ) + }; + + let x = { + let mmdit = MMDiT::new( + &MMDiTConfig::sd3_medium(), + use_flash_attn, + vb_fp16.pp("model.diffusion_model"), + )?; + + if let Some(seed) = seed { + device.set_seed(seed)?; + } + let start_time = std::time::Instant::now(); + let x = sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )?; + let dt = start_time.elapsed().as_secs_f32(); + println!( + "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", + dt, + num_inference_steps as f32 / dt + ); + x + }; + + let img = { + let vb_vae = vb_fp16 + .clone() + .rename_f(sd3_vae_vb_rename) + .pp("first_stage_model"); + let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; + + // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. + // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 + autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)? + }; + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs new file mode 100644 index 00000000..147d8e73 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -0,0 +1,55 @@ +use anyhow::{Ok, Result}; +use candle::{DType, Tensor}; + +use candle_transformers::models::flux; +use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function + +#[allow(clippy::too_many_arguments)] +pub fn euler_sample( + mmdit: &MMDiT, + y: &Tensor, + context: &Tensor, + num_inference_steps: usize, + cfg_scale: f64, + time_shift: f64, + height: usize, + width: usize, +) -> Result<Tensor> { + let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?; + let sigmas = (0..=num_inference_steps) + .map(|x| x as f64 / num_inference_steps as f64) + .rev() + .map(|x| time_snr_shift(time_shift, x)) + .collect::<Vec<f64>>(); + + for window in sigmas.windows(2) { + let (s_curr, s_prev) = match window { + [a, b] => (a, b), + _ => continue, + }; + + let timestep = (*s_curr) * 1000.0; + let noise_pred = mmdit.forward( + &Tensor::cat(&[x.clone(), x.clone()], 0)?, + &Tensor::full(timestep, (2,), x.device())?.contiguous()?, + y, + context, + )?; + x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?; + } + Ok(x) +} + +// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper +// https://arxiv.org/pdf/2403.03206 +// Following the implementation in ComfyUI: +// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/ +// comfy/model_sampling.py#L181 +fn time_snr_shift(alpha: f64, t: f64) -> f64 { + alpha * t / (1.0 + (alpha - 1.0) * t) +} + +fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> { + Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)? + - ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?) +} diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs new file mode 100644 index 00000000..708e472e --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/vae.rs @@ -0,0 +1,93 @@ +use anyhow::{Ok, Result}; +use candle_transformers::models::stable_diffusion::vae; + +pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> { + let config = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 16, + norm_num_groups: 32, + use_quant_conv: false, + use_post_quant_conv: false, + }; + Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?) +} + +pub fn sd3_vae_vb_rename(name: &str) -> String { + let parts: Vec<&str> = name.split('.').collect(); + let mut result = Vec::new(); + let mut i = 0; + + while i < parts.len() { + match parts[i] { + "down_blocks" => { + result.push("down"); + } + "mid_block" => { + result.push("mid"); + } + "up_blocks" => { + result.push("up"); + match parts[i + 1] { + // Reverse the order of up_blocks. + "0" => result.push("3"), + "1" => result.push("2"), + "2" => result.push("1"), + "3" => result.push("0"), + _ => {} + } + i += 1; // Skip the number after up_blocks. + } + "resnets" => { + if i > 0 && parts[i - 1] == "mid_block" { + match parts[i + 1] { + "0" => result.push("block_1"), + "1" => result.push("block_2"), + _ => {} + } + i += 1; // Skip the number after resnets. + } else { + result.push("block"); + } + } + "downsamplers" => { + result.push("downsample"); + i += 1; // Skip the 0 after downsamplers. + } + "conv_shortcut" => { + result.push("nin_shortcut"); + } + "attentions" => { + if parts[i + 1] == "0" { + result.push("attn_1") + } + i += 1; // Skip the number after attentions. + } + "group_norm" => { + result.push("norm"); + } + "query" => { + result.push("q"); + } + "key" => { + result.push("k"); + } + "value" => { + result.push("v"); + } + "proj_attn" => { + result.push("proj_out"); + } + "conv_norm_out" => { + result.push("norm_out"); + } + "upsamplers" => { + result.push("upsample"); + i += 1; // Skip the 0 after upsamplers. + } + part => result.push(part), + } + i += 1; + } + result.join(".") +} |