summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/static_switch.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/static_switch.h')
-rw-r--r--candle-flash-attn/kernels/static_switch.h53
1 files changed, 52 insertions, 1 deletions
diff --git a/candle-flash-attn/kernels/static_switch.h b/candle-flash-attn/kernels/static_switch.h
index 4aa84740..20c2afd6 100644
--- a/candle-flash-attn/kernels/static_switch.h
+++ b/candle-flash-attn/kernels/static_switch.h
@@ -14,6 +14,7 @@
/// some_function<BoolConst>(...);
/// });
/// ```
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
@@ -25,6 +26,56 @@
} \
}()
+#ifdef FLASHATTENTION_DISABLE_DROPOUT
+ #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define DROPOUT_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+ #define ALIBI_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define ALIBI_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+ #define EVENK_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = true; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define EVENK_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_SOFTCAP
+ #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define SOFTCAP_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_LOCAL
+ #define LOCAL_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define LOCAL_SWITCH BOOL_SWITCH
+#endif
+
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
@@ -36,7 +87,7 @@
} \
}()
-#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
+#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \