diff options
Diffstat (limited to 'candle-transformers/src/models/mmdit')
-rw-r--r-- | candle-transformers/src/models/mmdit/blocks.rs | 54 | ||||
-rw-r--r-- | candle-transformers/src/models/mmdit/model.rs | 8 | ||||
-rw-r--r-- | candle-transformers/src/models/mmdit/projections.rs | 1 |
3 files changed, 53 insertions, 10 deletions
diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs index e2b924a0..a1777f91 100644 --- a/candle-transformers/src/models/mmdit/blocks.rs +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -194,10 +194,16 @@ pub struct JointBlock { x_block: DiTBlock, context_block: DiTBlock, num_heads: usize, + use_flash_attn: bool, } impl JointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result<Self> { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; @@ -205,13 +211,15 @@ impl JointBlock { x_block, context_block, num_heads, + use_flash_attn, }) } pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (context_attn, x_attn) = + joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let context_out = self.context_block .post_attention(&context_attn, context, &context_interm)?; @@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock { x_block: DiTBlock, context_block: QkvOnlyDiTBlock, num_heads: usize, + use_flash_attn: bool, } impl ContextQkvOnlyJointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result<Self> { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; Ok(Self { x_block, context_block, num_heads, + use_flash_attn, }) } @@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock { let context_qkv = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; Ok(x_out) @@ -266,7 +281,28 @@ fn flash_compatible_attention( attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) } -fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result<Tensor> { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> { + unimplemented!("compile with '--features flash-attn'") +} + +fn joint_attn( + context_qkv: &Qkv, + x_qkv: &Qkv, + num_heads: usize, + use_flash_attn: bool, +) -> Result<(Tensor, Tensor)> { let qkv = Qkv { q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, @@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso let headdim = qkv.q.dim(D::Minus1)?; let softmax_scale = 1.0 / (headdim as f64).sqrt(); - // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; - let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = if use_flash_attn { + flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)? + } else { + flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)? + }; let attn = attn.reshape((batch_size, seqlen, ()))?; let context_qkv_seqlen = context_qkv.q.dim(1)?; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 1523836c..864b6623 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -23,7 +23,7 @@ pub struct Config { } impl Config { - pub fn sd3() -> Self { + pub fn sd3_medium() -> Self { Self { patch_size: 2, in_channels: 16, @@ -49,7 +49,7 @@ pub struct MMDiT { } impl MMDiT { - pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result<Self> { + pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result<Self> { let hidden_size = cfg.head_size * cfg.depth; let core = MMDiTCore::new( cfg.depth, @@ -57,6 +57,7 @@ impl MMDiT { cfg.depth, cfg.patch_size, cfg.out_channels, + use_flash_attn, vb.clone(), )?; let patch_embedder = PatchEmbedder::new( @@ -135,6 +136,7 @@ impl MMDiTCore { num_heads: usize, patch_size: usize, out_channels: usize, + use_flash_attn: bool, vb: nn::VarBuilder, ) -> Result<Self> { let mut joint_blocks = Vec::with_capacity(depth - 1); @@ -142,6 +144,7 @@ impl MMDiTCore { joint_blocks.push(JointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", i)), )?); } @@ -151,6 +154,7 @@ impl MMDiTCore { context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", depth - 1)), )?, final_layer: FinalLayer::new( diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index 1077398f..dc1e8ec9 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections { impl QkvOnlyAttnProjections { pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result<Self> { - // {'dim': 1536, 'num_heads': 24} let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; Ok(Self { qkv, head_dim }) |