diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r-- | candle-flash-attn/kernels/flash.h | 13 |
1 files changed, 4 insertions, 9 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index 88c2f22a..f21e4d62 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,13 +7,7 @@ #include <cuda.h> #include <vector> -// #ifdef OLD_GENERATOR_PATH -// #include <ATen/CUDAGeneratorImpl.h> -// #else -// #include <ATen/cuda/CUDAGeneratorImpl.h> -// #endif -// -// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack +// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; @@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); |