summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash.h
diff options
context:
space:
mode:
authorMichael Feil <63565275+michaelfeil@users.noreply.github.com>2024-12-31 09:32:22 +0100
committerGitHub <noreply@github.com>2024-12-31 09:32:22 +0100
commit71cd6d55337b1541f602c1afffa6baf6dd75b09c (patch)
tree207e7d050e9c4bd685a563e457b2e9ef59b66f20 /candle-flash-attn/kernels/flash.h
parentd60eba140820326ffc7ec39a8982e91feb462732 (diff)
downloadcandle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.gz
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.tar.bz2
candle-71cd6d55337b1541f602c1afffa6baf6dd75b09c.zip
Flash-Attn upgrade / SoftCap Candle-FlashAttn [1/n] (#2688)
* update flash-attn v1 * restore: hdim224 * add 224 flash_fwd_template * remove whitespace
Diffstat (limited to 'candle-flash-attn/kernels/flash.h')
-rw-r--r--candle-flash-attn/kernels/flash.h13
1 files changed, 4 insertions, 9 deletions
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 88c2f22a..f21e4d62 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,13 +7,7 @@
#include <cuda.h>
#include <vector>
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-//
-// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+// #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
@@ -76,6 +70,7 @@ struct Flash_fwd_params : public Qkv_params {
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
+ int * __restrict__ leftpad_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
@@ -189,6 +184,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
+// template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);