summaryrefslogtreecommitdiff
path: root/candle-examples/examples
diff options
context:
space:
mode:
authorCzxck001 <10724409+Czxck001@users.noreply.github.com>2024-10-13 13:08:40 -0700
committerGitHub <noreply@github.com>2024-10-13 22:08:40 +0200
commitca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch)
tree8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-examples/examples
parent0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff)
downloadcandle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2
candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples')
-rw-r--r--candle-examples/examples/stable-diffusion-3/README.md54
-rw-r--r--candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpgbin0 -> 83401 bytes
-rw-r--r--candle-examples/examples/stable-diffusion-3/clip.rs201
-rw-r--r--candle-examples/examples/stable-diffusion-3/main.rs185
-rw-r--r--candle-examples/examples/stable-diffusion-3/sampling.rs55
-rw-r--r--candle-examples/examples/stable-diffusion-3/vae.rs93
6 files changed, 588 insertions, 0 deletions
diff --git a/candle-examples/examples/stable-diffusion-3/README.md b/candle-examples/examples/stable-diffusion-3/README.md
new file mode 100644
index 00000000..746a31fa
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/README.md
@@ -0,0 +1,54 @@
+# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium
+
+![](assets/stable-diffusion-3.jpg)
+
+*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k*
+
+Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture.
+
+- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
+- [research paper](https://arxiv.org/pdf/2403.03206)
+- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium)
+
+## Getting access to the weights
+
+The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting the [repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account.
+
+On the first run, the weights will be automatically downloaded from the Huggingface Hub. You might be prompted to configure a [Huggingface User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) on your computer if you haven't done that before. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally.
+
+## Running the model
+
+```shell
+cargo run --example stable-diffusion-3 --release --features=cuda -- \
+ --height 1024 --width 1024 \
+ --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k'
+```
+
+To display other options available,
+
+```shell
+cargo run --example stable-diffusion-3 --release --features=cuda -- --help
+```
+
+If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`.
+
+```shell
+cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ...
+```
+
+## Performance Benchmark
+
+Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds).
+
+[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc).
+
+System specs (Desktop PCIE 5 x8/x8 dual-GPU setup):
+
+- Operating System: Ubuntu 23.10
+- CPU: i9 12900K w/o overclocking.
+- RAM: 64G dual-channel DDR5 @ 4800 MT/s
+
+| Speed (iter/s) | w/o flash-attn | w/ flash-attn |
+| -------------- | -------------- | ------------- |
+| RTX 3090 Ti | 0.83 | 2.15 |
+| RTX 4090 | 1.72 | 4.06 |
diff --git a/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg
new file mode 100644
index 00000000..58ca16c3
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg
Binary files differ
diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs
new file mode 100644
index 00000000..77263d96
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/clip.rs
@@ -0,0 +1,201 @@
+use anyhow::{Error as E, Ok, Result};
+use candle::{DType, IndexOp, Module, Tensor, D};
+use candle_transformers::models::{stable_diffusion, t5};
+use tokenizers::tokenizer::Tokenizer;
+
+struct ClipWithTokenizer {
+ clip: stable_diffusion::clip::ClipTextTransformer,
+ config: stable_diffusion::clip::Config,
+ tokenizer: Tokenizer,
+ max_position_embeddings: usize,
+}
+
+impl ClipWithTokenizer {
+ fn new(
+ vb: candle_nn::VarBuilder,
+ config: stable_diffusion::clip::Config,
+ tokenizer_path: &str,
+ max_position_embeddings: usize,
+ ) -> Result<Self> {
+ let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?;
+ let path_buf = hf_hub::api::sync::Api::new()?
+ .model(tokenizer_path.to_string())
+ .get("tokenizer.json")?;
+ let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg(
+ "Failed to serialize huggingface PathBuf of CLIP tokenizer",
+ ))?)
+ .map_err(E::msg)?;
+ Ok(Self {
+ clip,
+ config,
+ tokenizer,
+ max_position_embeddings,
+ })
+ }
+
+ fn encode_text_to_embedding(
+ &self,
+ prompt: &str,
+ device: &candle::Device,
+ ) -> Result<(Tensor, Tensor)> {
+ let pad_id = match &self.config.pad_with {
+ Some(padding) => *self
+ .tokenizer
+ .get_vocab(true)
+ .get(padding.as_str())
+ .ok_or(E::msg("Failed to tokenize CLIP padding."))?,
+ None => *self
+ .tokenizer
+ .get_vocab(true)
+ .get("<|endoftext|>")
+ .ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?,
+ };
+
+ let mut tokens = self
+ .tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+
+ let eos_position = tokens.len() - 1;
+
+ while tokens.len() < self.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+ let (text_embeddings, text_embeddings_penultimate) = self
+ .clip
+ .forward_until_encoder_layer(&tokens, usize::MAX, -2)?;
+ let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?;
+
+ Ok((text_embeddings_penultimate, text_embeddings_pooled))
+ }
+}
+
+struct T5WithTokenizer {
+ t5: t5::T5EncoderModel,
+ tokenizer: Tokenizer,
+ max_position_embeddings: usize,
+}
+
+impl T5WithTokenizer {
+ fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result<Self> {
+ let api = hf_hub::api::sync::Api::new()?;
+ let repo = api.repo(hf_hub::Repo::with_revision(
+ "google/t5-v1_1-xxl".to_string(),
+ hf_hub::RepoType::Model,
+ "refs/pr/2".to_string(),
+ ));
+ let config_filename = repo.get("config.json")?;
+ let config = std::fs::read_to_string(config_filename)?;
+ let config: t5::Config = serde_json::from_str(&config)?;
+ let model = t5::T5EncoderModel::load(vb, &config)?;
+
+ let tokenizer_filename = api
+ .model("lmz/mt5-tokenizers".to_string())
+ .get("t5-v1_1-xxl.tokenizer.json")?;
+
+ let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
+ Ok(Self {
+ t5: model,
+ tokenizer,
+ max_position_embeddings,
+ })
+ }
+
+ fn encode_text_to_embedding(
+ &mut self,
+ prompt: &str,
+ device: &candle::Device,
+ ) -> Result<Tensor> {
+ let mut tokens = self
+ .tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ tokens.resize(self.max_position_embeddings, 0);
+ let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
+ let embeddings = self.t5.forward(&input_token_ids)?;
+ Ok(embeddings)
+ }
+}
+
+pub struct StableDiffusion3TripleClipWithTokenizer {
+ clip_l: ClipWithTokenizer,
+ clip_g: ClipWithTokenizer,
+ clip_g_text_projection: candle_nn::Linear,
+ t5: T5WithTokenizer,
+}
+
+impl StableDiffusion3TripleClipWithTokenizer {
+ pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
+ let max_position_embeddings = 77usize;
+ let clip_l = ClipWithTokenizer::new(
+ vb_fp16.pp("clip_l.transformer"),
+ stable_diffusion::clip::Config::sdxl(),
+ "openai/clip-vit-large-patch14",
+ max_position_embeddings,
+ )?;
+
+ let clip_g = ClipWithTokenizer::new(
+ vb_fp16.pp("clip_g.transformer"),
+ stable_diffusion::clip::Config::sdxl2(),
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
+ max_position_embeddings,
+ )?;
+
+ let text_projection = candle_nn::linear_no_bias(
+ 1280,
+ 1280,
+ vb_fp16.pp("clip_g.transformer.text_projection"),
+ )?;
+
+ // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
+ // This is a temporary workaround until the T5 implementation is updated to support fp16.
+ // Also see:
+ // https://github.com/huggingface/candle/issues/2480
+ // https://github.com/huggingface/candle/pull/2481
+ let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;
+
+ Ok(Self {
+ clip_l,
+ clip_g,
+ clip_g_text_projection: text_projection,
+ t5,
+ })
+ }
+
+ pub fn encode_text_to_embedding(
+ &mut self,
+ prompt: &str,
+ device: &candle::Device,
+ ) -> Result<(Tensor, Tensor)> {
+ let (clip_l_embeddings, clip_l_embeddings_pooled) =
+ self.clip_l.encode_text_to_embedding(prompt, device)?;
+ let (clip_g_embeddings, clip_g_embeddings_pooled) =
+ self.clip_g.encode_text_to_embedding(prompt, device)?;
+
+ let clip_g_embeddings_pooled = self
+ .clip_g_text_projection
+ .forward(&clip_g_embeddings_pooled.unsqueeze(0)?)?
+ .squeeze(0)?;
+
+ let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)?
+ .unsqueeze(0)?;
+ let clip_embeddings_concat = Tensor::cat(
+ &[&clip_l_embeddings, &clip_g_embeddings],
+ D::Minus1,
+ )?
+ .pad_with_zeros(D::Minus1, 0, 2048)?;
+
+ let t5_embeddings = self
+ .t5
+ .encode_text_to_embedding(prompt, device)?
+ .to_dtype(DType::F16)?;
+ let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;
+
+ Ok((context, y))
+ }
+}
diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs
new file mode 100644
index 00000000..164ae420
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/main.rs
@@ -0,0 +1,185 @@
+mod clip;
+mod sampling;
+mod vae;
+
+use candle::{DType, IndexOp, Tensor};
+use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
+
+use crate::clip::StableDiffusion3TripleClipWithTokenizer;
+use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
+
+use anyhow::{Ok, Result};
+use clap::Parser;
+
+#[derive(Parser)]
+#[command(author, version, about, long_about = None)]
+struct Args {
+ /// The prompt to be used for image generation.
+ #[arg(
+ long,
+ default_value = "A cute rusty robot holding a candle torch in its hand, \
+ with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \
+ bright background, high quality, 4k"
+ )]
+ prompt: String,
+
+ #[arg(long, default_value = "")]
+ uncond_prompt: String,
+
+ /// Run on CPU rather than on GPU.
+ #[arg(long)]
+ cpu: bool,
+
+ /// The CUDA device ID to use.
+ #[arg(long, default_value = "0")]
+ cuda_device_id: usize,
+
+ /// Enable tracing (generates a trace-timestamp.json file).
+ #[arg(long)]
+ tracing: bool,
+
+ /// Use flash_attn to accelerate attention operation in the MMDiT.
+ #[arg(long)]
+ use_flash_attn: bool,
+
+ /// The height in pixels of the generated image.
+ #[arg(long, default_value_t = 1024)]
+ height: usize,
+
+ /// The width in pixels of the generated image.
+ #[arg(long, default_value_t = 1024)]
+ width: usize,
+
+ /// The seed to use when generating random samples.
+ #[arg(long, default_value_t = 28)]
+ num_inference_steps: usize,
+
+ // CFG scale.
+ #[arg(long, default_value_t = 4.0)]
+ cfg_scale: f64,
+
+ // Time shift factor (alpha).
+ #[arg(long, default_value_t = 3.0)]
+ time_shift: f64,
+
+ /// The seed to use when generating random samples.
+ #[arg(long)]
+ seed: Option<u64>,
+}
+
+fn main() -> Result<()> {
+ let args = Args::parse();
+ // Your main code here
+ run(args)
+}
+
+fn run(args: Args) -> Result<()> {
+ use tracing_chrome::ChromeLayerBuilder;
+ use tracing_subscriber::prelude::*;
+
+ let Args {
+ prompt,
+ uncond_prompt,
+ cpu,
+ cuda_device_id,
+ tracing,
+ use_flash_attn,
+ height,
+ width,
+ num_inference_steps,
+ cfg_scale,
+ time_shift,
+ seed,
+ } = args;
+
+ let _guard = if tracing {
+ let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
+ tracing_subscriber::registry().with(chrome_layer).init();
+ Some(guard)
+ } else {
+ None
+ };
+
+ // TODO: Support and test on Metal.
+ let device = if cpu {
+ candle::Device::Cpu
+ } else {
+ candle::Device::cuda_if_available(cuda_device_id)?
+ };
+
+ let api = hf_hub::api::sync::Api::new()?;
+ let sai_repo = {
+ let name = "stabilityai/stable-diffusion-3-medium";
+ api.repo(hf_hub::Repo::model(name.to_string()))
+ };
+ let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
+ let vb_fp16 = unsafe {
+ candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
+ };
+
+ let (context, y) = {
+ let vb_fp32 = unsafe {
+ candle_nn::VarBuilder::from_mmaped_safetensors(
+ &[model_file.clone()],
+ DType::F32,
+ &device,
+ )?
+ };
+ let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
+ vb_fp16.pp("text_encoders"),
+ vb_fp32.pp("text_encoders"),
+ )?;
+ let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
+ let (context_uncond, y_uncond) =
+ triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
+ (
+ Tensor::cat(&[context, context_uncond], 0)?,
+ Tensor::cat(&[y, y_uncond], 0)?,
+ )
+ };
+
+ let x = {
+ let mmdit = MMDiT::new(
+ &MMDiTConfig::sd3_medium(),
+ use_flash_attn,
+ vb_fp16.pp("model.diffusion_model"),
+ )?;
+
+ if let Some(seed) = seed {
+ device.set_seed(seed)?;
+ }
+ let start_time = std::time::Instant::now();
+ let x = sampling::euler_sample(
+ &mmdit,
+ &y,
+ &context,
+ num_inference_steps,
+ cfg_scale,
+ time_shift,
+ height,
+ width,
+ )?;
+ let dt = start_time.elapsed().as_secs_f32();
+ println!(
+ "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
+ dt,
+ num_inference_steps as f32 / dt
+ );
+ x
+ };
+
+ let img = {
+ let vb_vae = vb_fp16
+ .clone()
+ .rename_f(sd3_vae_vb_rename)
+ .pp("first_stage_model");
+ let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;
+
+ // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
+ // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
+ autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
+ };
+ let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
+ candle_examples::save_image(&img.i(0)?, "out.jpg")?;
+ Ok(())
+}
diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs
new file mode 100644
index 00000000..147d8e73
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/sampling.rs
@@ -0,0 +1,55 @@
+use anyhow::{Ok, Result};
+use candle::{DType, Tensor};
+
+use candle_transformers::models::flux;
+use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function
+
+#[allow(clippy::too_many_arguments)]
+pub fn euler_sample(
+ mmdit: &MMDiT,
+ y: &Tensor,
+ context: &Tensor,
+ num_inference_steps: usize,
+ cfg_scale: f64,
+ time_shift: f64,
+ height: usize,
+ width: usize,
+) -> Result<Tensor> {
+ let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?;
+ let sigmas = (0..=num_inference_steps)
+ .map(|x| x as f64 / num_inference_steps as f64)
+ .rev()
+ .map(|x| time_snr_shift(time_shift, x))
+ .collect::<Vec<f64>>();
+
+ for window in sigmas.windows(2) {
+ let (s_curr, s_prev) = match window {
+ [a, b] => (a, b),
+ _ => continue,
+ };
+
+ let timestep = (*s_curr) * 1000.0;
+ let noise_pred = mmdit.forward(
+ &Tensor::cat(&[x.clone(), x.clone()], 0)?,
+ &Tensor::full(timestep, (2,), x.device())?.contiguous()?,
+ y,
+ context,
+ )?;
+ x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?;
+ }
+ Ok(x)
+}
+
+// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper
+// https://arxiv.org/pdf/2403.03206
+// Following the implementation in ComfyUI:
+// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/
+// comfy/model_sampling.py#L181
+fn time_snr_shift(alpha: f64, t: f64) -> f64 {
+ alpha * t / (1.0 + (alpha - 1.0) * t)
+}
+
+fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
+ Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
+ - ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
+}
diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs
new file mode 100644
index 00000000..708e472e
--- /dev/null
+++ b/candle-examples/examples/stable-diffusion-3/vae.rs
@@ -0,0 +1,93 @@
+use anyhow::{Ok, Result};
+use candle_transformers::models::stable_diffusion::vae;
+
+pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
+ let config = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 16,
+ norm_num_groups: 32,
+ use_quant_conv: false,
+ use_post_quant_conv: false,
+ };
+ Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
+}
+
+pub fn sd3_vae_vb_rename(name: &str) -> String {
+ let parts: Vec<&str> = name.split('.').collect();
+ let mut result = Vec::new();
+ let mut i = 0;
+
+ while i < parts.len() {
+ match parts[i] {
+ "down_blocks" => {
+ result.push("down");
+ }
+ "mid_block" => {
+ result.push("mid");
+ }
+ "up_blocks" => {
+ result.push("up");
+ match parts[i + 1] {
+ // Reverse the order of up_blocks.
+ "0" => result.push("3"),
+ "1" => result.push("2"),
+ "2" => result.push("1"),
+ "3" => result.push("0"),
+ _ => {}
+ }
+ i += 1; // Skip the number after up_blocks.
+ }
+ "resnets" => {
+ if i > 0 && parts[i - 1] == "mid_block" {
+ match parts[i + 1] {
+ "0" => result.push("block_1"),
+ "1" => result.push("block_2"),
+ _ => {}
+ }
+ i += 1; // Skip the number after resnets.
+ } else {
+ result.push("block");
+ }
+ }
+ "downsamplers" => {
+ result.push("downsample");
+ i += 1; // Skip the 0 after downsamplers.
+ }
+ "conv_shortcut" => {
+ result.push("nin_shortcut");
+ }
+ "attentions" => {
+ if parts[i + 1] == "0" {
+ result.push("attn_1")
+ }
+ i += 1; // Skip the number after attentions.
+ }
+ "group_norm" => {
+ result.push("norm");
+ }
+ "query" => {
+ result.push("q");
+ }
+ "key" => {
+ result.push("k");
+ }
+ "value" => {
+ result.push("v");
+ }
+ "proj_attn" => {
+ result.push("proj_out");
+ }
+ "conv_norm_out" => {
+ result.push("norm_out");
+ }
+ "upsamplers" => {
+ result.push("upsample");
+ i += 1; // Skip the 0 after upsamplers.
+ }
+ part => result.push(part),
+ }
+ i += 1;
+ }
+ result.join(".")
+}