summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r--candle-flash-attn/kernels/flash.h13
1 files changed, 4 insertions, 9 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 88c2f22a..f21e4d62 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-//
-// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
@@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
+ int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
@@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);