diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index c250ed56..023d8630 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -172,7 +172,11 @@ impl StableDiffusionConfig { ) } - pub fn build_vae(&self, vae_weights: &str, device: &Device) -> Result<vae::AutoEncoderKL> { + pub fn build_vae<P: AsRef<std::path::Path>>( + &self, + vae_weights: P, + device: &Device, + ) -> Result<vae::AutoEncoderKL> { let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; let weights = weights.deserialize()?; let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); @@ -181,9 +185,9 @@ impl StableDiffusionConfig { Ok(autoencoder) } - pub fn build_unet( + pub fn build_unet<P: AsRef<std::path::Path>>( &self, - unet_weights: &str, + unet_weights: P, device: &Device, in_channels: usize, ) -> Result<unet_2d::UNet2DConditionModel> { @@ -198,9 +202,9 @@ impl StableDiffusionConfig { ddim::DDIMScheduler::new(n_steps, self.scheduler) } - pub fn build_clip_transformer( + pub fn build_clip_transformer<P: AsRef<std::path::Path>>( &self, - clip_weights: &str, + clip_weights: P, device: &Device, ) -> Result<clip::ClipTextTransformer> { let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; |