diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 10:04:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 10:04:47 +0100 |
commit | 2a705e6f3739cd43b40139b1ee58141b733bcfc1 (patch) | |
tree | 13cc9822fc8b9e471335243651ccbd0bcbb4159f | |
parent | a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (diff) | |
download | candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.gz candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.bz2 candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
* unpadded lse added
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 2 | ||||
-rw-r--r-- | candle-flash-attn/src/ffi.rs | 1 | ||||
-rw-r--r-- | candle-flash-attn/src/lib.rs | 8 |
3 files changed, 7 insertions, 4 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 00933419..d172bef8 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -53,6 +53,7 @@ extern "C" void run_mha( int is_bf16, int is_causal, + int unpadded_lse, int window_size_left, int window_size_right, @@ -128,6 +129,7 @@ extern "C" void run_mha( params.is_seqlens_k_cumulative = true; params.num_splits = 1; + params.unpadded_lse = unpadded_lse; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs index 47e54e2a..78d3a986 100644 --- a/candle-flash-attn/src/ffi.rs +++ b/candle-flash-attn/src/ffi.rs @@ -42,6 +42,7 @@ extern "C" { is_bf16: c_int, is_causal: c_int, + unpadded_lse: c_int, window_size_left: c_int, window_size_right: c_int, diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 22a6f1d6..1b2e5e43 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -200,6 +200,7 @@ impl FlashAttn { /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_bf16 */ is_bf16, /* is_causal */ is_causal, + /* upadded_lse */ 0, /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, /* softcap */ self.softcap.unwrap_or(0f32), @@ -518,7 +519,7 @@ impl FlashAttnVarLen { candle::bail!("the last dim of v must be contiguous {v_stride:?}") } - let (_total_q, num_heads, head_size_og) = q_l.shape().dims3()?; + let (total_q, num_heads, head_size_og) = q_l.shape().dims3()?; let (total_k, num_heads_k, _head_size_og) = k_l.shape().dims3()?; let expected_kv = (total_k, num_heads_k, head_size_og); if expected_kv != k_l.shape().dims3()? { @@ -601,9 +602,7 @@ impl FlashAttnVarLen { let elem_count = out_shape.elem_count(); let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?; - let softmax_lse = dev - .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) - .w()?; + let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -663,6 +662,7 @@ impl FlashAttnVarLen { /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_bf16 */ is_bf16, /* is_causal */ is_causal, + /* upadded_lse */ 1, /* window_size_left */ window_size_left, /* window_size_right */ window_size_right, /* softcap */ self.softcap.unwrap_or(0.0), |