//! Stable Diffusion
//!
//! Stable Diffusion is a latent text-to-image diffusion model capable of
//! generating photo-realistic images given any text input.
//!
//! - 💻 [Original Repository](https://github.com/CompVis/stable-diffusion)
//! - 🤗 [Hugging Face](https://huggingface.co/runwayml/stable-diffusion-v1-5)
//! - The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising Diffusion Implicit Model scheduler (DDIM). The original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
//!
//!
//! # Example
//!
//!
//!

//!
//!
//! _"A rusty robot holding a fire torch in its hand."_ Generated by Stable Diffusion XL using Rust and [candle](https://github.com/huggingface/candle).
//!
//! ```bash
//! # example running with cuda
//! # see the candle-examples/examples/stable-diffusion for all options
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
//!
//! # with sd-turbo
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
//! --sd-version turbo
//!
//! # with flash attention.
//! # feature flag: `--features flash-attn`
//! # cli flag: `--use-flash-attn`.
//! # flash-attention-v2 is only compatible with Ampere, Ada, \
//! # or Hopper GPUs (e.g., A100/H100, RTX 3090/4090).
//! cargo run --example stable-diffusion --release --features=cuda,cudnn \
//! -- --prompt "a cosmonaut on a horse (hd, realistic, high-def)" \
//! --use-flash-attn
//! ```
pub mod attention;
pub mod clip;
pub mod ddim;
pub mod ddpm;
pub mod embeddings;
pub mod euler_ancestral_discrete;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
pub mod unet_2d_blocks;
pub mod uni_pc;
pub mod utils;
pub mod vae;
use std::sync::Arc;
use candle::{DType, Device, Result};
use candle_nn as nn;
use self::schedulers::{Scheduler, SchedulerConfig};
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
pub clip2: Option,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: Arc,
}
impl StableDiffusionConfig {
pub fn v1_5(
sliced_attention_size: Option,
height: Option,
width: Option,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, Some(1), 8),
bc(640, Some(1), 8),
bc(1280, Some(1), 8),
bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: false,
};
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
use_quant_conv: true,
use_post_quant_conv: true,
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type: schedulers::PredictionType::Epsilon,
..Default::default()
});
StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler,
unet,
}
}
fn v2_1_(
sliced_attention_size: Option,
height: Option,
width: Option,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, Some(1), 5),
bc(640, Some(1), 10),
bc(1280, Some(1), 20),
bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
use_quant_conv: true,
use_post_quant_conv: true,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
768
};
StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
clip2: None,
autoencoder,
scheduler,
unet,
}
}
pub fn v2_1(
sliced_attention_size: Option,
height: Option,
width: Option,
) -> Self {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json
Self::v2_1_(
sliced_attention_size,
height,
width,
schedulers::PredictionType::VPrediction,
)
}
fn sdxl_(
sliced_attention_size: Option,
height: Option,
width: Option,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
use_quant_conv: true,
use_post_quant_conv: true,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
StableDiffusionConfig {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
fn sdxl_turbo_(
sliced_attention_size: Option,
height: Option,
width: Option,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
use_quant_conv: true,
use_post_quant_conv: true,
};
let scheduler = Arc::new(
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
prediction_type,
timestep_spacing: schedulers::TimestepSpacing::Trailing,
..Default::default()
},
);
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
pub fn sdxl(
sliced_attention_size: Option,
height: Option,
width: Option,
) -> Self {
Self::sdxl_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn sdxl_turbo(
sliced_attention_size: Option,
height: Option,
width: Option,
) -> Self {
Self::sdxl_turbo_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn ssd1b(
sliced_attention_size: Option,
height: Option,
width: Option,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
use_quant_conv: true,
use_post_quant_conv: true,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
..Default::default()
});
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::ssd1b(),
clip2: Some(clip::Config::ssd1b2()),
autoencoder,
scheduler,
unet,
}
}
pub fn build_vae>(
&self,
vae_weights: P,
device: &Device,
dtype: DType,
) -> Result {
let vs_ae =
unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder)
}
pub fn build_unet>(
&self,
unet_weights: P,
device: &Device,
in_channels: usize,
use_flash_attn: bool,
dtype: DType,
) -> Result {
let vs_unet =
unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
let unet = unet_2d::UNet2DConditionModel::new(
vs_unet,
in_channels,
4,
use_flash_attn,
self.unet.clone(),
)?;
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result> {
self.scheduler.build(n_steps)
}
}
pub fn build_clip_transformer>(
clip: &clip::Config,
clip_weights: P,
device: &Device,
dtype: DType,
) -> Result {
let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model)
}