summaryrefslogtreecommitdiff
path: root/candle-transformers
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-10-27 10:01:04 +0100
committerGitHub <noreply@github.com>2024-10-27 10:01:04 +0100
commit37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82 (patch)
tree20981112a4c378cb9b90a2e6f856d2629d05b840 /candle-transformers
parent07849aa595c65309ed9230a4c97035f471c6afb1 (diff)
downloadcandle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.tar.gz
candle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.tar.bz2
candle-37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82.zip
Stable diffusion 3.5 support. (#2578)
* Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code.
Diffstat (limited to 'candle-transformers')
-rw-r--r--candle-transformers/src/models/mmdit/model.rs14
-rw-r--r--candle-transformers/src/models/mmdit/projections.rs30
2 files changed, 43 insertions, 1 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index 864b6623..5b5c90b0 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -36,6 +36,20 @@ impl Config {
frequency_embedding_size: 256,
}
}
+
+ pub fn sd3_5_large() -> Self {
+ Self {
+ patch_size: 2,
+ in_channels: 16,
+ out_channels: 16,
+ depth: 38,
+ head_size: 64,
+ adm_in_channels: 2048,
+ pos_embed_max_size: 192,
+ context_embed_size: 4096,
+ frequency_embedding_size: 256,
+ }
+ }
}
pub struct MMDiT {
diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs
index dc1e8ec9..27753285 100644
--- a/candle-transformers/src/models/mmdit/projections.rs
+++ b/candle-transformers/src/models/mmdit/projections.rs
@@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections {
pub struct AttnProjections {
head_dim: usize,
qkv: nn::Linear,
+ ln_k: Option<candle_nn::RmsNorm>,
+ ln_q: Option<candle_nn::RmsNorm>,
proj: nn::Linear,
}
@@ -64,16 +66,42 @@ impl AttnProjections {
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
let proj = nn::linear(dim, dim, vb.pp("proj"))?;
+ let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") {
+ let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?;
+ let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?;
+ (Some(ln_k), Some(ln_q))
+ } else {
+ (None, None)
+ };
Ok(Self {
head_dim,
qkv,
proj,
+ ln_k,
+ ln_q,
})
}
pub fn pre_attention(&self, x: &Tensor) -> Result<Qkv> {
let qkv = self.qkv.forward(x)?;
- split_qkv(&qkv, self.head_dim)
+ let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
+ let q = match self.ln_q.as_ref() {
+ None => q,
+ Some(l) => {
+ let (b, t, h) = q.dims3()?;
+ l.forward(&q.reshape((b, t, (), self.head_dim))?)?
+ .reshape((b, t, h))?
+ }
+ };
+ let k = match self.ln_k.as_ref() {
+ None => k,
+ Some(l) => {
+ let (b, t, h) = k.dims3()?;
+ l.forward(&k.reshape((b, t, (), self.head_dim))?)?
+ .reshape((b, t, h))?
+ }
+ };
+ Ok(Qkv { q, k, v })
}
pub fn post_attention(&self, x: &Tensor) -> Result<Tensor> {