diff options
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r-- | candle-examples/examples/stable-diffusion/stable_diffusion.rs | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index 023d8630..05ba41cb 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] use crate::schedulers::PredictionType; use crate::{clip, ddim, unet_2d, vae}; use candle::{DType, Device, Result}; @@ -156,22 +155,6 @@ impl StableDiffusionConfig { ) } - pub fn v2_1_inpaint( - sliced_attention_size: Option<usize>, - height: Option<usize>, - width: Option<usize>, - ) -> Self { - // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json - // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction - // type being "epsilon" by default and not "v_prediction". - Self::v2_1_( - sliced_attention_size, - height, - width, - PredictionType::Epsilon, - ) - } - pub fn build_vae<P: AsRef<std::path::Path>>( &self, vae_weights: P, @@ -190,11 +173,18 @@ impl StableDiffusionConfig { unet_weights: P, device: &Device, in_channels: usize, + use_flash_attn: bool, ) -> Result<unet_2d::UNet2DConditionModel> { let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; let weights = weights.deserialize()?; let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); - let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?; + let unet = unet_2d::UNet2DConditionModel::new( + vs_unet, + in_channels, + 4, + use_flash_attn, + self.unet.clone(), + )?; Ok(unet) } |