diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2024-10-27 10:01:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-27 10:01:04 +0100 |
commit | 37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82 (patch) | |
tree | 20981112a4c378cb9b90a2e6f856d2629d05b840 /candle-transformers | |
parent | 07849aa595c65309ed9230a4c97035f471c6afb1 (diff) | |
download | candle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.tar.gz candle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.tar.bz2 candle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.zip |
Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support.
* Clippy fixes.
* CFG fix.
* Remove some unnecessary clones.
* Avoid duplicating some of the code.
Diffstat (limited to 'candle-transformers')
-rw-r--r-- | candle-transformers/src/models/mmdit/model.rs | 14 | ||||
-rw-r--r-- | candle-transformers/src/models/mmdit/projections.rs | 30 |
2 files changed, 43 insertions, 1 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 864b6623..5b5c90b0 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -36,6 +36,20 @@ impl Config { frequency_embedding_size: 256, } } + + pub fn sd3_5_large() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 38, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } } pub struct MMDiT { diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index dc1e8ec9..27753285 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections { pub struct AttnProjections { head_dim: usize, qkv: nn::Linear, + ln_k: Option<candle_nn::RmsNorm>, + ln_q: Option<candle_nn::RmsNorm>, proj: nn::Linear, } @@ -64,16 +66,42 @@ impl AttnProjections { let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; let proj = nn::linear(dim, dim, vb.pp("proj"))?; + let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") { + let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?; + let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?; + (Some(ln_k), Some(ln_q)) + } else { + (None, None) + }; Ok(Self { head_dim, qkv, proj, + ln_k, + ln_q, }) } pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> { let qkv = self.qkv.forward(x)?; - split_qkv(&qkv, self.head_dim) + let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?; + let q = match self.ln_q.as_ref() { + None => q, + Some(l) => { + let (b, t, h) = q.dims3()?; + l.forward(&q.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + let k = match self.ln_k.as_ref() { + None => k, + Some(l) => { + let (b, t, h) = k.dims3()?; + l.forward(&k.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + Ok(Qkv { q, k, v }) } pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> { |