summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu20
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 8113dbc7..4ca41b0a 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,15 +1,15 @@
+#include "kernels.h"
+#include "kernel_helpers.h"
#include "flash_fwd_launch_template.h"
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
- FP16_SWITCH(!params.is_bf16, [&] {
- FWD_HEADDIM_SWITCH(params.d, [&] {
-// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
- run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-// } else {
-// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
-// }
- });
- });
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ HEADDIM_SWITCH(params.d, [&] {
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+ run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
+ });
+ });
+ });
}
extern "C" void run_mha(