summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels')
-rw-r--r--candle-flash-attn/kernels/flash_api.cu109
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu19
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu32
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu17
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu27
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu16
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu27
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu9
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu92
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu19
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu26
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu17
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu23
17 files changed, 379 insertions, 91 deletions
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
new file mode 100644
index 00000000..323aeaad
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -0,0 +1,109 @@
+#include "flash_fwd_launch_template.h"
+
+// TODO: Switch back to handling bf16.
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
+ });
+}
+
+// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+// FP16_SWITCH(!params.is_bf16, [&] {
+// FWD_HEADDIM_SWITCH(params.d, [&] {
+// run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+// });
+// });
+// }
+
+extern "C" void run_mha(
+ void *q_ptr,
+ void *k_ptr,
+ void *v_ptr,
+ void *o_ptr,
+ void *softmax_lse_ptr,
+
+ uint32_t q_batch_stride,
+ uint32_t k_batch_stride,
+ uint32_t v_batch_stride,
+ uint32_t o_batch_stride,
+
+ uint32_t q_row_stride,
+ uint32_t k_row_stride,
+ uint32_t v_row_stride,
+ uint32_t o_row_stride,
+
+ uint32_t q_head_stride,
+ uint32_t k_head_stride,
+ uint32_t v_head_stride,
+ uint32_t o_head_stride,
+
+ uint32_t b,
+ uint32_t h,
+ uint32_t h_k,
+ uint32_t d,
+ uint32_t d_rounded,
+ float softmax_scale,
+
+ uint32_t seqlen_q,
+ uint32_t seqlen_k,
+ uint32_t seqlen_q_rounded,
+ uint32_t seqlen_k_rounded,
+
+ int is_causal
+) {
+ Flash_fwd_params params;
+ // Reset the parameters
+ memset(&params, 0, sizeof(params));
+
+ // Set the pointers and strides.
+ params.q_ptr = q_ptr;
+ params.k_ptr = k_ptr;
+ params.v_ptr = v_ptr;
+ params.o_ptr = o_ptr;
+
+ params.softmax_lse_ptr = softmax_lse_ptr;
+
+ // All stride are in elements, not bytes.
+ params.q_batch_stride = q_batch_stride;
+ params.k_batch_stride = k_batch_stride;
+ params.v_batch_stride = v_batch_stride;
+ params.o_batch_stride = o_batch_stride;
+
+ params.q_row_stride = q_row_stride;
+ params.k_row_stride = k_row_stride;
+ params.v_row_stride = v_row_stride;
+ params.o_row_stride = o_row_stride;
+ params.q_head_stride = q_head_stride;
+ params.k_head_stride = k_head_stride;
+ params.v_head_stride = v_head_stride;
+ params.o_head_stride = o_head_stride;
+
+ // Set the dimensions.
+ params.b = b;
+ params.h = h;
+ params.h_k = h_k;
+ params.h_h_k_ratio = h / h_k;
+ params.seqlen_q = seqlen_q;
+ params.seqlen_k = seqlen_k;
+ params.seqlen_q_rounded = seqlen_q_rounded;
+ params.seqlen_k_rounded = seqlen_k_rounded;
+ params.d = d;
+ params.d_rounded = d_rounded;
+ params.is_causal = is_causal;
+
+ // Set the different scale values.
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+
+ params.p_dropout = 1.; // probability to keep
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
+ params.rp_dropout = 1.f / params.p_dropout;
+ params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
+ params.is_bf16 = 0;
+ params.cu_seqlens_q = nullptr;
+ params.cu_seqlens_k = nullptr;
+ params.p_ptr = nullptr;
+
+ cudaStream_t stream = 0; // Use the default stream.
+ run_mha_fwd(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
new file mode 100644
index 00000000..654400a7
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -0,0 +1,19 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// if (params.p_dropout == 1.f) {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
new file mode 100644
index 00000000..5b7254a9
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -0,0 +1,32 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// if (params.p_dropout == 1.f) {
+// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
+// // 1st ones are good for H100, A100
+// // 2nd one is good for A6000 bc we get slightly better occupancy
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
+// // 1st one is good for H100, A100, A6000
+// }
+// }
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
new file mode 100644
index 00000000..6a9d60c3
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -0,0 +1,17 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
new file mode 100644
index 00000000..6c40a164
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
+// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
+// // For A100, H100, 1st is fastest.
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
new file mode 100644
index 00000000..d2f4cba7
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -0,0 +1,16 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
new file mode 100644
index 00000000..2875c926
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -0,0 +1,27 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// // This one is slightly faster for causal?
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
+// });
+// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
+// // For A6000, 1st is faster when causal, 3rd is faster when not causal
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
new file mode 100644
index 00000000..982fe7ea
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
new file mode 100644
index 00000000..4c083f7b
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
new file mode 100644
index 00000000..cb074a95
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
new file mode 100644
index 00000000..ddf5e132
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -0,0 +1,9 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
new file mode 100644
index 00000000..81e359e1
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index d8f071ef..91e6331e 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -20,94 +20,4 @@
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
-}
-
-
-extern "C" void run_mha(
- void *q_ptr,
- void *k_ptr,
- void *v_ptr,
- void *o_ptr,
- void *softmax_lse_ptr,
-
- uint32_t q_batch_stride,
- uint32_t k_batch_stride,
- uint32_t v_batch_stride,
- uint32_t o_batch_stride,
-
- uint32_t q_row_stride,
- uint32_t k_row_stride,
- uint32_t v_row_stride,
- uint32_t o_row_stride,
-
- uint32_t q_head_stride,
- uint32_t k_head_stride,
- uint32_t v_head_stride,
- uint32_t o_head_stride,
-
- uint32_t b,
- uint32_t h,
- uint32_t h_k,
- uint32_t d,
- uint32_t d_rounded,
- float softmax_scale,
-
- uint32_t seqlen_q,
- uint32_t seqlen_k,
- uint32_t seqlen_q_rounded,
- uint32_t seqlen_k_rounded,
-
- int is_causal
-) {
- Flash_fwd_params params;
- // Reset the parameters
- memset(&params, 0, sizeof(params));
-
- // Set the pointers and strides.
- params.q_ptr = q_ptr;
- params.k_ptr = k_ptr;
- params.v_ptr = v_ptr;
- params.o_ptr = o_ptr;
-
- params.softmax_lse_ptr = softmax_lse_ptr;
-
- // All stride are in elements, not bytes.
- params.q_batch_stride = q_batch_stride;
- params.k_batch_stride = k_batch_stride;
- params.v_batch_stride = v_batch_stride;
- params.o_batch_stride = o_batch_stride;
-
- params.q_row_stride = q_row_stride;
- params.k_row_stride = k_row_stride;
- params.v_row_stride = v_row_stride;
- params.o_row_stride = o_row_stride;
- params.q_head_stride = q_head_stride;
- params.k_head_stride = k_head_stride;
- params.v_head_stride = v_head_stride;
- params.o_head_stride = o_head_stride;
-
- // Set the dimensions.
- params.b = b;
- params.h = h;
- params.h_k = h_k;
- params.h_h_k_ratio = h / h_k;
- params.seqlen_q = seqlen_q;
- params.seqlen_k = seqlen_k;
- params.seqlen_q_rounded = seqlen_q_rounded;
- params.seqlen_k_rounded = seqlen_k_rounded;
- params.d = d;
- params.d_rounded = d_rounded;
- params.is_causal = is_causal;
-
- // Set the different scale values.
- params.scale_softmax = softmax_scale;
- params.scale_softmax_log2 = softmax_scale * M_LOG2E;
-
- params.p_dropout = 1.; // probability to keep
- params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
- params.rp_dropout = 1.f / params.p_dropout;
- params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
-
- cudaStream_t stream = 0; // Use the default stream.
- run_mha_fwd_<cutlass::half_t, 32>(params, stream);
-}
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
new file mode 100644
index 00000000..fffcbebb
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -0,0 +1,19 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// if (params.p_dropout == 1.f) {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
new file mode 100644
index 00000000..01bd1716
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -0,0 +1,26 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// if (params.p_dropout == 1.f) {
+// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+// // Using block size (64 x 256) is 27% slower for seqlen=2k
+// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
+// } else {
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
+// }
+// }
+template<>
+void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
new file mode 100644
index 00000000..b0b27db5
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -0,0 +1,17 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::bfloat16_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
+// });
+// }
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
+} \ No newline at end of file
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
new file mode 100644
index 00000000..820b63cb
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -0,0 +1,23 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+
+#include "flash_fwd_launch_template.h"
+
+// template<>
+// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+// using elem_type = cutlass::half_t;
+// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
+// // This 3rd one is good for H100, and A100, A6000
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
+// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
+// // These two are always slower
+// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
+// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
+// });
+// }
+template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
+} \ No newline at end of file