diff options
5 files changed, 78 insertions, 32 deletions
diff --git a/candle-examples/examples/stable-diffusion/attention.rs b/candle-examples/examples/stable-diffusion/attention.rs index c3f76cd0..dc414889 100644 --- a/candle-examples/examples/stable-diffusion/attention.rs +++ b/candle-examples/examples/stable-diffusion/attention.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] //! Attention Based Building Blocks use candle::{IndexOp, Result, Tensor, D}; use candle_nn as nn; @@ -61,6 +60,22 @@ impl FeedForward { } } +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + #[derive(Debug)] struct CrossAttention { to_q: nn::Linear, @@ -72,6 +87,7 @@ struct CrossAttention { slice_size: Option<usize>, span: tracing::Span, span_attn: tracing::Span, + use_flash_attn: bool, } impl CrossAttention { @@ -83,6 +99,7 @@ impl CrossAttention { heads: usize, dim_head: usize, slice_size: Option<usize>, + use_flash_attn: bool, ) -> Result<Self> { let inner_dim = dim_head * heads; let context_dim = context_dim.unwrap_or(query_dim); @@ -103,6 +120,7 @@ impl CrossAttention { slice_size, span, span_attn, + use_flash_attn, }) } @@ -146,8 +164,28 @@ impl CrossAttention { fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { let _enter = self.span_attn.enter(); - let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?; - let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?; + let xs = if self.use_flash_attn { + let init_dtype = query.dtype(); + let q = query + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let k = key + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let v = value + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + flash_attn(&q, &k, &v, self.scale as f32, false)? + .transpose(1, 2)? + .squeeze(0)? + .to_dtype(init_dtype)? + } else { + let xs = query.matmul(&(key.t()? * self.scale)?)?; + nn::ops::softmax(&xs, D::Minus1)?.matmul(value)? + }; self.reshape_batch_dim_to_heads(&xs) } @@ -160,15 +198,17 @@ impl CrossAttention { let query = self.reshape_heads_to_batch_dim(&query)?; let key = self.reshape_heads_to_batch_dim(&key)?; let value = self.reshape_heads_to_batch_dim(&value)?; - let xs = match self.slice_size { - None => self.attention(&query, &key, &value)?, - Some(slice_size) => { - if query.dim(0)? / slice_size <= 1 { - self.attention(&query, &key, &value)? - } else { - self.sliced_attention(&query, &key, &value, slice_size)? - } + let dim0 = query.dim(0)?; + let slice_size = self.slice_size.and_then(|slice_size| { + if dim0 < slice_size { + None + } else { + Some(slice_size) } + }); + let xs = match slice_size { + None => self.attention(&query, &key, &value)?, + Some(slice_size) => self.sliced_attention(&query, &key, &value, slice_size)?, }; self.to_out.forward(&xs) } @@ -194,6 +234,7 @@ impl BasicTransformerBlock { d_head: usize, context_dim: Option<usize>, sliced_attention_size: Option<usize>, + use_flash_attn: bool, ) -> Result<Self> { let attn1 = CrossAttention::new( vs.pp("attn1"), @@ -202,6 +243,7 @@ impl BasicTransformerBlock { n_heads, d_head, sliced_attention_size, + use_flash_attn, )?; let ff = FeedForward::new(vs.pp("ff"), dim, None, 4)?; let attn2 = CrossAttention::new( @@ -211,6 +253,7 @@ impl BasicTransformerBlock { n_heads, d_head, sliced_attention_size, + use_flash_attn, )?; let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?; let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?; @@ -279,6 +322,7 @@ impl SpatialTransformer { in_channels: usize, n_heads: usize, d_head: usize, + use_flash_attn: bool, config: SpatialTransformerConfig, ) -> Result<Self> { let inner_dim = n_heads * d_head; @@ -304,6 +348,7 @@ impl SpatialTransformer { d_head, config.context_dim, config.sliced_attention_size, + use_flash_attn, )?; transformer_blocks.push(tb) } diff --git a/candle-examples/examples/stable-diffusion/main.rs b/candle-examples/examples/stable-diffusion/main.rs index 2c9e2021..de20d4a7 100644 --- a/candle-examples/examples/stable-diffusion/main.rs +++ b/candle-examples/examples/stable-diffusion/main.rs @@ -90,6 +90,9 @@ struct Args { /// Generate intermediary images at each step. #[arg(long, action)] intermediary_images: bool, + + #[arg(long)] + use_flash_attn: bool, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -268,7 +271,7 @@ fn run(args: Args) -> Result<()> { let vae = sd_config.build_vae(&vae_weights, &device)?; println!("Building the unet."); let unet_weights = ModelFile::Unet.get(unet_weights, sd_version)?; - let unet = sd_config.build_unet(&unet_weights, &device, 4)?; + let unet = sd_config.build_unet(&unet_weights, &device, 4, args.use_flash_attn)?; let bsize = 1; for idx in 0..num_samples { diff --git a/candle-examples/examples/stable-diffusion/stable_diffusion.rs b/candle-examples/examples/stable-diffusion/stable_diffusion.rs index 023d8630..05ba41cb 100644 --- a/candle-examples/examples/stable-diffusion/stable_diffusion.rs +++ b/candle-examples/examples/stable-diffusion/stable_diffusion.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] use crate::schedulers::PredictionType; use crate::{clip, ddim, unet_2d, vae}; use candle::{DType, Device, Result}; @@ -156,22 +155,6 @@ impl StableDiffusionConfig { ) } - pub fn v2_1_inpaint( - sliced_attention_size: Option<usize>, - height: Option<usize>, - width: Option<usize>, - ) -> Self { - // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json - // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction - // type being "epsilon" by default and not "v_prediction". - Self::v2_1_( - sliced_attention_size, - height, - width, - PredictionType::Epsilon, - ) - } - pub fn build_vae<P: AsRef<std::path::Path>>( &self, vae_weights: P, @@ -190,11 +173,18 @@ impl StableDiffusionConfig { unet_weights: P, device: &Device, in_channels: usize, + use_flash_attn: bool, ) -> Result<unet_2d::UNet2DConditionModel> { let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; let weights = weights.deserialize()?; let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], DType::F32, device); - let unet = unet_2d::UNet2DConditionModel::new(vs_unet, in_channels, 4, self.unet.clone())?; + let unet = unet_2d::UNet2DConditionModel::new( + vs_unet, + in_channels, + 4, + use_flash_attn, + self.unet.clone(), + )?; Ok(unet) } diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index 0bc820db..e52ec281 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] //! 2D UNet Denoising Models //! //! The 2D Unet models take as input a noisy sample and the current diffusion @@ -103,6 +102,7 @@ impl UNet2DConditionModel { vs: nn::VarBuilder, in_channels: usize, out_channels: usize, + use_flash_attn: bool, config: UNet2DConditionModelConfig, ) -> Result<Self> { let n_blocks = config.blocks.len(); @@ -161,6 +161,7 @@ impl UNet2DConditionModel { in_channels, out_channels, Some(time_embed_dim), + use_flash_attn, config, )?; Ok(UNetDownBlock::CrossAttn(block)) @@ -190,6 +191,7 @@ impl UNet2DConditionModel { vs.pp("mid_block"), bl_channels, Some(time_embed_dim), + use_flash_attn, mid_cfg, )?; @@ -242,6 +244,7 @@ impl UNet2DConditionModel { prev_out_channels, out_channels, Some(time_embed_dim), + use_flash_attn, config, )?; Ok(UNetUpBlock::CrossAttn(block)) diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index 2f397815..b7adb2c0 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] //! 2D UNet Building Blocks //! use crate::attention::{ @@ -393,6 +392,7 @@ impl UNetMidBlock2DCrossAttn { vs: nn::VarBuilder, in_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: UNetMidBlock2DCrossAttnConfig, ) -> Result<Self> { let vs_resnets = vs.pp("resnets"); @@ -423,6 +423,7 @@ impl UNetMidBlock2DCrossAttn { in_channels, n_heads, in_channels / n_heads, + use_flash_attn, attn_cfg, )?; let resnet = ResnetBlock2D::new( @@ -588,6 +589,7 @@ impl CrossAttnDownBlock2D { in_channels: usize, out_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: CrossAttnDownBlock2DConfig, ) -> Result<Self> { let downblock = DownBlock2D::new( @@ -613,6 +615,7 @@ impl CrossAttnDownBlock2D { out_channels, n_heads, out_channels / n_heads, + use_flash_attn, cfg, ) }) @@ -789,6 +792,7 @@ impl CrossAttnUpBlock2D { prev_output_channels: usize, out_channels: usize, temb_channels: Option<usize>, + use_flash_attn: bool, config: CrossAttnUpBlock2DConfig, ) -> Result<Self> { let upblock = UpBlock2D::new( @@ -815,6 +819,7 @@ impl CrossAttnUpBlock2D { out_channels, n_heads, out_channels / n_heads, + use_flash_attn, cfg, ) }) |