summaryrefslogtreecommitdiff
path: root/candle-transformers/src/models/mmdit
diff options
context:
space:
mode:
Diffstat (limited to 'candle-transformers/src/models/mmdit')
-rw-r--r--candle-transformers/src/models/mmdit/blocks.rs54
-rw-r--r--candle-transformers/src/models/mmdit/model.rs8
-rw-r--r--candle-transformers/src/models/mmdit/projections.rs1
3 files changed, 53 insertions, 10 deletions
diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs
index e2b924a0..a1777f91 100644
--- a/candle-transformers/src/models/mmdit/blocks.rs
+++ b/candle-transformers/src/models/mmdit/blocks.rs
@@ -194,10 +194,16 @@ pub struct JointBlock {
x_block: DiTBlock,
context_block: DiTBlock,
num_heads: usize,
+ use_flash_attn: bool,
}
impl JointBlock {
- pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ pub fn new(
+ hidden_size: usize,
+ num_heads: usize,
+ use_flash_attn: bool,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
@@ -205,13 +211,15 @@ impl JointBlock {
x_block,
context_block,
num_heads,
+ use_flash_attn,
})
}
pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> {
let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
- let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
+ let (context_attn, x_attn) =
+ joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let context_out =
self.context_block
.post_attention(&context_attn, context, &context_interm)?;
@@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock {
x_block: DiTBlock,
context_block: QkvOnlyDiTBlock,
num_heads: usize,
+ use_flash_attn: bool,
}
impl ContextQkvOnlyJointBlock {
- pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
+ pub fn new(
+ hidden_size: usize,
+ num_heads: usize,
+ use_flash_attn: bool,
+ vb: nn::VarBuilder,
+ ) -> Result<Self> {
let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?;
let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
Ok(Self {
x_block,
context_block,
num_heads,
+ use_flash_attn,
})
}
@@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock {
let context_qkv = self.context_block.pre_attention(context, c)?;
let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;
- let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
+ let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?;
let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;
Ok(x_out)
@@ -266,7 +281,28 @@ fn flash_compatible_attention(
attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2)
}
-fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> {
+#[cfg(feature = "flash-attn")]
+fn flash_attn(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
+}
+
+#[cfg(not(feature = "flash-attn"))]
+fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
+ unimplemented!("compile with '--features flash-attn'")
+}
+
+fn joint_attn(
+ context_qkv: &Qkv,
+ x_qkv: &Qkv,
+ num_heads: usize,
+ use_flash_attn: bool,
+) -> Result<(Tensor, Tensor)> {
let qkv = Qkv {
q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?,
k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?,
@@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso
let headdim = qkv.q.dim(D::Minus1)?;
let softmax_scale = 1.0 / (headdim as f64).sqrt();
- // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?;
- let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?;
+
+ let attn = if use_flash_attn {
+ flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?
+ } else {
+ flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?
+ };
let attn = attn.reshape((batch_size, seqlen, ()))?;
let context_qkv_seqlen = context_qkv.q.dim(1)?;
diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs
index 1523836c..864b6623 100644
--- a/candle-transformers/src/models/mmdit/model.rs
+++ b/candle-transformers/src/models/mmdit/model.rs
@@ -23,7 +23,7 @@ pub struct Config {
}
impl Config {
- pub fn sd3() -> Self {
+ pub fn sd3_medium() -> Self {
Self {
patch_size: 2,
in_channels: 16,
@@ -49,7 +49,7 @@ pub struct MMDiT {
}
impl MMDiT {
- pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> {
+ pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> {
let hidden_size = cfg.head_size * cfg.depth;
let core = MMDiTCore::new(
cfg.depth,
@@ -57,6 +57,7 @@ impl MMDiT {
cfg.depth,
cfg.patch_size,
cfg.out_channels,
+ use_flash_attn,
vb.clone(),
)?;
let patch_embedder = PatchEmbedder::new(
@@ -135,6 +136,7 @@ impl MMDiTCore {
num_heads: usize,
patch_size: usize,
out_channels: usize,
+ use_flash_attn: bool,
vb: nn::VarBuilder,
) -> Result<Self> {
let mut joint_blocks = Vec::with_capacity(depth - 1);
@@ -142,6 +144,7 @@ impl MMDiTCore {
joint_blocks.push(JointBlock::new(
hidden_size,
num_heads,
+ use_flash_attn,
vb.pp(format!("joint_blocks.{}", i)),
)?);
}
@@ -151,6 +154,7 @@ impl MMDiTCore {
context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new(
hidden_size,
num_heads,
+ use_flash_attn,
vb.pp(format!("joint_blocks.{}", depth - 1)),
)?,
final_layer: FinalLayer::new(
diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs
index 1077398f..dc1e8ec9 100644
--- a/candle-transformers/src/models/mmdit/projections.rs
+++ b/candle-transformers/src/models/mmdit/projections.rs
@@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections {
impl QkvOnlyAttnProjections {
pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> {
- // {'dim': 1536, 'num_heads': 24}
let head_dim = dim / num_heads;
let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?;
Ok(Self { qkv, head_dim })