diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-09-04 08:50:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-04 07:50:52 +0100 |
commit | d0cdea95a5ec8f53b24c6de19f6029060339ed98 (patch) | |
tree | 61c22d5d923080ba1d673ca2bfdc80ee6aa51939 /candle-flash-attn/kernels/flash_api.cu | |
parent | 20512ba408f9840828e902b7dd824be5a0969feb (diff) | |
download | candle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.tar.gz candle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.tar.bz2 candle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.zip |
Add back the bf16 flash-attn kernels. (#730)
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu index d928bcb6..72991257 100644 --- a/candle-flash-attn/kernels/flash_api.cu +++ b/candle-flash-attn/kernels/flash_api.cu @@ -1,20 +1,19 @@ #include "flash_fwd_launch_template.h" -// TODO: Switch back to handling bf16. -void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - FWD_HEADDIM_SWITCH(params.d, [&] { - run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); - }); -} - // void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { -// FP16_SWITCH(!params.is_bf16, [&] { -// FWD_HEADDIM_SWITCH(params.d, [&] { -// run_mha_fwd_<elem_type, kHeadDim>(params, stream); -// }); +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); // }); // } +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_<elem_type, kHeadDim>(params, stream); + }); + }); +} + extern "C" void run_mha( void *q_ptr, void *k_ptr, @@ -52,7 +51,8 @@ extern "C" void run_mha( uint32_t seqlen_q_rounded, uint32_t seqlen_k_rounded, - int is_causal + int is_causal, + int is_bf16 ) { Flash_fwd_params params; // Reset the parameters @@ -102,7 +102,7 @@ extern "C" void run_mha( params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - params.is_bf16 = 0; + params.is_bf16 = is_bf16; params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr; params.p_ptr = nullptr; // used for `return_softmax`. |