summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/stable_diffusion.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs128
1 files changed, 106 insertions, 22 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index e159fa0a..cffc00d8 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
pub width: usize,
pub height: usize,
pub clip: clip::Config,
+ pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: ddim::DDIMSchedulerConfig,
@@ -27,10 +28,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 8),
- bc(640, true, 8),
- bc(1280, true, 8),
- bc(1280, false, 8),
+ bc(320, Some(1), 8),
+ bc(640, Some(1), 8),
+ bc(1280, Some(1), 8),
+ bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
@@ -51,7 +52,7 @@ impl StableDiffusionConfig {
norm_num_groups: 32,
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
@@ -68,6 +69,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v1_5(),
+ clip2: None,
autoencoder,
scheduler: Default::default(),
unet,
@@ -88,10 +90,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 5),
- bc(640, true, 10),
- bc(1280, true, 20),
- bc(1280, false, 20),
+ bc(320, Some(1), 5),
+ bc(640, Some(1), 10),
+ bc(1280, Some(1), 20),
+ bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
@@ -118,7 +120,7 @@ impl StableDiffusionConfig {
};
let height = if let Some(height) = height {
- assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
768
@@ -135,6 +137,7 @@ impl StableDiffusionConfig {
width,
height,
clip: clip::Config::v2_1(),
+ clip2: None,
autoencoder,
scheduler,
unet,
@@ -155,6 +158,87 @@ impl StableDiffusionConfig {
)
}
+ fn sdxl_(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ prediction_type: PredictionType,
+ ) -> Self {
+ let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
+ out_channels,
+ use_cross_attn,
+ attention_head_dim,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
+ let unet = unet_2d::UNet2DConditionModelConfig {
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
+ center_input_sample: false,
+ cross_attention_dim: 2048,
+ downsample_padding: 1,
+ flip_sin_to_cos: true,
+ freq_shift: 0.,
+ layers_per_block: 2,
+ mid_block_scale_factor: 1.,
+ norm_eps: 1e-5,
+ norm_num_groups: 32,
+ sliced_attention_size,
+ use_linear_projection: true,
+ };
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
+ let autoencoder = vae::AutoEncoderKLConfig {
+ block_out_channels: vec![128, 256, 512, 512],
+ layers_per_block: 2,
+ latent_channels: 4,
+ norm_num_groups: 32,
+ };
+ let scheduler = ddim::DDIMSchedulerConfig {
+ prediction_type,
+ ..Default::default()
+ };
+
+ let height = if let Some(height) = height {
+ assert_eq!(height % 8, 0, "height has to be divisible by 8");
+ height
+ } else {
+ 1024
+ };
+
+ let width = if let Some(width) = width {
+ assert_eq!(width % 8, 0, "width has to be divisible by 8");
+ width
+ } else {
+ 1024
+ };
+
+ Self {
+ width,
+ height,
+ clip: clip::Config::sdxl(),
+ clip2: Some(clip::Config::sdxl2()),
+ autoencoder,
+ scheduler,
+ unet,
+ }
+ }
+
+ pub fn sdxl(
+ sliced_attention_size: Option<usize>,
+ height: Option<usize>,
+ width: Option<usize>,
+ ) -> Self {
+ Self::sdxl_(
+ sliced_attention_size,
+ height,
+ width,
+ // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
+ PredictionType::Epsilon,
+ )
+ }
+
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -193,17 +277,17 @@ impl StableDiffusionConfig {
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
+}
- pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
- &self,
- clip_weights: P,
- device: &Device,
- dtype: DType,
- ) -> Result<clip::ClipTextTransformer> {
- let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
- let weights = weights.deserialize()?;
- let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
- let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
- Ok(text_model)
- }
+pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
+ clip: &clip::Config,
+ clip_weights: P,
+ device: &Device,
+ dtype: DType,
+) -> Result<clip::ClipTextTransformer> {
+ let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
+ let weights = weights.deserialize()?;
+ let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
+ let text_model = clip::ClipTextTransformer::new(vs, clip)?;
+ Ok(text_model)
}