summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/stable_diffusion.rs
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-08-11 19:57:06 +0200
committerGitHub <noreply@github.com>2023-08-11 18:57:06 +0100
commit1d0157bbc4807f993cecc0de7dbbe0f305a68cd4 (patch)
tree492d559f4b09b5332127a1af27be49be62ab5d93 /candle-examples/examples/stable-diffusion/stable_diffusion.rs
parent91dbf907d3ee45dd4777efa82c1f431907ce8125 (diff)
downloadcandle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.gz
candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.tar.bz2
candle-1d0157bbc4807f993cecc0de7dbbe0f305a68cd4.zip
Stable diffusion: retrieve the model files from the HF hub. (#414)
* Retrieve the model files from the HF hub in the stable diffusion example. * Add to the readme.
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs14
1 files changed, 9 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index c250ed56..023d8630 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -172,7 +172,11 @@ impl StableDiffusionConfig {
)
}
- pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> {
+ pub fn build_vae<P: AsRef<std::path::Path>>(
+ &self,
+ vae_weights: P,
+ device: &Device,
+ ) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
let weights = weights.deserialize()?;
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
@@ -181,9 +185,9 @@ impl StableDiffusionConfig {
Ok(autoencoder)
}
- pub fn build_unet(
+ pub fn build_unet<P: AsRef<std::path::Path>>(
&self,
- unet_weights: &str,
+ unet_weights: P,
device: &Device,
in_channels: usize,
) -> Result<unet_2d::UNet2DConditionModel> {
@@ -198,9 +202,9 @@ impl StableDiffusionConfig {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
- pub fn build_clip_transformer(
+ pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
&self,
- clip_weights: &str,
+ clip_weights: P,
device: &Device,
) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };