diff options
author | Laurent Mazare <laurent.mazare@gmail.com> | 2023-07-26 14:16:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-26 14:16:37 +0100 |
commit | 2ce5f12513d0dafb04c7e345da9d4fba566cfa16 (patch) | |
tree | d8370aa035f667905e6f033e99e08fd93e677041 /candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu | |
parent | fa2b64d678ca83e2fbc3dabdecffbc778d5b067d (diff) | |
download | candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.gz candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.tar.bz2 candle-2ce5f12513d0dafb04c7e345da9d4fba566cfa16.zip |
Again set a few extra params in flash-attn. (#245)
* Again set a few extra params.
* Use the appropriate kernel sizes.
* Add all the kernel sizes.
* Parallel compiling.
* Reduce the amount of parallelism.
* Add the missing kernel.
* Fix a typo.
* Remove bf16 support for now.
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu')
-rw-r--r-- | candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu | 92 |
1 files changed, 1 insertions, 91 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index d8f071ef..91e6331e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -20,94 +20,4 @@ template<> void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32<cutlass::half_t>(params, stream); -} - - -extern "C" void run_mha( - void *q_ptr, - void *k_ptr, - void *v_ptr, - void *o_ptr, - void *softmax_lse_ptr, - - uint32_t q_batch_stride, - uint32_t k_batch_stride, - uint32_t v_batch_stride, - uint32_t o_batch_stride, - - uint32_t q_row_stride, - uint32_t k_row_stride, - uint32_t v_row_stride, - uint32_t o_row_stride, - - uint32_t q_head_stride, - uint32_t k_head_stride, - uint32_t v_head_stride, - uint32_t o_head_stride, - - uint32_t b, - uint32_t h, - uint32_t h_k, - uint32_t d, - uint32_t d_rounded, - float softmax_scale, - - uint32_t seqlen_q, - uint32_t seqlen_k, - uint32_t seqlen_q_rounded, - uint32_t seqlen_k_rounded, - - int is_causal -) { - Flash_fwd_params params; - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - // Set the pointers and strides. - params.q_ptr = q_ptr; - params.k_ptr = k_ptr; - params.v_ptr = v_ptr; - params.o_ptr = o_ptr; - - params.softmax_lse_ptr = softmax_lse_ptr; - - // All stride are in elements, not bytes. - params.q_batch_stride = q_batch_stride; - params.k_batch_stride = k_batch_stride; - params.v_batch_stride = v_batch_stride; - params.o_batch_stride = o_batch_stride; - - params.q_row_stride = q_row_stride; - params.k_row_stride = k_row_stride; - params.v_row_stride = v_row_stride; - params.o_row_stride = o_row_stride; - params.q_head_stride = q_head_stride; - params.k_head_stride = k_head_stride; - params.v_head_stride = v_head_stride; - params.o_head_stride = o_head_stride; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - params.is_causal = is_causal; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - - params.p_dropout = 1.; // probability to keep - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - - cudaStream_t stream = 0; // Use the default stream. - run_mha_fwd_<cutlass::half_t, 32>(params, stream); -} +}
\ No newline at end of file |