diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r-- | candle-flash-attn/kernels/flash.h | 54 |
1 files changed, 42 insertions, 12 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index be4ae0ca..80b517e9 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,15 +7,6 @@ #include <cuda.h> #include <vector> -// #ifdef OLD_GENERATOR_PATH -// #include <ATen/CUDAGeneratorImpl.h> -// #else -// #include <ATen/cuda/CUDAGeneratorImpl.h> -// #endif -// -// #include <ATen/cuda/CUDAGraphsUtils.cuh> - - constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params { // The O matrix (output). void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; // The stride between rows of O. index_t o_batch_stride; @@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params { // The pointer to the softmax sum. void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; // The scaling factors for the kernel. float scale_softmax; @@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + int *__restrict__ blockmask; + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; @@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params { float rp_dropout; float scale_softmax_rp_dropout; - // Random state. - // at::PhiloxCudaState philox_args; + // Local window size + int window_size_left, window_size_right; bool is_bf16; bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); |