diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 128 |
1 files changed, 106 insertions, 22 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index e159fa0a..cffc00d8 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -8,6 +8,7 @@ pub struct StableDiffusionConfig { pub width: usize, pub height: usize, pub clip: clip::Config, + pub clip2: Option<clip::Config>, autoencoder: vae::AutoEncoderKLConfig, unet: unet_2d::UNet2DConditionModelConfig, scheduler: ddim::DDIMSchedulerConfig, @@ -27,10 +28,10 @@ impl StableDiffusionConfig { // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 8), - bc(640, true, 8), - bc(1280, true, 8), - bc(1280, false, 8), + bc(320, Some(1), 8), + bc(640, Some(1), 8), + bc(1280, Some(1), 8), + bc(1280, None, 8), ], center_input_sample: false, cross_attention_dim: 768, @@ -51,7 +52,7 @@ impl StableDiffusionConfig { norm_num_groups: 32, }; let height = if let Some(height) = height { - assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + assert_eq!(height % 8, 0, "height has to be divisible by 8"); height } else { 512 @@ -68,6 +69,7 @@ impl StableDiffusionConfig { width, height, clip: clip::Config::v1_5(), + clip2: None, autoencoder, scheduler: Default::default(), unet, @@ -88,10 +90,10 @@ impl StableDiffusionConfig { // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 5), - bc(640, true, 10), - bc(1280, true, 20), - bc(1280, false, 20), + bc(320, Some(1), 5), + bc(640, Some(1), 10), + bc(1280, Some(1), 20), + bc(1280, None, 20), ], center_input_sample: false, cross_attention_dim: 1024, @@ -118,7 +120,7 @@ impl StableDiffusionConfig { }; let height = if let Some(height) = height { - assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); + assert_eq!(height % 8, 0, "height has to be divisible by 8"); height } else { 768 @@ -135,6 +137,7 @@ impl StableDiffusionConfig { width, height, clip: clip::Config::v2_1(), + clip2: None, autoencoder, scheduler, unet, @@ -155,6 +158,87 @@ impl StableDiffusionConfig { ) } + fn sdxl_( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + prediction_type: PredictionType, + ) -> Self { + let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { + out_channels, + use_cross_attn, + attention_head_dim, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json + let unet = unet_2d::UNet2DConditionModelConfig { + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], + center_input_sample: false, + cross_attention_dim: 2048, + downsample_padding: 1, + flip_sin_to_cos: true, + freq_shift: 0., + layers_per_block: 2, + mid_block_scale_factor: 1., + norm_eps: 1e-5, + norm_num_groups: 32, + sliced_attention_size, + use_linear_projection: true, + }; + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json + let autoencoder = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 4, + norm_num_groups: 32, + }; + let scheduler = ddim::DDIMSchedulerConfig { + prediction_type, + ..Default::default() + }; + + let height = if let Some(height) = height { + assert_eq!(height % 8, 0, "height has to be divisible by 8"); + height + } else { + 1024 + }; + + let width = if let Some(width) = width { + assert_eq!(width % 8, 0, "width has to be divisible by 8"); + width + } else { + 1024 + }; + + Self { + width, + height, + clip: clip::Config::sdxl(), + clip2: Some(clip::Config::sdxl2()), + autoencoder, + scheduler, + unet, + } + } + + pub fn sdxl( + sliced_attention_size: Option<usize>, + height: Option<usize>, + width: Option<usize>, + ) -> Self { + Self::sdxl_( + sliced_attention_size, + height, + width, + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json + PredictionType::Epsilon, + ) + } + pub fn build_vae<P: AsRef<std::path::Path>>( &self, vae_weights: P, @@ -193,17 +277,17 @@ impl StableDiffusionConfig { pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> { ddim::DDIMScheduler::new(n_steps, self.scheduler) } +} - pub fn build_clip_transformer<P: AsRef<std::path::Path>>( - &self, - clip_weights: P, - device: &Device, - dtype: DType, - ) -> Result<clip::ClipTextTransformer> { - let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; - let weights = weights.deserialize()?; - let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); - let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?; - Ok(text_model) - } +pub fn build_clip_transformer<P: AsRef<std::path::Path>>( + clip: &clip::Config, + clip_weights: P, + device: &Device, + dtype: DType, +) -> Result<clip::ClipTextTransformer> { + let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; + let weights = weights.deserialize()?; + let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device); + let text_model = clip::ClipTextTransformer::new(vs, clip)?; + Ok(text_model) } |