summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--candle-examples/examples/stable-diffusion/attention.rs67
-rw-r--r--candle-examples/examples/stable-diffusion/main.rs5
-rw-r--r--candle-examples/examples/stable-diffusion/stable_diffusion.rs26
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d.rs5
-rw-r--r--candle-examples/examples/stable-diffusion/unet_2d_blocks.rs7
5 files changed, 78 insertions, 32 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs
index c3f76cd0..dc414889 100644
--- a/candle-examples/examples/stable-diffusion/attention.rs
+++ b/candle-examples/examples/stable-diffusion/attention.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
//! Attention Based Building Blocks
use candle::{IndexOp, Result, Tensor, D};
use candle_nn as nn;
@@ -61,6 +60,22 @@ impl FeedForward {
}
}
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
#[derive(Debug)]
struct CrossAttention {
to_q: nn::Linear,
@@ -72,6 +87,7 @@ struct CrossAttention {
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
+ use_flash_attn: bool,
}
impl CrossAttention {
@@ -83,6 +99,7 @@ impl CrossAttention {
heads: usize,
dim_head: usize,
slice_size: Option<usize>,
+ use_flash_attn: bool,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let context_dim = context_dim.unwrap_or(query_dim);
@@ -103,6 +120,7 @@ impl CrossAttention {
slice_size,
span,
span_attn,
+ use_flash_attn,
})
}
@@ -146,8 +164,28 @@ impl CrossAttention {
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
- let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
- let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
+ let xs = if self.use_flash_attn {
+ let init_dtype = query.dtype();
+ let q = query
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let k = key
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ let v = value
+ .to_dtype(candle::DType::F16)?
+ .unsqueeze(0)?
+ .transpose(1, 2)?;
+ flash_attn(&q, &k, &v, self.scale as f32, false)?
+ .transpose(1, 2)?
+ .squeeze(0)?
+ .to_dtype(init_dtype)?
+ } else {
+ let xs = query.matmul(&(key.t()? * self.scale)?)?;
+ nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?
+ };
self.reshape_batch_dim_to_heads(&xs)
}
@@ -160,15 +198,17 @@ impl CrossAttention {
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;
- let xs = match self.slice_size {
- None => self.attention(&query, &key, &value)?,
- Some(slice_size) => {
- if query.dim(0)? / slice_size <= 1 {
- self.attention(&query, &key, &value)?
- } else {
- self.sliced_attention(&query, &key, &value, slice_size)?
- }
+ let dim0 = query.dim(0)?;
+ let slice_size = self.slice_size.and_then(|slice_size| {
+ if dim0 < slice_size {
+ None
+ } else {
+ Some(slice_size)
}
+ });
+ let xs = match slice_size {
+ None => self.attention(&query, &key, &value)?,
+ Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?,
};
self.to_out.forward(&xs)
}
@@ -194,6 +234,7 @@ impl BasicTransformerBlock {
d_head: usize,
context_dim: Option<usize>,
sliced_attention_size: Option<usize>,
+ use_flash_attn: bool,
) -> Result<Self> {
let attn1 = CrossAttention::new(
vs.pp("attn1"),
@@ -202,6 +243,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
+ use_flash_attn,
)?;
let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?;
let attn2 = CrossAttention::new(
@@ -211,6 +253,7 @@ impl BasicTransformerBlock {
n_heads,
d_head,
sliced_attention_size,
+ use_flash_attn,
)?;
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
@@ -279,6 +322,7 @@ impl SpatialTransformer {
in_channels: usize,
n_heads: usize,
d_head: usize,
+ use_flash_attn: bool,
config: SpatialTransformerConfig,
) -> Result<Self> {
let inner_dim = n_heads * d_head;
@@ -304,6 +348,7 @@ impl SpatialTransformer {
d_head,
config.context_dim,
config.sliced_attention_size,
+ use_flash_attn,
)?;
transformer_blocks.push(tb)
}
diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs
index 2c9e2021..de20d4a7 100644
--- a/candle-examples/examples/stable-diffusion/main.rs
+++ b/candle-examples/examples/stable-diffusion/main.rs
@@ -90,6 +90,9 @@ struct Args {
/// Generate intermediary images at each step.
#[arg(long, action)]
intermediary_images: bool,
+
+ #[arg(long)]
+ use_flash_attn: bool,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> {
let vae = sd_config.build_vae(&vae_weights, &device)?;
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?;
- let unet = sd_config.build_unet(&unet_weights, &device, 4)?;
+ let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?;
let bsize = 1;
for idx in 0..num_samples {
diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
index 023d8630..05ba41cb 100644
--- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs
+++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
use crate::schedulers::PredictionType;
use crate::{clip, ddim, unet_2d, vae};
use candle::{DType, Device, Result};
@@ -156,22 +155,6 @@ impl StableDiffusionConfig {
)
}
- pub fn v2_1_inpaint(
- sliced_attention_size: Option<usize>,
- height: Option<usize>,
- width: Option<usize>,
- ) -> Self {
- // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json
- // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction
- // type being "epsilon" by default and not "v_prediction".
- Self::v2_1_(
- sliced_attention_size,
- height,
- width,
- PredictionType::Epsilon,
- )
- }
-
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,
@@ -190,11 +173,18 @@ impl StableDiffusionConfig {
unet_weights: P,
device: &Device,
in_channels: usize,
+ use_flash_attn: bool,
) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
let weights = weights.deserialize()?;
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device);
- let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?;
+ let unet = unet_2d::UNet2DConditionModel::new(
+ vs_unet,
+ in_channels,
+ 4,
+ use_flash_attn,
+ self.unet.clone(),
+ )?;
Ok(unet)
}
diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs
index 0bc820db..e52ec281 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
//! 2D UNet Denoising Models
//!
//! The 2D Unet models take as input a noisy sample and the current diffusion
@@ -103,6 +102,7 @@ impl UNet2DConditionModel {
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
+ use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
@@ -161,6 +161,7 @@ impl UNet2DConditionModel {
in_channels,
out_channels,
Some(time_embed_dim),
+ use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
@@ -190,6 +191,7 @@ impl UNet2DConditionModel {
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
+ use_flash_attn,
mid_cfg,
)?;
@@ -242,6 +244,7 @@ impl UNet2DConditionModel {
prev_out_channels,
out_channels,
Some(time_embed_dim),
+ use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))
diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
index 2f397815..b7adb2c0 100644
--- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
+++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs
@@ -1,4 +1,3 @@
-#![allow(dead_code)]
//! 2D UNet Building Blocks
//!
use crate::attention::{
@@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn {
vs: nn::VarBuilder,
in_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: UNetMidBlock2DCrossAttnConfig,
) -> Result<Self> {
let vs_resnets = vs.pp("resnets");
@@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn {
in_channels,
n_heads,
in_channels / n_heads,
+ use_flash_attn,
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
@@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D {
in_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: CrossAttnDownBlock2DConfig,
) -> Result<Self> {
let downblock = DownBlock2D::new(
@@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
+ use_flash_attn,
cfg,
)
})
@@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D {
prev_output_channels: usize,
out_channels: usize,
temb_channels: Option<usize>,
+ use_flash_attn: bool,
config: CrossAttnUpBlock2DConfig,
) -> Result<Self> {
let upblock = UpBlock2D::new(
@@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D {
out_channels,
n_heads,
out_channels / n_heads,
+ use_flash_attn,
cfg,
)
})