summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_api.cu
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-09-04 08:50:52 +0200
committerGitHub <noreply@github.com>2023-09-04 07:50:52 +0100
commitd0cdea95a5ec8f53b24c6de19f6029060339ed98 (patch)
tree61c22d5d923080ba1d673ca2bfdc80ee6aa51939 /candle-flash-attn/kernels/flash_api.cu
parent20512ba408f9840828e902b7dd824be5a0969feb (diff)
downloadcandle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.tar.gz
candle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.tar.bz2
candle-d0cdea95a5ec8f53b24c6de19f6029060339ed98.zip
Add back the bf16 flash-attn kernels. (#730)
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu26
1 files changed, 13 insertions, 13 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index d928bcb6..72991257 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,20 +1,19 @@
#include "flash_fwd_launch_template.h"
-// TODO: Switch back to handling bf16.
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
- FWD_HEADDIM_SWITCH(params.d, [&] {
- run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
- });
-}
-
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
-// FP16_SWITCH(!params.is_bf16, [&] {
-// FWD_HEADDIM_SWITCH(params.d, [&] {
-// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-// });
+// FWD_HEADDIM_SWITCH(params.d, [&] {
+// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
// });
// }
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+ });
+ });
+}
+
extern "C" void run_mha(
void *q_ptr,
void *k_ptr,
@@ -52,7 +51,8 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
- int is_causal
+ int is_causal,
+ int is_bf16
) {
Flash_fwd_params params;
// Reset the parameters
@@ -102,7 +102,7 @@ extern "C" void run_mha(
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
- params.is_bf16 = 0;
+ params.is_bf16 = is_bf16;
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.