summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-10-09 06:52:28 +0100
committerGitHub <noreply@github.com>2023-10-09 06:52:28 +0100
commit4d04ac83c7fd65c6f6e5f08ab87f27ab6e5af21b (patch)
tree0910fa0a968832046b6e7aea67c18432a23ae383 /candle-examples/examples/stable-diffusion
parent392fe02fba96658bafc73100e80bf68d54e4e23f (diff)
downloadcandle-4d04ac83c7fd65c6f6e5f08ab87f27ab6e5af21b.tar.gz
candle-4d04ac83c7fd65c6f6e5f08ab87f27ab6e5af21b.tar.bz2
candle-4d04ac83c7fd65c6f6e5f08ab87f27ab6e5af21b.zip
Override the repo for SDXL f16 vae weights. (#1064)
* Override the repo for SDXL f16 vae weights. * Slightly simpler change.
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs15
1 files changed, 13 insertions, 2 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 6a08d9c8..74e816d4 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -97,7 +97,7 @@ struct Args {
img2img_strength: f64,
}
-#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum StableDiffusionVersion {
V1_5,
V2_1,
@@ -204,7 +204,18 @@ impl ModelFile {
Self::Clip => (version.repo(), version.clip_file(use_f16)),
Self::Clip2 => (version.repo(), version.clip2_file(use_f16)),
Self::Unet => (version.repo(), version.unet_file(use_f16)),
- Self::Vae => (version.repo(), version.vae_file(use_f16)),
+ Self::Vae => {
+ // Override for SDXL when using f16 weights.
+ // See https://github.com/huggingface/candle/issues/1060
+ if version == StableDiffusionVersion::Xl && use_f16 {
+ (
+ "madebyollin/sdxl-vae-fp16-fix",
+ "diffusion_pytorch_model.safetensors",
+ )
+ } else {
+ (version.repo(), version.vae_file(use_f16))
+ }
+ }
};
let filename = Api::new()?.model(repo.to_string()).get(path)?;
Ok(filename)