diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r-- | candle-flash-attn/kernels/flash.h | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index 80b517e9..88c2f22a 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,6 +7,14 @@ #include <cuda.h> #include <vector> +// #ifdef OLD_GENERATOR_PATH +// #include <ATen/CUDAGeneratorImpl.h> +// #else +// #include <ATen/cuda/CUDAGeneratorImpl.h> +// #endif +// +// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack + constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -14,7 +22,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = uint32_t; + using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; @@ -59,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; @@ -91,7 +99,12 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. - int *__restrict__ cache_batch_idx; + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -105,6 +118,13 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; + float softcap; + + // Random state. + // at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; bool is_bf16; bool is_causal; @@ -119,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -165,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); +template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); |