summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r--candle-flash-attn/kernels/flash.h35
1 files changed, 29 insertions, 6 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 80b517e9..88c2f22a 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,6 +7,14 @@
#include <cuda.h>
#include <vector>
+// #ifdef OLD_GENERATOR_PATH
+// #include <ATen/CUDAGeneratorImpl.h>
+// #else
+// #include <ATen/cuda/CUDAGeneratorImpl.h>
+// #endif
+//
+// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
@@ -14,7 +22,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
- using index_t = uint32_t;
+ using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
@@ -59,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
- int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel.
float scale_softmax;
@@ -91,7 +99,12 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
- int *__restrict__ cache_batch_idx;
+ int * __restrict__ cache_batch_idx;
+
+ // Paged KV cache
+ int * __restrict__ block_table;
+ index_t block_table_batch_stride;
+ int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
@@ -105,6 +118,13 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
int window_size_left, window_size_right;
+ float softcap;
+
+ // Random state.
+ // at::PhiloxCudaState philox_args;
+
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
+ uint64_t * rng_state;
bool is_bf16;
bool is_causal;
@@ -119,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
+
+ bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
+ bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -165,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
+template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);