diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index bed60161..cffc00d8 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -28,10 +28,10 @@ impl StableDiffusionConfig { // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 8), - bc(640, true, 8), - bc(1280, true, 8), - bc(1280, false, 8), + bc(320, Some(1), 8), + bc(640, Some(1), 8), + bc(1280, Some(1), 8), + bc(1280, None, 8), ], center_input_sample: false, cross_attention_dim: 768, @@ -90,10 +90,10 @@ impl StableDiffusionConfig { // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { blocks: vec![ - bc(320, true, 5), - bc(640, true, 10), - bc(1280, true, 20), - bc(1280, false, 20), + bc(320, Some(1), 5), + bc(640, Some(1), 10), + bc(1280, Some(1), 20), + bc(1280, None, 20), ], center_input_sample: false, cross_attention_dim: 1024, @@ -171,7 +171,11 @@ impl StableDiffusionConfig { }; // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json let unet = unet_2d::UNet2DConditionModelConfig { - blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)], + blocks: vec![ + bc(320, None, 5), + bc(640, Some(2), 10), + bc(1280, Some(10), 20), + ], center_input_sample: false, cross_attention_dim: 2048, downsample_padding: 1, |