summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-29 09:00:04 +0100
committerGitHub <noreply@github.com>2023-08-29 09:00:04 +0100
commit33c23c19b6f4821c00a47758f7841baf52ba9081 (patch)
tree71ad06837519d92a52d6a0021887bf1db4e29360 /candle-examples/examples/stable-diffusion
parent49326fb9252b67d1e46c69565da93106ee4b71a0 (diff)
downloadcandle-33c23c19b6f4821c00a47758f7841baf52ba9081.tar.gz
candle-33c23c19b6f4821c00a47758f7841baf52ba9081.tar.bz2
candle-33c23c19b6f4821c00a47758f7841baf52ba9081.zip
Preliminary support for SDXL. (#647)
* Preliminary support for SDXL. * More SDXL support. * More SDXL. * Use the proper clip config. * Querying for existing tensors. * More robust test.
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs2
-rw-r--r--candle-examples/examples/stable-diffusion/clip.rs30
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs170
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs108
4 files changed, 253 insertions, 57 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
index 58f5e87e..797542aa 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -473,7 +473,7 @@ impl AttentionBlock {
let num_heads = channels / num_head_channels;
let group_norm =
nn::group_norm(config.num_groups, channels, config.eps, vs.pp("group_norm"))?;
- let (q_path, k_path, v_path, out_path) = if vs.dtype() == DType::F16 {
+ let (q_path, k_path, v_path, out_path) = if vs.contains_tensor("to_q.weight") {
("to_q", "to_k", "to_v", "to_out.0")
} else {
("query", "key", "value", "proj_attn")
diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs
index 2927a404..d26c1c46 100644
--- a/candle-examples/examples/stable-diffusion/clip.rs
+++ b/candle-examples/examples/stable-diffusion/clip.rs
@@ -69,6 +69,36 @@ impl Config {
activation: Activation::Gelu,
}
}
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder/config.json
+ pub fn sdxl() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 768,
+ intermediate_size: 3072,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 12,
+ num_attention_heads: 12,
+ projection_dim: 768,
+ activation: Activation::QuickGelu,
+ }
+ }
+
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/config.json
+ pub fn sdxl2() -> Self {
+ Self {
+ vocab_size: 49408,
+ embed_dim: 1280,
+ intermediate_size: 5120,
+ max_position_embeddings: 77,
+ pad_with: Some("!".to_string()),
+ num_hidden_layers: 32,
+ num_attention_heads: 20,
+ projection_dim: 1280,
+ activation: Activation::Gelu,
+ }
+ }
}
// CLIP Text Model
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 1443986c..8372edcd 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -17,7 +17,7 @@ mod utils;
mod vae;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{DType, Device, IndexOp, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@@ -102,12 +102,16 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
+ Xl,
}
+#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
+ Tokenizer2,
Clip,
+ Clip2,
Unet,
Vae,
}
@@ -115,6 +119,7 @@ enum ModelFile {
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
+ Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
@@ -122,7 +127,7 @@ impl StableDiffusionVersion {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -134,7 +139,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -146,7 +151,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -155,12 +160,21 @@ impl StableDiffusionVersion {
}
}
}
+
+ fn clip2_file(&self, use_f16: bool) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
+ if use_f16 {
+ "text_encoder_2/model.fp16.safetensors"
+ } else {
+ "text_encoder_2/model.safetensors"
+ }
+ }
+ }
+ }
}
impl ModelFile {
- const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
- const TOKENIZER_PATH: &str = "tokenizer.json";
-
fn get(
&self,
filename: Option<String>,
@@ -172,8 +186,24 @@ impl ModelFile {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
- Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
+ Self::Tokenizer => {
+ let tokenizer_repo = match version {
+ StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
+ "openai/clip-vit-base-patch32"
+ }
+ StableDiffusionVersion::Xl => {
+ // This seems similar to the patch32 version except some very small
+ // difference in the split regex.
+ "openai/clip-vit-large-patch14"
+ }
+ };
+ (tokenizer_repo, "tokenizer.json")
+ }
+ Self::Tokenizer2 => {
+ ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
+ }
Self::Clip => (version.repo(), version.clip_file(use_f16)),
+ Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
@@ -211,6 +241,71 @@ fn output_filename(
}
}
+#[allow(clippy::too_many_arguments)]
+fn text_embeddings(
+ prompt: &str,
+ uncond_prompt: &str,
+ tokenizer: Option<String>,
+ clip_weights: Option<String>,
+ sd_version: StableDiffusionVersion,
+ sd_config: &stable_diffusion::StableDiffusionConfig,
+ use_f16: bool,
+ device: &Device,
+ dtype: DType,
+ first: bool,
+) -> Result<Tensor> {
+ let tokenizer_file = if first {
+ ModelFile::Tokenizer
+ } else {
+ ModelFile::Tokenizer2
+ };
+ let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let clip_weights_file = if first {
+ ModelFile::Clip
+ } else {
+ ModelFile::Clip2
+ };
+ let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
+ let clip_config = if first {
+ &sd_config.clip
+ } else {
+ sd_config.clip2.as_ref().unwrap()
+ };
+ let text_model =
+ stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+ Ok(text_embeddings)
+}
+
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Xl => {
+ stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
+ }
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
- let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let pad_id = match &sd_config.clip.pad_with {
- Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
- None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
- };
- println!("Running with prompt \"{prompt}\".");
- let mut tokens = tokenizer
- .encode(prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while tokens.len() < sd_config.clip.max_position_embeddings {
- tokens.push(pad_id)
- }
- let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- println!("Building the Clip transformer.");
- let text_embeddings = {
- let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
- let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
- let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ let which = match sd_version {
+ StableDiffusionVersion::Xl => vec![true, false],
+ _ => vec![true],
};
+ let text_embeddings = which
+ .iter()
+ .map(|first| {
+ text_embeddings(
+ &prompt,
+ &uncond_prompt,
+ tokenizer.clone(),
+ clip_weights.clone(),
+ sd_version,
+ &sd_config,
+ use_f16,
+ &device,
+ dtype,
+ *first,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
+ println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index e159fa0a..bed60161 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
+ pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
+ clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
+ clip2: None,
autoencoder,
scheduler,
unet,
@@ -155,6 +158,83 @@ impl StableDiffusionConfig {
)
}
+ fn sdxl_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 1024
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 1024
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn sdxl(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
+ PredictionType::Epsilon,
+ )
+ }
+
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -193,17 +273,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
+}
- pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
- &self,
- clip_weights: P,
- device: &Device,
- dtype: DType,
- ) -> Result<clip::ClipTextTransformer> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
- let weights = weights.deserialize()?;
- let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
- let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
- Ok(text_model)
- }
+pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
+ clip: &clip::Config,
+ clip_weights: P,
+ device: &Device,
+ dtype: DType,
+) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let text_model = clip::ClipTextTransformer::new(vs, clip)?;
+ Ok(text_model)
}