diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-08-11 19:57:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-11 18:57:06 +0100 |
commit | 1d0157bbc4807f993cecc0de7dbbe0f305a68cd4 (patch) | |
tree | 492d559f4b09b5332127a1af27be49be62ab5d93 /candle-examples/examples/stable-diffusion/stable_diffusion.rs | |
parent | 91dbf907d3ee45dd4777efa82c1f431907ce8125 (diff) | |
download | candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.gz candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.bz2 candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.zip |
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example.
* Add to the readme.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index c250ed56..023d8630 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -172,7 +172,11 @@ impl StableDiffusionConfig { ) } - pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> { + pub fn build_vae<P: AsRef<std::path::Path>>( + &self, + vae_weights: P, + device: &Device, + ) -> Result<vae::AutoEncoderKL> { let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; let weights = weights.deserialize()?; let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); @@ -181,9 +185,9 @@ impl StableDiffusionConfig { Ok(autoencoder) } - pub fn build_unet( + pub fn build_unet<P: AsRef<std::path::Path>>( &self, - unet_weights: &str, + unet_weights: P, device: &Device, in_channels: usize, ) -> Result<unet_2d::UNet2DConditionModel> { @@ -198,9 +202,9 @@ impl StableDiffusionConfig { ddim::DDIMScheduler::new(n_steps, self.scheduler) } - pub fn build_clip_transformer( + pub fn build_clip_transformer<P: AsRef<std::path::Path>>( &self, - clip_weights: &str, + clip_weights: P, device: &Device, ) -> Result<clip::ClipTextTransformer> { let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; |