diff options
Diffstat (limited to 'candle-flash-attn/kernels/flash_api.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_api.cu | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu new file mode 100644 index 00000000..323aeaad --- /dev/null +++ b/candle-flash-attn/kernels/flash_api.cu @@ -0,0 +1,109 @@ +#include "flash_fwd_launch_template.h" + +// TODO: Switch back to handling bf16. +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); + }); +} + +// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { +// FP16_SWITCH(!params.is_bf16, [&] { +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_<elem_type, kHeadDim>(params, stream); +// }); +// }); +// } + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *softmax_lse_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t o_batch_stride, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_causal +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + + // All stride are in elements, not bytes. + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.o_row_stride = o_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_head_stride = o_head_stride; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.is_causal = is_causal; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + params.is_bf16 = 0; + params.cu_seqlens_q = nullptr; + params.cu_seqlens_k = nullptr; + params.p_ptr = nullptr; + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd(params, stream); +} |