summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 10:04:47 +0100
committerGitHub <noreply@github.com>2024-12-31 10:04:47 +0100
commit2a705e6f3739cd43b40139b1ee58141b733bcfc1 (patch)
tree13cc9822fc8b9e471335243651ccbd0bcbb4159f
parenta594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (diff)
downloadcandle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.gz
candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.bz2
candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better * unpadded lse added
-rw-r--r--candle-flash-attn/kernels/flash_api.cu2
-rw-r--r--candle-flash-attn/src/ffi.rs1
-rw-r--r--candle-flash-attn/src/lib.rs8
3 files changed, 7 insertions, 4 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 00933419..d172bef8 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -53,6 +53,7 @@ extern "C" void run_mha(
int is_bf16,
int is_causal,
+ int unpadded_lse,
int window_size_left,
int window_size_right,
@@ -128,6 +129,7 @@ extern "C" void run_mha(
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
+ params.unpadded_lse = unpadded_lse;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs
index 47e54e2a..78d3a986 100644
--- a/candle-flash-attn/src/ffi.rs
+++ b/candle-flash-attn/src/ffi.rs
@@ -42,6 +42,7 @@ extern "C" {
is_bf16: c_int,
is_causal: c_int,
+ unpadded_lse: c_int,
window_size_left: c_int,
window_size_right: c_int,
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 22a6f1d6..1b2e5e43 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -200,6 +200,7 @@ impl FlashAttn {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
+ /* upadded_lse */ 0,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0f32),
@@ -518,7 +519,7 @@ impl FlashAttnVarLen {
candle::bail!("the last dim of v must be contiguous {v_stride:?}")
}
- let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
+ let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?;
let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?;
let expected_kv = (total_k, num_heads_k, head_size_og);
if expected_kv != k_l.shape().dims3()? {
@@ -601,9 +602,7 @@ impl FlashAttnVarLen {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
- let softmax_lse = dev
- .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
- .w()?;
+ let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@@ -663,6 +662,7 @@ impl FlashAttnVarLen {
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
+ /* upadded_lse */ 1,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
/* softcap */ self.softcap.unwrap_or(0.0),