diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-11 19:57:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-11 18:57:06 +0100 |
commit | 1d0157bbc4807f993cecc0de7dbbe0f305a68cd4 (patch) | |
tree | 492d559f4b09b5332127a1af27be49be62ab5d93 /candle-examples/examples/stable-diffusion | |
parent | 91dbf907d3ee45dd4777efa82c1f431907ce8125 (diff) | |
download | candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.gz candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.bz2 candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.zip |
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example.
* Add to the readme.
Diffstat (limited to 'candle-examples/examples/stable-diffusion')
-rw-r--r-- | candle-examples/examples/stable-diffusion/main.rs | 91 | ||||
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 14 |
2 files changed, 71 insertions, 34 deletions
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index ac31e855..5ec40f7d 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -45,21 +45,21 @@ struct Args { #[arg(long)] width: Option<usize>, - /// The UNet weight file, in .ot or .safetensors format. + /// The UNet weight file, in .safetensors format. #[arg(long, value_name = "FILE")] unet_weights: Option<String>, - /// The CLIP weight file, in .ot or .safetensors format. + /// The CLIP weight file, in .safetensors format. #[arg(long, value_name = "FILE")] clip_weights: Option<String>, - /// The VAE weight file, in .ot or .safetensors format. + /// The VAE weight file, in .safetensors format. #[arg(long, value_name = "FILE")] vae_weights: Option<String>, #[arg(long, value_name = "FILE")] /// The file specifying the tokenizer to used for tokenization. - tokenizer: String, + tokenizer: Option<String>, /// The size of the sliced attention or 0 for automatic slicing (disabled by default) #[arg(long)] @@ -91,34 +91,63 @@ enum StableDiffusionVersion { V2_1, } -impl Args { - fn clip_weights(&self) -> String { - match &self.clip_weights { - Some(w) => w.clone(), - None => match self.sd_version { - StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), - StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), - }, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ModelFile { + Tokenizer, + Clip, + Unet, + Vae, +} + +impl StableDiffusionVersion { + fn repo(&self) -> &'static str { + match self { + Self::V2_1 => "stabilityai/stable-diffusion-2-1", + Self::V1_5 => "runwayml/stable-diffusion-v1-5", + } + } + + fn unet_file(&self) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 => "unet/diffusion_pytorch_model.safetensors", } } - fn vae_weights(&self) -> String { - match &self.vae_weights { - Some(w) => w.clone(), - None => match self.sd_version { - StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), - StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), - }, + fn vae_file(&self) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 => "vae/diffusion_pytorch_model.safetensors", } } - fn unet_weights(&self) -> String { - match &self.unet_weights { - Some(w) => w.clone(), - None => match self.sd_version { - StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), - StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), - }, + fn clip_file(&self) -> &'static str { + match self { + Self::V1_5 | Self::V2_1 => "text_encoder/model.safetensors", + } + } +} + +impl ModelFile { + const TOKENIZER_REPO: &str = "openai/clip-vit-base-patch32"; + const TOKENIZER_PATH: &str = "tokenizer.json"; + + fn get( + &self, + filename: Option<String>, + version: StableDiffusionVersion, + ) -> Result<std::path::PathBuf> { + use hf_hub::api::sync::Api; + match filename { + Some(filename) => Ok(std::path::PathBuf::from(filename)), + None => { + let (repo, path) = match self { + Self::Tokenizer => (Self::TOKENIZER_REPO, Self::TOKENIZER_PATH), + Self::Clip => (version.repo(), version.clip_file()), + Self::Unet => (version.repo(), version.unet_file()), + Self::Vae => (version.repo(), version.vae_file()), + }; + let filename = Api::new()?.model(repo.to_string()).get(path)?; + Ok(filename) + } } } } @@ -151,9 +180,6 @@ fn output_filename( } fn run(args: Args) -> Result<()> { - let clip_weights = args.clip_weights(); - let vae_weights = args.vae_weights(); - let unet_weights = args.unet_weights(); let Args { prompt, uncond_prompt, @@ -166,6 +192,9 @@ fn run(args: Args) -> Result<()> { sliced_attention_size, num_samples, sd_version, + clip_weights, + vae_weights, + unet_weights, .. } = args; let sd_config = match sd_version { @@ -180,6 +209,7 @@ fn run(args: Args) -> Result<()> { let scheduler = sd_config.build_scheduler(n_steps)?; let device = candle_examples::device(cpu)?; + let tokenizer = ModelFile::Tokenizer.get(tokenizer, sd_version)?; let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; let pad_id = match &sd_config.clip.pad_with { Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), @@ -207,14 +237,17 @@ fn run(args: Args) -> Result<()> { let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), &device)?.unsqueeze(0)?; println!("Building the Clip transformer."); + let clip_weights = ModelFile::Clip.get(clip_weights, sd_version)?; let text_model = sd_config.build_clip_transformer(&clip_weights, &device)?; let text_embeddings = text_model.forward(&tokens)?; let uncond_embeddings = text_model.forward(&uncond_tokens)?; let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?; println!("Building the autoencoder."); + let vae_weights = ModelFile::Vae.get(vae_weights, sd_version)?; let vae = sd_config.build_vae(&vae_weights, &device)?; println!("Building the unet."); + let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?; let unet = sd_config.build_unet(&unet_weights, &device, 4)?; let bsize = 1; diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index c250ed56..023d8630 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -172,7 +172,11 @@ impl StableDiffusionConfig { ) } - pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> { + pub fn build_vae<P: AsRef<std::path::Path>>( + &self, + vae_weights: P, + device: &Device, + ) -> Result<vae::AutoEncoderKL> { let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; let weights = weights.deserialize()?; let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); @@ -181,9 +185,9 @@ impl StableDiffusionConfig { Ok(autoencoder) } - pub fn build_unet( + pub fn build_unet<P: AsRef<std::path::Path>>( &self, - unet_weights: &str, + unet_weights: P, device: &Device, in_channels: usize, ) -> Result<unet_2d::UNet2DConditionModel> { @@ -198,9 +202,9 @@ impl StableDiffusionConfig { ddim::DDIMScheduler::new(n_steps, self.scheduler) } - pub fn build_clip_transformer( + pub fn build_clip_transformer<P: AsRef<std::path::Path>>( &self, - clip_weights: &str, + clip_weights: P, device: &Device, ) -> Result<clip::ClipTextTransformer> { let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; |