summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r--candle-flash-attn/kernels/flash.h54
1 files changed, 42 insertions, 12 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index be4ae0ca..80b517e9 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,15 +7,6 @@
#include <cuda.h>
#include <vector>
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-//
-// #include <ATen/cuda/CUDAGraphsUtils.cuh>
-
-
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
@@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
+ void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
@@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
+ void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
- int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel.
float scale_softmax;
@@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
+ // If provided, the actual length of each k sequence.
+ int * __restrict__ seqused_k;
+
int *__restrict__ blockmask;
+ // The K_new and V_new matrices.
+ void * __restrict__ knew_ptr;
+ void * __restrict__ vnew_ptr;
+
+ // The stride between rows of the Q, K and V matrices.
+ index_t knew_batch_stride;
+ index_t vnew_batch_stride;
+ index_t knew_row_stride;
+ index_t vnew_row_stride;
+ index_t knew_head_stride;
+ index_t vnew_head_stride;
+
+ // The cos and sin matrices for rotary embedding.
+ void * __restrict__ rotary_cos_ptr;
+ void * __restrict__ rotary_sin_ptr;
+
+ // The indices to index into the KV cache.
+ int *__restrict__ cache_batch_idx;
+
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
@@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout;
float scale_softmax_rp_dropout;
- // Random state.
- // at::PhiloxCudaState philox_args;
+ // Local window size
+ int window_size_left, window_size_right;
bool is_bf16;
bool is_causal;
+
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ bool is_seqlens_k_cumulative;
+
+ bool is_rotary_interleaved;
+
+ int num_splits; // For split-KV version
+
+ void * __restrict__ alibi_slopes_ptr;
+ index_t alibi_slopes_batch_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
+
+ bool deterministic;
+ index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);