summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 10:04:47 +0100
committerGitHub <noreply@github.com>2024-12-31 10:04:47 +0100
commit2a705e6f3739cd43b40139b1ee58141b733bcfc1 (patch)
tree13cc9822fc8b9e471335243651ccbd0bcbb4159f /candle-flash-attn/kernels
parenta594ef669ca5ed82c1f19d2230b4b3dc9cb46f43 (diff)
downloadcandle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.gz
candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.tar.bz2
candle-2a705e6f3739cd43b40139b1ee58141b733bcfc1.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [3/n] (#2690)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace * softcap is working, including test and api. * make softcap test case better * unpadded lse added
Diffstat (limited to 'candle-flash-attn/kernels')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 00933419..d172bef8 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -53,6 +53,7 @@ extern "C" void run_mha(
int is_bf16,
int is_causal,
+ int unpadded_lse,
int window_size_left,
int window_size_right,
@@ -128,6 +129,7 @@ extern "C" void run_mha(
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
+ params.unpadded_lse = unpadded_lse;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);