From 37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Oct 2024 10:01:04 +0100 Subject: Stable diffusion 3.5 support. (#2578) * Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code. --- candle-transformers/src/models/mmdit/model.rs | 14 ++++++++++ .../src/models/mmdit/projections.rs | 30 +++++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) (limited to 'candle-transformers') diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 864b6623..5b5c90b0 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -36,6 +36,20 @@ impl Config { frequency_embedding_size: 256, } } + + pub fn sd3_5_large() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 38, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } } pub struct MMDiT { diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index dc1e8ec9..27753285 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections { pub struct AttnProjections { head_dim: usize, qkv: nn::Linear, + ln_k: Option, + ln_q: Option, proj: nn::Linear, } @@ -64,16 +66,42 @@ impl AttnProjections { let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; let proj = nn::linear(dim, dim, vb.pp("proj"))?; + let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") { + let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?; + let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?; + (Some(ln_k), Some(ln_q)) + } else { + (None, None) + }; Ok(Self { head_dim, qkv, proj, + ln_k, + ln_q, }) } pub fn pre_attention(&self, x: &Tensor) -> Result { let qkv = self.qkv.forward(x)?; - split_qkv(&qkv, self.head_dim) + let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?; + let q = match self.ln_q.as_ref() { + None => q, + Some(l) => { + let (b, t, h) = q.dims3()?; + l.forward(&q.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + let k = match self.ln_k.as_ref() { + None => k, + Some(l) => { + let (b, t, h) = k.dims3()?; + l.forward(&k.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + Ok(Qkv { q, k, v }) } pub fn post_attention(&self, x: &Tensor) -> Result { -- cgit v1.2.3