summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu9
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 323aeaad..d928bcb6 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -22,6 +22,9 @@ extern "C" void run_mha(
void *o_ptr,
void *softmax_lse_ptr,
+ int32_t *cu_seqlens_q_ptr,
+ int32_t *cu_seqlens_k_ptr,
+
uint32_t q_batch_stride,
uint32_t k_batch_stride,
uint32_t v_batch_stride,
@@ -100,9 +103,9 @@ extern "C" void run_mha(
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
params.is_bf16 = 0;
- params.cu_seqlens_q = nullptr;
- params.cu_seqlens_k = nullptr;
- params.p_ptr = nullptr;
+ params.cu_seqlens_q = cu_seqlens_q_ptr;
+ params.cu_seqlens_k = cu_seqlens_k_ptr;
+ params.p_ptr = nullptr; // used for `return_softmax`.
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);