diff options
author | Czxck001 <10724409+Czxck001@users.noreply.github.com> | 2024-10-13 13:08:40 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-13 22:08:40 +0200 |
commit | ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e (patch) | |
tree | 8f61fd8b9a4c86b08e50328d051e0acec3945fb3 /candle-transformers | |
parent | 0d96ec31e8be03f844ed0aed636d6217dee9c7bc (diff) | |
download | candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.gz candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.tar.bz2 candle-ca7cf5cb3bb38d1b735e1db0efdac7eea1a9d43e.zip |
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example
Add get_qkv_linear to handle different dimensionality in linears
Add stable diffusion 3 example
Add use_quant_conv and use_post_quant_conv for vae in stable diffusion
adapt existing AutoEncoderKLConfig to the change
add forward_until_encoder_layer to ClipTextTransformer
rename sd3 config to sd3_medium in mmdit; minor clean-up
Enable flash-attn for mmdit impl when the feature is enabled.
Add sd3 example codebase
add document
crediting references
pass the cargo fmt test
pass the clippy test
* fix typos
* expose cfg_scale and time_shift as options
* Replace the sample image with JPG version. Change image output format accordingly.
* make meaningful error messages
* remove the tail-end assignment in sd3_vae_vb_rename
* remove the CUDA requirement
* use default_value in clap args
* add use_flash_attn to turn on/off flash-attn for MMDiT at runtime
* resolve clippy errors and warnings
* use default_value_t
* Pin the web-sys dependency.
* Clippy fix.
---------
Co-authored-by: Laurent <laurent.mazare@gmail.com>
Diffstat (limited to 'candle-transformers')
7 files changed, 158 insertions, 33 deletions
diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs index e2b924a0..a1777f91 100644 --- a/candle-transformers/src/models/mmdit/blocks.rs +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -194,10 +194,16 @@ pub struct JointBlock { x_block: DiTBlock, context_block: DiTBlock, num_heads: usize, + use_flash_attn: bool, } impl JointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result<Self> { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; @@ -205,13 +211,15 @@ impl JointBlock { x_block, context_block, num_heads, + use_flash_attn, }) } pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (context_attn, x_attn) = + joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let context_out = self.context_block .post_attention(&context_attn, context, &context_interm)?; @@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock { x_block: DiTBlock, context_block: QkvOnlyDiTBlock, num_heads: usize, + use_flash_attn: bool, } impl ContextQkvOnlyJointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result<Self> { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; Ok(Self { x_block, context_block, num_heads, + use_flash_attn, }) } @@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock { let context_qkv = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; Ok(x_out) @@ -266,7 +281,28 @@ fn flash_compatible_attention( attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) } -fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +fn joint_attn( + context_qkv: &Qkv, + x_qkv: &Qkv, + num_heads: usize, + use_flash_attn: bool, +) -> Result<(Tensor, Tensor)> { let qkv = Qkv { q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, @@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso let headdim = qkv.q.dim(D::Minus1)?; let softmax_scale = 1.0 / (headdim as f64).sqrt(); - // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; - let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = if use_flash_attn { + flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)? + } else { + flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)? + }; let attn = attn.reshape((batch_size, seqlen, ()))?; let context_qkv_seqlen = context_qkv.q.dim(1)?; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 1523836c..864b6623 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -23,7 +23,7 @@ pub struct Config { } impl Config { - pub fn sd3() -> Self { + pub fn sd3_medium() -> Self { Self { patch_size: 2, in_channels: 16, @@ -49,7 +49,7 @@ pub struct MMDiT { } impl MMDiT { - pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> { + pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> { let hidden_size = cfg.head_size * cfg.depth; let core = MMDiTCore::new( cfg.depth, @@ -57,6 +57,7 @@ impl MMDiT { cfg.depth, cfg.patch_size, cfg.out_channels, + use_flash_attn, vb.clone(), )?; let patch_embedder = PatchEmbedder::new( @@ -135,6 +136,7 @@ impl MMDiTCore { num_heads: usize, patch_size: usize, out_channels: usize, + use_flash_attn: bool, vb: nn::VarBuilder, ) -> Result<Self> { let mut joint_blocks = Vec::with_capacity(depth - 1); @@ -142,6 +144,7 @@ impl MMDiTCore { joint_blocks.push(JointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", i)), )?); } @@ -151,6 +154,7 @@ impl MMDiTCore { context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", depth - 1)), )?, final_layer: FinalLayer::new( diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index 1077398f..dc1e8ec9 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections { impl QkvOnlyAttnProjections { pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { - // {'dim': 1536, 'num_heads': 24} let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; Ok(Self { qkv, head_dim }) diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 5cc59e82..c04e6aa1 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -467,6 +467,24 @@ pub struct AttentionBlock { config: AttentionBlockConfig, } +// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo +// https://huggingface.co/stabilityai/stable-diffusion-3-medium +// Linear layer may use a different dimension for the weight in the linear, which is +// incompatible with the current implementation of the nn::linear constructor. +// This is a workaround to handle the different dimensions. +fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> { + match vs.get((channels, channels), "weight") { + Ok(_) => nn::linear(channels, channels, vs), + Err(_) => { + let weight = vs + .get((channels, channels, 1, 1), "weight")? + .reshape((channels, channels))?; + let bias = vs.get((channels,), "bias")?; + Ok(nn::Linear::new(weight, Some(bias))) + } + } +} + impl AttentionBlock { pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> { let num_head_channels = config.num_head_channels.unwrap_or(channels); @@ -478,10 +496,10 @@ impl AttentionBlock { } else { ("query", "key", "value", "proj_attn") }; - let query = nn::linear(channels, channels, vs.pp(q_path))?; - let key = nn::linear(channels, channels, vs.pp(k_path))?; - let value = nn::linear(channels, channels, vs.pp(v_path))?; - let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?; + let query = get_qkv_linear(channels, vs.pp(q_path))?; + let key = get_qkv_linear(channels, vs.pp(k_path))?; + let value = get_qkv_linear(channels, vs.pp(v_path))?; + let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?; let span = tracing::span!(tracing::Level::TRACE, "attn-block"); Ok(Self { group_norm, diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 5254818e..2f631248 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -388,6 +388,37 @@ impl ClipTextTransformer { let xs = self.encoder.forward(&xs, &causal_attention_mask)?; self.final_layer_norm.forward(&xs) } + + pub fn forward_until_encoder_layer( + &self, + xs: &Tensor, + mask_after: usize, + until_layer: isize, + ) -> Result<(Tensor, Tensor)> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + + let mut xs = xs.clone(); + let mut intermediate = xs.clone(); + + // Modified encoder.forward that returns the intermediate tensor along with final output. + let until_layer = if until_layer < 0 { + self.encoder.layers.len() as isize + until_layer + } else { + until_layer + } as usize; + + for (layer_id, layer) in self.encoder.layers.iter().enumerate() { + xs = layer.forward(&xs, &causal_attention_mask)?; + if layer_id == until_layer { + intermediate = xs.clone(); + } + } + + Ok((self.final_layer_norm.forward(&xs)?, intermediate)) + } } impl Module for ClipTextTransformer { diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 30f23975..37f4cdbf 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -65,6 +65,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -133,6 +135,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -214,6 +218,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -281,6 +287,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new( euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { @@ -378,6 +386,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 670b3f56..b3aba802 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -275,6 +275,8 @@ pub struct AutoEncoderKLConfig { pub layers_per_block: usize, pub latent_channels: usize, pub norm_num_groups: usize, + pub use_quant_conv: bool, + pub use_post_quant_conv: bool, } impl Default for AutoEncoderKLConfig { @@ -284,6 +286,8 @@ impl Default for AutoEncoderKLConfig { layers_per_block: 1, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, } } } @@ -315,8 +319,8 @@ impl DiagonalGaussianDistribution { pub struct AutoEncoderKL { encoder: Encoder, decoder: Decoder, - quant_conv: nn::Conv2d, - post_quant_conv: nn::Conv2d, + quant_conv: Option<nn::Conv2d>, + post_quant_conv: Option<nn::Conv2d>, pub config: AutoEncoderKLConfig, } @@ -342,20 +346,33 @@ impl AutoEncoderKL { }; let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; let conv_cfg = Default::default(); - let quant_conv = nn::conv2d( - 2 * latent_channels, - 2 * latent_channels, - 1, - conv_cfg, - vs.pp("quant_conv"), - )?; - let post_quant_conv = nn::conv2d( - latent_channels, - latent_channels, - 1, - conv_cfg, - vs.pp("post_quant_conv"), - )?; + + let quant_conv = { + if config.use_quant_conv { + Some(nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?) + } else { + None + } + }; + let post_quant_conv = { + if config.use_post_quant_conv { + Some(nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?) + } else { + None + } + }; Ok(Self { encoder, decoder, @@ -368,13 +385,19 @@ impl AutoEncoderKL { /// Returns the distribution in the latent space. pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> { let xs = self.encoder.forward(xs)?; - let parameters = self.quant_conv.forward(&xs)?; + let parameters = match &self.quant_conv { + None => xs, + Some(quant_conv) => quant_conv.forward(&xs)?, + }; DiagonalGaussianDistribution::new(¶meters) } /// Takes as input some sampled values. pub fn decode(&self, xs: &Tensor) -> Result<Tensor> { - let xs = self.post_quant_conv.forward(xs)?; - self.decoder.forward(&xs) + let xs = match &self.post_quant_conv { + None => xs, + Some(post_quant_conv) => &post_quant_conv.forward(xs)?, + }; + self.decoder.forward(xs) } } |