summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/block_info.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/block_info.h')
-rw-r--r--candle-flash-attn/kernels/block_info.h13
1 files changed, 9 insertions, 4 deletions
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h
index 94251a41..65435e51 100644
--- a/candle-flash-attn/kernels/block_info.h
+++ b/candle-flash-attn/kernels/block_info.h
@@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
- , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
+ , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
- , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
@@ -32,8 +35,10 @@ struct BlockInfo {
const int sum_s_q;
const int sum_s_k;
- const uint32_t actual_seqlen_q;
- const uint32_t actual_seqlen_k;
+ const int actual_seqlen_q;
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+ const int seqlen_k_cache;
+ const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////