diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 323aeaad..d928bcb6 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -22,6 +22,9 @@ extern "C" void run_mha( void *o_ptr, void *softmax_lse_ptr, + int32_t *cu_seqlens_q_ptr, + int32_t *cu_seqlens_k_ptr, + uint32_t q_batch_stride, uint32_t k_batch_stride, uint32_t v_batch_stride, @@ -100,9 +103,9 @@ extern "C" void run_mha( params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; params.is_bf16 = 0; - params.cu_seqlens_q = nullptr; - params.cu_seqlens_k = nullptr; - params.p_ptr = nullptr; + params.cu_seqlens_q = cu_seqlens_q_ptr; + params.cu_seqlens_k = cu_seqlens_k_ptr; + params.p_ptr = nullptr; // used for `return_softmax`. cudaStream_t stream = 0; // Use the default stream. run_mha_fwd(params, stream); |