diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 10:04:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 10:04:47 +0100 |
commit | 2a705e6f3739cd43b40139b1ee58141b733bcfc1 (patch) | |
tree | 13cc9822fc8b9e471335243651ccbd0bcbb4159f /candle-flash-attn/kernels | |
parent | a594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (diff) | |
download | candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.gz candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.bz2 candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
* softcap is working, including test and api.
* make softcap test case better
* unpadded lse added
Diffstat (limited to 'candle-flash-attn/kernels')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 00933419..d172bef8 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -53,6 +53,7 @@ extern "C" void run_mha( int is_bf16, int is_causal, + int unpadded_lse, int window_size_left, int window_size_right, @@ -128,6 +129,7 @@ extern "C" void run_mha( params.is_seqlens_k_cumulative = true; params.num_splits = 1; + params.unpadded_lse = unpadded_lse; cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); |