summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/stable_diffusion.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs22
1 files changed, 13 insertions, 9 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index bed60161..cffc00d8 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -28,10 +28,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 8),
- bc(640, true, 8),
- bc(1280, true, 8),
- bc(1280, false, 8),
+ bc(320, Some(1), 8),
+ bc(640, Some(1), 8),
+ bc(1280, Some(1), 8),
+ bc(1280, None, 8),
],
center_input_sample: false,
cross_attention_dim: 768,
@@ -90,10 +90,10 @@ impl StableDiffusionConfig {
// https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
- bc(320, true, 5),
- bc(640, true, 10),
- bc(1280, true, 20),
- bc(1280, false, 20),
+ bc(320, Some(1), 5),
+ bc(640, Some(1), 10),
+ bc(1280, Some(1), 20),
+ bc(1280, None, 20),
],
center_input_sample: false,
cross_attention_dim: 1024,
@@ -171,7 +171,11 @@ impl StableDiffusionConfig {
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
- blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
+ blocks: vec![
+ bc(320, None, 5),
+ bc(640, Some(2), 10),
+ bc(1280, Some(10), 20),
+ ],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,