diff options
Diffstat (limited to 'candle-flash-attn/kernels')
17 files changed, 379 insertions, 91 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu new file mode 100644 index 00000000..323aeaad --- /dev/null +++ b/candle-flash-attn/kernels/flash_api.cu @@ -0,0 +1,109 @@ +#include "flash_fwd_launch_template.h" + +// TODO: Switch back to handling bf16. +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); + }); +} + +// void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { +// FP16_SWITCH(!params.is_bf16, [&] { +// FWD_HEADDIM_SWITCH(params.d, [&] { +// run_mha_fwd_<elem_type, kHeadDim>(params, stream); +// }); +// }); +// } + +extern "C" void run_mha( + void *q_ptr, + void *k_ptr, + void *v_ptr, + void *o_ptr, + void *softmax_lse_ptr, + + uint32_t q_batch_stride, + uint32_t k_batch_stride, + uint32_t v_batch_stride, + uint32_t o_batch_stride, + + uint32_t q_row_stride, + uint32_t k_row_stride, + uint32_t v_row_stride, + uint32_t o_row_stride, + + uint32_t q_head_stride, + uint32_t k_head_stride, + uint32_t v_head_stride, + uint32_t o_head_stride, + + uint32_t b, + uint32_t h, + uint32_t h_k, + uint32_t d, + uint32_t d_rounded, + float softmax_scale, + + uint32_t seqlen_q, + uint32_t seqlen_k, + uint32_t seqlen_q_rounded, + uint32_t seqlen_k_rounded, + + int is_causal +) { + Flash_fwd_params params; + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + // Set the pointers and strides. + params.q_ptr = q_ptr; + params.k_ptr = k_ptr; + params.v_ptr = v_ptr; + params.o_ptr = o_ptr; + + params.softmax_lse_ptr = softmax_lse_ptr; + + // All stride are in elements, not bytes. + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.o_row_stride = o_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_head_stride = o_head_stride; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + params.is_causal = is_causal; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.p_dropout = 1.; // probability to keep + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + params.is_bf16 = 0; + params.cu_seqlens_q = nullptr; + params.cu_seqlens_k = nullptr; + params.p_ptr = nullptr; + + cudaStream_t stream = 0; // Use the default stream. + run_mha_fwd(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 00000000..654400a7 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 00000000..5b7254a9 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,32 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream); +// // 1st ones are good for H100, A100 +// // 2nd one is good for A6000 bc we get slightly better occupancy +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream); +// // 1st one is good for H100, A100, A6000 +// } +// } + +template<> +void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 00000000..6a9d60c3 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 00000000..6c40a164 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream); +// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest. +// // For A100, H100, 1st is fastest. +// }); +// } +template<> +void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 00000000..d2f4cba7 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,16 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 00000000..2875c926 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream); +// // This one is slightly faster for causal? +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream); +// }); +// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout +// // For A6000, 1st is faster when causal, 3rd is faster when not causal +// } +template<> +void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 00000000..982fe7ea --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 00000000..4c083f7b --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::half_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 00000000..cb074a95 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 00000000..ddf5e132 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::half_t>(params, stream); +} diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 00000000..81e359e1 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu index d8f071ef..91e6331e 100644 --- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu +++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu @@ -20,94 +20,4 @@ template<> void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) { run_mha_fwd_hdim32<cutlass::half_t>(params, stream); -} - - -extern "C" void run_mha( - void *q_ptr, - void *k_ptr, - void *v_ptr, - void *o_ptr, - void *softmax_lse_ptr, - - uint32_t q_batch_stride, - uint32_t k_batch_stride, - uint32_t v_batch_stride, - uint32_t o_batch_stride, - - uint32_t q_row_stride, - uint32_t k_row_stride, - uint32_t v_row_stride, - uint32_t o_row_stride, - - uint32_t q_head_stride, - uint32_t k_head_stride, - uint32_t v_head_stride, - uint32_t o_head_stride, - - uint32_t b, - uint32_t h, - uint32_t h_k, - uint32_t d, - uint32_t d_rounded, - float softmax_scale, - - uint32_t seqlen_q, - uint32_t seqlen_k, - uint32_t seqlen_q_rounded, - uint32_t seqlen_k_rounded, - - int is_causal -) { - Flash_fwd_params params; - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - // Set the pointers and strides. - params.q_ptr = q_ptr; - params.k_ptr = k_ptr; - params.v_ptr = v_ptr; - params.o_ptr = o_ptr; - - params.softmax_lse_ptr = softmax_lse_ptr; - - // All stride are in elements, not bytes. - params.q_batch_stride = q_batch_stride; - params.k_batch_stride = k_batch_stride; - params.v_batch_stride = v_batch_stride; - params.o_batch_stride = o_batch_stride; - - params.q_row_stride = q_row_stride; - params.k_row_stride = k_row_stride; - params.v_row_stride = v_row_stride; - params.o_row_stride = o_row_stride; - params.q_head_stride = q_head_stride; - params.k_head_stride = k_head_stride; - params.v_head_stride = v_head_stride; - params.o_head_stride = o_head_stride; - - // Set the dimensions. - params.b = b; - params.h = h; - params.h_k = h_k; - params.h_h_k_ratio = h / h_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = d; - params.d_rounded = d_rounded; - params.is_causal = is_causal; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; - - params.p_dropout = 1.; // probability to keep - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - - cudaStream_t stream = 0; // Use the default stream. - run_mha_fwd_<cutlass::half_t, 32>(params, stream); -} +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 00000000..fffcbebb --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,19 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// if (params.p_dropout == 1.f) { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 00000000..01bd1716 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// if (params.p_dropout == 1.f) { +// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower +// // Using block size (64 x 256) is 27% slower for seqlen=2k +// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream); +// } else { +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream); +// } +// } +template<> +void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::half_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 00000000..b0b27db5 --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,17 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::bfloat16_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream); +// }); +// } +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); +}
\ No newline at end of file diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 00000000..820b63cb --- /dev/null +++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +// template<> +// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { +// using elem_type = cutlass::half_t; +// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream); +// // This 3rd one is good for H100, and A100, A6000 +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream); +// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream); +// // These two are always slower +// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream); +// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream); +// }); +// } +template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::half_t>(params, stream); +}
\ No newline at end of file |