summaryrefslogtreecommitdiff
path: root/candle-examples/examples/stable-diffusion/stable_diffusion.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-examples/examples/stable-diffusion/stable_diffusion.rs')
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs26
1 files changed, 8 insertions, 18 deletions
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index 023d8630..05ba41cb 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
@@ -156,22 +155,6 @@ impl StableDiffusionConfig {
)
}
- pub fn v2_1_inpaint(
- sliced_attention_size: Option<usize>,
- height: Option<usize>,
- width: Option<usize>,
- ) -> Self {
- // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
- // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
- // type being "epsilon" by default and not "v_prediction".
- Self::v2_1_(
- sliced_attention_size,
- height,
- width,
- PredictionType::Epsilon,
- )
- }
-
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -190,11 +173,18 @@ impl StableDiffusionConfig {
unet_weights: P,
device: &Device,
in_channels: usize,
+ use_flash_attn: bool,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
- let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
+ let unet = unet_2d::UNet2DConditionModel::new(
+ vs_unet,
+ in_channels,
+ 4,
+ use_flash_attn,
+ self.unet.clone(),
+ )?;
Ok(unet)
}