diff options
author | Czxck001 <10724409+Czxck001@users.noreply.github.com> | 2024-10-13 13:08:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-13 22:08:40 +0200 |
commit | ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch) | |
tree | 8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-examples/examples/stable-diffusion-3/vae.rs | |
parent | 0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff) | |
download | candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2 candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip |
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example
Add get_qkv_linear to handle different dimensionality in linears
Add stable diffusion 3 example
Add use_quant_conv and use_post_quant_conv for vae in stable diffusion
adapt existing AutoEncoderKLConfig to the change
add forward_until_encoder_layer to ClipTextTransformer
rename sd3 config to sd3_medium in mmdit; minor clean-up
Enable flash-attn for mmdit impl when the feature is enabled.
Add sd3 example codebase
add document
crediting references
pass the cargo fmt test
pass the clippy test
* fix typos
* expose cfg_scale and time_shift as options
* Replace the sample image with JPG version. Change image output format accordingly.
* make meaningful error messages
* remove the tail-end assignment in sd3_vae_vb_rename
* remove the CUDA requirement
* use default_value in clap args
* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime
* resolve clippy errors and warnings
* use default_value_t
* Pin the web-sys dependency.
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/vae.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion-3/vae.rs | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs new file mode 100644 index 00000000..708e472e --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/vae.rs @@ -0,0 +1,93 @@ +use anyhow::{Ok, Result}; +use candle_transformers::models::stable_diffusion::vae; + +pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> { + let config = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 16, + norm_num_groups: 32, + use_quant_conv: false, + use_post_quant_conv: false, + }; + Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?) +} + +pub fn sd3_vae_vb_rename(name: &str) -> String { + let parts: Vec<&str> = name.split('.').collect(); + let mut result = Vec::new(); + let mut i = 0; + + while i < parts.len() { + match parts[i] { + "down_blocks" => { + result.push("down"); + } + "mid_block" => { + result.push("mid"); + } + "up_blocks" => { + result.push("up"); + match parts[i + 1] { + // Reverse the order of up_blocks. + "0" => result.push("3"), + "1" => result.push("2"), + "2" => result.push("1"), + "3" => result.push("0"), + _ => {} + } + i += 1; // Skip the number after up_blocks. + } + "resnets" => { + if i > 0 && parts[i - 1] == "mid_block" { + match parts[i + 1] { + "0" => result.push("block_1"), + "1" => result.push("block_2"), + _ => {} + } + i += 1; // Skip the number after resnets. + } else { + result.push("block"); + } + } + "downsamplers" => { + result.push("downsample"); + i += 1; // Skip the 0 after downsamplers. + } + "conv_shortcut" => { + result.push("nin_shortcut"); + } + "attentions" => { + if parts[i + 1] == "0" { + result.push("attn_1") + } + i += 1; // Skip the number after attentions. + } + "group_norm" => { + result.push("norm"); + } + "query" => { + result.push("q"); + } + "key" => { + result.push("k"); + } + "value" => { + result.push("v"); + } + "proj_attn" => { + result.push("proj_out"); + } + "conv_norm_out" => { + result.push("norm_out"); + } + "upsamplers" => { + result.push("upsample"); + i += 1; // Skip the 0 after upsamplers. + } + part => result.push(part), + } + i += 1; + } + result.join(".") +} |