diff options
Diffstat (limited to 'candle-flash-attn/kernels/static_switch.h')
-rw-r--r-- | candle-flash-attn/kernels/static_switch.h | 53 |
1 files changed, 52 insertions, 1 deletions
diff --git a/candle-flash-attn/kernels/static_switch.h b/candle-flash-attn/kernels/static_switch.h index 4aa84740..20c2afd6 100644 --- a/candle-flash-attn/kernels/static_switch.h +++ b/candle-flash-attn/kernels/static_switch.h @@ -14,6 +14,7 @@ /// some_function<BoolConst>(...); /// }); /// ``` + #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ @@ -25,6 +26,56 @@ } \ }() +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ @@ -36,7 +87,7 @@ } \ }() -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ +#define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \ |