summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu2
1 files changed, 2 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 00933419..d172bef8 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -53,6 +53,7 @@ extern "C" void run_mha(
int is_bf16,
int is_causal,
+ int unpadded_lse,
int window_size_left,
int window_size_right,
@@ -128,6 +129,7 @@ extern "C" void run_mha(
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
+ params.unpadded_lse = unpadded_lse;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);