summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/main.rs
diff options
context:
space:
mode:
authorEdwin Cheng <edwin0cheng@gmail.com>2023-12-03 15:37:10 +0800
committerGitHub <noreply@github.com>2023-12-03 08:37:10 +0100
commit37bf1ed012926b4180d2e9068b84118e1eb6e26d (patch)
tree73aa11818463fbe5bf39f4978dd6a5538a1a01a5 /candle-examples/examples/stable-diffusion/main.rs
parentdd40edfe73b796ad12f6246319cde2e603c4dc56 (diff)
downloadcandle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.gz
candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.tar.bz2
candle-37bf1ed012926b4180d2e9068b84118e1eb6e26d.zip
Stable Diffusion Turbo Support (#1395)
* Add support for SD Turbo * Set Leading as default in euler_ancestral discrete * Use the appropriate default values for n_steps and guidance_scale. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-examples/examples/stable-diffusion/main.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs121
1 files changed, 90 insertions, 31 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 3e6de34d..8c3ca2ee 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -11,8 +11,6 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
-const GUIDANCE_SCALE: f64 = 7.5;
-
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@@ -63,8 +61,8 @@ struct Args {
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
- #[arg(long, default_value_t = 30)]
- n_steps: usize,
+ #[arg(long)]
+ n_steps: Option<usize>,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
@@ -87,6 +85,9 @@ struct Args {
#[arg(long)]
use_f16: bool,
+ #[arg(long)]
+ guidance_scale: Option<f64>,
+
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
@@ -102,6 +103,7 @@ enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
+ Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@@ -120,12 +122,13 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
+ Self::Turbo => "stabilityai/sdxl-turbo",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -137,7 +140,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@@ -149,7 +152,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@@ -161,7 +164,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
- Self::V1_5 | Self::V2_1 | Self::Xl => {
+ Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
@@ -189,7 +192,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
- StableDiffusionVersion::Xl => {
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
@@ -206,7 +209,11 @@ impl ModelFile {
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
- if version == StableDiffusionVersion::Xl && use_f16 {
+ if matches!(
+ version,
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
+ ) && use_f16
+ {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
@@ -261,6 +268,7 @@ fn text_embeddings(
use_f16: bool,
device: &Device,
dtype: DType,
+ use_guide_scale: bool,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
@@ -285,16 +293,6 @@ fn text_embeddings(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
- let mut uncond_tokens = tokenizer
- .encode(uncond_prompt, true)
- .map_err(E::msg)?
- .get_ids()
- .to_vec();
- while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
- uncond_tokens.push(pad_id)
- }
- let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
-
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
@@ -310,8 +308,24 @@ fn text_embeddings(
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
- let uncond_embeddings = text_model.forward(&uncond_tokens)?;
- let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
+
+ let text_embeddings = if use_guide_scale {
+ let mut uncond_tokens = tokenizer
+ .encode(uncond_prompt, true)
+ .map_err(E::msg)?
+ .get_ids()
+ .to_vec();
+ while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
+ uncond_tokens.push(pad_id)
+ }
+
+ let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
+ let uncond_embeddings = text_model.forward(&uncond_tokens)?;
+
+ Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
+ } else {
+ text_embeddings.to_dtype(dtype)?
+ };
Ok(text_embeddings)
}
@@ -356,6 +370,7 @@ fn run(args: Args) -> Result<()> {
unet_weights,
tracing,
use_f16,
+ guidance_scale,
use_flash_attn,
img2img,
img2img_strength,
@@ -374,6 +389,24 @@ fn run(args: Args) -> Result<()> {
None
};
+ let guidance_scale = match guidance_scale {
+ Some(guidance_scale) => guidance_scale,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 7.5,
+ StableDiffusionVersion::Turbo => 0.,
+ },
+ };
+ let n_steps = match n_steps {
+ Some(n_steps) => n_steps,
+ None => match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 30,
+ StableDiffusionVersion::Turbo => 1,
+ },
+ };
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
@@ -385,13 +418,19 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
+ StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
+ sliced_attention_size,
+ height,
+ width,
+ ),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
+ let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version {
- StableDiffusionVersion::Xl => vec![true, false],
+ StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
@@ -407,10 +446,12 @@ fn run(args: Args) -> Result<()> {
use_f16,
&device,
dtype,
+ use_guide_scale,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
+
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
@@ -434,11 +475,19 @@ fn run(args: Args) -> Result<()> {
0
};
let bsize = 1;
+
+ let vae_scale = match sd_version {
+ StableDiffusionVersion::V1_5
+ | StableDiffusionVersion::V2_1
+ | StableDiffusionVersion::Xl => 0.18215,
+ StableDiffusionVersion::Turbo => 0.13025,
+ };
+
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
- let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
+ let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
@@ -465,21 +514,31 @@ fn run(args: Args) -> Result<()> {
continue;
}
let start_time = std::time::Instant::now();
- let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
+ let latent_model_input = if use_guide_scale {
+ Tensor::cat(&[&latents, &latents], 0)?
+ } else {
+ latents.clone()
+ };
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
- let noise_pred = noise_pred.chunk(2, 0)?;
- let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
- let noise_pred =
- (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
+
+ let noise_pred = if use_guide_scale {
+ let noise_pred = noise_pred.chunk(2, 0)?;
+ let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
+
+ (noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
+ } else {
+ noise_pred
+ };
+
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
@@ -493,7 +552,7 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
- let image = vae.decode(&(&latents / 0.18215)?)?;
+ let image = vae.decode(&(&latents / vae_scale)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);