summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mmdit/model.rs
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mmdit/model.rs')
-rw-r--r--candle-transformers/src/models/mmdit/model.rs8
1 files changed, 6 insertions, 2 deletions
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index 1523836c..864b6623 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -23,7 +23,7 @@ pub struct Config {
}
impl Config {
- pub fn sd3() -> Self {
+ pub fn sd3_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
@@ -49,7 +49,7 @@ pub struct MMDiT {
}
impl MMDiT {
- pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
+ pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
@@ -57,6 +57,7 @@ impl MMDiT {
cfg.depth,
cfg.patch_size,
cfg.out_channels,
+ use_flash_attn,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
@@ -135,6 +136,7 @@ impl MMDiTCore {
num_heads: usize,
patch_size: usize,
out_channels: usize,
+ use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
@@ -142,6 +144,7 @@ impl MMDiTCore {
joint_blocks.push(JointBlock::new(
hidden_size,
num_heads,
+ use_flash_attn,
vb.pp(format!("joint_blocks.{}", i)),
)?);
}
@@ -151,6 +154,7 @@ impl MMDiTCore {
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
+ use_flash_attn,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(