diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index 8113dbc7..4ca41b0a 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,15 +1,15 @@ +#include "kernels.h" +#include "kernel_helpers.h" #include "flash_fwd_launch_template.h" -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { - FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { -// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_<elem_type, kHeadDim>(params, stream); -// } else { -// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream); -// } - }); - }); +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream); + }); + }); + }); } extern "C" void run_mha( |