diff options
author | Michael Feil <63565275+michaelfeil@users.noreply.github.com> | 2024-12-31 09:32:22 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-12-31 09:32:22 +0100 |
commit | 71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch) | |
tree | 207e7d050e9c4bd685a563e457b2e9ef59b66f20 /candle-flash-attn/kernels/flash.h | |
parent | d60eba140820326ffc7ec39a8982e91feb462732 (diff) | |
download | candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2 candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip |
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1
* restore: hdim224
* add 224 flash_fwd_template
* remove whitespace
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r-- | candle-flash-attn/kernels/flash.h | 13 |
1 files changed, 4 insertions, 9 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h index 88c2f22a..f21e4d62 100644 --- a/candle-flash-attn/kernels/flash.h +++ b/candle-flash-attn/kernels/flash.h @@ -7,13 +7,7 @@ #include <cuda.h> #include <vector> -// #ifdef OLD_GENERATOR_PATH -// #include <ATen/CUDAGeneratorImpl.h> -// #else -// #include <ATen/cuda/CUDAGeneratorImpl.h> -// #endif -// -// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack +// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; @@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); +// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); |