summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion-3/vae.rs
blob: 708e472eff175c53951f39a87ddcdccdf0feeb38 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use anyhow::{Ok, Result};
use candle_transformers::models::stable_diffusion::vae;

pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
    let config = vae::AutoEncoderKLConfig {
        block_out_channels: vec![128, 256, 512, 512],
        layers_per_block: 2,
        latent_channels: 16,
        norm_num_groups: 32,
        use_quant_conv: false,
        use_post_quant_conv: false,
    };
    Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
}

pub fn sd3_vae_vb_rename(name: &str) -> String {
    let parts: Vec<&str> = name.split('.').collect();
    let mut result = Vec::new();
    let mut i = 0;

    while i < parts.len() {
        match parts[i] {
            "down_blocks" => {
                result.push("down");
            }
            "mid_block" => {
                result.push("mid");
            }
            "up_blocks" => {
                result.push("up");
                match parts[i + 1] {
                    // Reverse the order of up_blocks.
                    "0" => result.push("3"),
                    "1" => result.push("2"),
                    "2" => result.push("1"),
                    "3" => result.push("0"),
                    _ => {}
                }
                i += 1; // Skip the number after up_blocks.
            }
            "resnets" => {
                if i > 0 && parts[i - 1] == "mid_block" {
                    match parts[i + 1] {
                        "0" => result.push("block_1"),
                        "1" => result.push("block_2"),
                        _ => {}
                    }
                    i += 1; // Skip the number after resnets.
                } else {
                    result.push("block");
                }
            }
            "downsamplers" => {
                result.push("downsample");
                i += 1; // Skip the 0 after downsamplers.
            }
            "conv_shortcut" => {
                result.push("nin_shortcut");
            }
            "attentions" => {
                if parts[i + 1] == "0" {
                    result.push("attn_1")
                }
                i += 1; // Skip the number after attentions.
            }
            "group_norm" => {
                result.push("norm");
            }
            "query" => {
                result.push("q");
            }
            "key" => {
                result.push("k");
            }
            "value" => {
                result.push("v");
            }
            "proj_attn" => {
                result.push("proj_out");
            }
            "conv_norm_out" => {
                result.push("norm_out");
            }
            "upsamplers" => {
                result.push("upsample");
                i += 1; // Skip the 0 after upsamplers.
            }
            part => result.push(part),
        }
        i += 1;
    }
    result.join(".")
}