diff options
Diffstat (limited to 'candle-flash-attn/kernels/block_info.h')
-rw-r--r-- | candle-flash-attn/kernels/block_info.h | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h index 94251a41..65435e51 100644 --- a/candle-flash-attn/kernels/block_info.h +++ b/candle-flash-attn/kernels/block_info.h @@ -14,9 +14,12 @@ struct BlockInfo { template<typename Params> __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) - , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) - , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -32,8 +35,10 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// |