summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion-3/vae.rs
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-10-13 13:08:40 -0700
committerGitHub <noreply@github.com>2024-10-13 22:08:40 +0200
commitca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch)
tree8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-examples/examples/stable-diffusion-3/vae.rs
parent0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff)
downloadcandle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/stable-diffusion-3/vae.rs')
-rw-r--r--candle-examples/examples/stable-diffusion-3/vae.rs93
1 files changed, 93 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs
new file mode 100644
index 00000000..708e472e
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/vae.rs
@@ -0,0 +1,93 @@
+use anyhow::{Ok, Result};
+use candle_transformers::models::stable_diffusion::vae;
+
+pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
+ let config = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 16,
+ norm_num_groups: 32,
+ use_quant_conv: false,
+ use_post_quant_conv: false,
+ };
+ Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
+}
+
+pub fn sd3_vae_vb_rename(name: &str) -> String {
+ let parts: Vec<&str> = name.split('.').collect();
+ let mut result = Vec::new();
+ let mut i = 0;
+
+ while i < parts.len() {
+ match parts[i] {
+ "down_blocks" => {
+ result.push("down");
+ }
+ "mid_block" => {
+ result.push("mid");
+ }
+ "up_blocks" => {
+ result.push("up");
+ match parts[i + 1] {
+ // Reverse the order of up_blocks.
+ "0" => result.push("3"),
+ "1" => result.push("2"),
+ "2" => result.push("1"),
+ "3" => result.push("0"),
+ _ => {}
+ }
+ i += 1; // Skip the number after up_blocks.
+ }
+ "resnets" => {
+ if i > 0 && parts[i - 1] == "mid_block" {
+ match parts[i + 1] {
+ "0" => result.push("block_1"),
+ "1" => result.push("block_2"),
+ _ => {}
+ }
+ i += 1; // Skip the number after resnets.
+ } else {
+ result.push("block");
+ }
+ }
+ "downsamplers" => {
+ result.push("downsample");
+ i += 1; // Skip the 0 after downsamplers.
+ }
+ "conv_shortcut" => {
+ result.push("nin_shortcut");
+ }
+ "attentions" => {
+ if parts[i + 1] == "0" {
+ result.push("attn_1")
+ }
+ i += 1; // Skip the number after attentions.
+ }
+ "group_norm" => {
+ result.push("norm");
+ }
+ "query" => {
+ result.push("q");
+ }
+ "key" => {
+ result.push("k");
+ }
+ "value" => {
+ result.push("v");
+ }
+ "proj_attn" => {
+ result.push("proj_out");
+ }
+ "conv_norm_out" => {
+ result.push("norm_out");
+ }
+ "upsamplers" => {
+ result.push("upsample");
+ i += 1; // Skip the 0 after upsamplers.
+ }
+ part => result.push(part),
+ }
+ i += 1;
+ }
+ result.join(".")
+}