summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/main.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs170
1 files changed, 128 insertions, 42 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 1443986c..8372edcd 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -17,7 +17,7 @@ mod utils;
mod vae;
use anyhow::{Error as E, Result};
-use candle::{DType, Device, IndexOp, Tensor};
+use candle::{DType, Device, IndexOp, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
@@ -102,12 +102,16 @@ struct Args {
enum StableDiffusionVersion {
V1_5,
V2_1,
+ Xl,
}
+#[allow(unused)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelFile {
Tokenizer,
+ Tokenizer2,
Clip,
+ Clip2,
Unet,
Vae,
}
@@ -115,6 +119,7 @@ enum ModelFile {
impl StableDiffusionVersion {
fn repo(&self) -> &'static str {
match self {
+ Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
}
@@ -122,7 +127,7 @@ impl StableDiffusionVersion {
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -134,7 +139,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -146,7 +151,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 => {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -155,12 +160,21 @@ impl StableDiffusionVersion {
}
}
}
+
+ fn clip2_file(&self, use_f16: bool) -> &'static str {
+ match self {
+ Self::V1_5 | Self::V2_1 | Self::Xl => {
+ if use_f16 {
+ "text_encoder_2/model.fp16.safetensors"
+ } else {
+ "text_encoder_2/model.safetensors"
+ }
+ }
+ }
+ }
}
impl ModelFile {
- const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32";
- const TOKENIZER_PATH: &str = "tokenizer.json";
-
fn get(
&self,
filename: Option<String>,
@@ -172,8 +186,24 @@ impl ModelFile {
Some(filename) => Ok(std::path::PathBuf::from(filename)),
None => {
let (repo, path) = match self {
- Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH),
+ Self::Tokenizer => {
+ let tokenizer_repo = match version {
+ StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
+ "openai/clip-vit-base-patch32"
+ }
+ StableDiffusionVersion::Xl => {
+ // This seems similar to the patch32 version except some very small
+ // difference in the split regex.
+ "openai/clip-vit-large-patch14"
+ }
+ };
+ (tokenizer_repo, "tokenizer.json")
+ }
+ Self::Tokenizer2 => {
+ ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json")
+ }
Self::Clip => (version.repo(), version.clip_file(use_f16)),
+ Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
Self::Vae => (version.repo(), version.vae_file(use_f16)),
};
@@ -211,6 +241,71 @@ fn output_filename(
}
}
+#[allow(clippy::too_many_arguments)]
+fn text_embeddings(
+ prompt: &str,
+ uncond_prompt: &str,
+ tokenizer: Option<String>,
+ clip_weights: Option<String>,
+ sd_version: StableDiffusionVersion,
+ sd_config: &stable_diffusion::StableDiffusionConfig,
+ use_f16: bool,
+ device: &Device,
+ dtype: DType,
+ first: bool,
+) -> Result<Tensor> {
+ let tokenizer_file = if first {
+ ModelFile::Tokenizer
+ } else {
+ ModelFile::Tokenizer2
+ };
+ let tokenizer = tokenizer_file.get(tokenizer, sd_version, use_f16)?;
+ let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
+ let pad_id = match &sd_config.clip.pad_with {
+ Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
+ None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
+ };
+ println!("Running with prompt \"{prompt}\".");
+ let mut tokens = tokenizer
+ .encode(prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while tokens.len() < sd_config.clip.max_position_embeddings {
+ tokens.push(pad_id)
+ }
+ let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+
+ println!("Building the Clip transformer.");
+ let clip_weights_file = if first {
+ ModelFile::Clip
+ } else {
+ ModelFile::Clip2
+ };
+ let clip_weights = clip_weights_file.get(clip_weights, sd_version, false)?;
+ let clip_config = if first {
+ &sd_config.clip
+ } else {
+ sd_config.clip2.as_ref().unwrap()
+ };
+ let text_model =
+ stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
+ let text_embeddings = text_model.forward(&tokens)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+ let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+ Ok(text_embeddings)
+}
+
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@@ -252,46 +347,37 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::V2_1 => {
stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Xl => {
+ stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
+ }
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
- let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version, use_f16)?;
- let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
- let pad_id = match &sd_config.clip.pad_with {
- Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(),
- None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(),
- };
- println!("Running with prompt \"{prompt}\".");
- let mut tokens = tokenizer
- .encode(prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while tokens.len() < sd_config.clip.max_position_embeddings {
- tokens.push(pad_id)
- }
- let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?;
-
- println!("Building the Clip transformer.");
- let text_embeddings = {
- let clip_weights = ModelFile::Clip.get(clip_weights, sd_version, false)?;
- let text_model = sd_config.build_clip_transformer(&clip_weights, &device, DType::F32)?;
- let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ let which = match sd_version {
+ StableDiffusionVersion::Xl => vec![true, false],
+ _ => vec![true],
};
+ let text_embeddings = which
+ .iter()
+ .map(|first| {
+ text_embeddings(
+ &prompt,
+ &uncond_prompt,
+ tokenizer.clone(),
+ clip_weights.clone(),
+ sd_version,
+ &sd_config,
+ use_f16,
+ &device,
+ dtype,
+ *first,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
+ println!("{text_embeddings:?}");
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;