summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorOlivierDehaene <Olivier.dehaene@gmail.com>2024-01-05 18:28:55 +0100
committerGitHub <noreply@github.com>2024-01-05 18:28:55 +0100
commit8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7 (patch)
tree289466c9df7a7f21ec1e574cd6cfd7b957998a08 /candle-flash-attn
parent3a7304cb0dbdf8ceeab8a4f5cf9b8e7ced822e20 (diff)
downloadcandle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.gz
candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.tar.bz2
candle-8d1a57c9a0465b201e4e9e410e2b8fcde37b35f7.zip
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/kernels/alibi.h62
-rw-r--r--candle-flash-attn/kernels/block_info.h13
-rw-r--r--candle-flash-attn/kernels/flash.h54
-rw-r--r--candle-flash-attn/kernels/flash_api.cu40
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu13
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu26
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu11
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu21
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu12
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu21
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu5
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu5
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu5
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu5
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu17
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu13
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu20
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu11
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu21
-rw-r--r--candle-flash-attn/kernels/flash_fwd_kernel.h282
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h63
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h77
-rw-r--r--candle-flash-attn/kernels/kernel_traits_sm90.h159
-rw-r--r--candle-flash-attn/kernels/softmax.h57
-rw-r--r--candle-flash-attn/kernels/utils.h92
-rw-r--r--candle-flash-attn/src/ffi.rs8
-rw-r--r--candle-flash-attn/src/lib.rs434
28 files changed, 1086 insertions, 465 deletions
diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h
new file mode 100644
index 00000000..1afb3687
--- /dev/null
+++ b/candle-flash-attn/kernels/alibi.h
@@ -0,0 +1,62 @@
+#include <cmath>
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/array.h>
+
+#include "utils.h"
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_causal, typename Engine, typename Layout>
+inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
+ const int col_idx_offset_,
+ const int max_seqlen_k,
+ const int row_idx_offset,
+ const int max_seqlen_q,
+ const int warp_row_stride,
+ const float alibi_slope) {
+ // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ #pragma unroll
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
+ }
+ }
+ }
+ } else { // Bias depends on both row_idx and col_idx
+ #pragma unroll
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+ #pragma unroll
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
+ const int row_idx = row_idx_base + i * 8;
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace flash
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h
index 94251a41..65435e51 100644
--- a/candle-flash-attn/kernels/block_info.h
+++ b/candle-flash-attn/kernels/block_info.h
@@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
- , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
+ , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
- , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
+ , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
@@ -32,8 +35,10 @@ struct BlockInfo {
const int sum_s_q;
const int sum_s_k;
- const uint32_t actual_seqlen_q;
- const uint32_t actual_seqlen_k;
+ const int actual_seqlen_q;
+ // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
+ const int seqlen_k_cache;
+ const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index be4ae0ca..80b517e9 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,15 +7,6 @@
#include <cuda.h>
#include <vector>
-// #ifdef OLD_GENERATOR_PATH
-// #include <ATen/CUDAGeneratorImpl.h>
-// #else
-// #include <ATen/cuda/CUDAGeneratorImpl.h>
-// #endif
-//
-// #include <ATen/cuda/CUDAGraphsUtils.cuh>
-
-
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
@@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
+ void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
@@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
+ void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
- int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel.
float scale_softmax;
@@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
+ // If provided, the actual length of each k sequence.
+ int * __restrict__ seqused_k;
+
int *__restrict__ blockmask;
+ // The K_new and V_new matrices.
+ void * __restrict__ knew_ptr;
+ void * __restrict__ vnew_ptr;
+
+ // The stride between rows of the Q, K and V matrices.
+ index_t knew_batch_stride;
+ index_t vnew_batch_stride;
+ index_t knew_row_stride;
+ index_t vnew_row_stride;
+ index_t knew_head_stride;
+ index_t vnew_head_stride;
+
+ // The cos and sin matrices for rotary embedding.
+ void * __restrict__ rotary_cos_ptr;
+ void * __restrict__ rotary_sin_ptr;
+
+ // The indices to index into the KV cache.
+ int *__restrict__ cache_batch_idx;
+
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
@@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout;
float scale_softmax_rp_dropout;
- // Random state.
- // at::PhiloxCudaState philox_args;
+ // Local window size
+ int window_size_left, window_size_right;
bool is_bf16;
bool is_causal;
+
+ // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
+ // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
+ bool is_seqlens_k_cumulative;
+
+ bool is_rotary_interleaved;
+
+ int num_splits; // For split-KV version
+
+ void * __restrict__ alibi_slopes_ptr;
+ index_t alibi_slopes_batch_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
+
+ bool deterministic;
+ index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 72991257..8113dbc7 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h"
-// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
-// FWD_HEADDIM_SWITCH(params.d, [&] {
-// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream);
-// });
-// }
-
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
- FP16_SWITCH(!params.is_bf16, [&] {
- FWD_HEADDIM_SWITCH(params.d, [&] {
- run_mha_fwd_<elem_type, kHeadDim>(params, stream);
- });
- });
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
+ run_mha_fwd_<elem_type, kHeadDim>(params, stream);
+// } else {
+// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
+// }
+ });
+ });
}
extern "C" void run_mha(
@@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
+ void *alibi_slopes_ptr,
int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr,
@@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride,
uint32_t v_batch_stride,
uint32_t o_batch_stride,
+ uint32_t alibi_slopes_batch_stride,
uint32_t q_row_stride,
uint32_t k_row_stride,
@@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded,
+ int is_bf16,
int is_causal,
- int is_bf16
+
+ int window_size_left,
+ int window_size_right
) {
Flash_fwd_params params;
// Reset the parameters
@@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr;
+ params.alibi_slopes_ptr = alibi_slopes_ptr;
// All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride;
+ params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride;
@@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
- params.is_causal = is_causal;
// Set the different scale values.
params.scale_softmax = softmax_scale;
@@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`.
+ params.seqused_k = nullptr;
+
+ params.is_causal = is_causal;
+ params.window_size_left = window_size_left;
+ params.window_size_right = window_size_right;
+
+ params.is_seqlens_k_cumulative = true;
+ params.num_splits = 1;
cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream);
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
index 654400a7..6ffa4126 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::bfloat16_t;
-// if (params.p_dropout == 1.f) {
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
-// } else {
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
-// }
-// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
index 5b7254a9..19b005ad 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -1,32 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// if (params.p_dropout == 1.f) {
-// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
-// // 1st ones are good for H100, A100
-// // 2nd one is good for A6000 bc we get slightly better occupancy
-// } else {
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
-// // 1st one is good for H100, A100, A6000
-// }
-// }
-
template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
index 6a9d60c3..f674f481 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::bfloat16_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// });
-// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
index 6c40a164..afd0a8a3 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
-// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
-// // For A100, H100, 1st is fastest.
-// });
-// }
template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
index d2f4cba7..aa91bdd6 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -1,16 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::bfloat16_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// });
-// }
-template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
index 2875c926..37a96526 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// // This one is slightly faster for causal?
-// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
-// });
-// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
-// // For A6000, 1st is faster when causal, 3rd is faster when not causal
-// }
template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
index 982fe7ea..167a0df2 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
index 4c083f7b..58ffe75c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
index cb074a95..1b370141 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
index ddf5e132..9f35129c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
index 81e359e1..770de6fc 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index 91e6331e..8dbf8b94 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// // For dropout there might be a lot of register spilling?
-// // These two are very slow due to register spilling
-// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
-// // This one is slightly slower
-// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
-// });
-// }
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
index fffcbebb..22eac878 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::bfloat16_t;
-// if (params.p_dropout == 1.f) {
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
-// } else {
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
-// }
-// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
index 01bd1716..e6da5dd2 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -1,26 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// if (params.p_dropout == 1.f) {
-// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
-// // Using block size (64 x 256) is 27% slower for seqlen=2k
-// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
-// } else {
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
-// }
-// }
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
index b0b27db5..9c003540 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::bfloat16_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
-// });
-// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
index 820b63cb..8108696a 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao.
-
// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h"
-// template<>
-// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
-// using elem_type = cutlass::half_t;
-// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
-// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
-// // This 3rd one is good for H100, and A100, A6000
-// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
-// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
-// // These two are always slower
-// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
-// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
-// });
-// }
-template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
+template<>
+void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
-} \ No newline at end of file
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 232dea0d..05f5f701 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -4,20 +4,18 @@
#pragma once
-#include <cmath>
#include <cute/algorithm/copy.hpp>
-#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
-#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
-#include "philox.cuh"
+
+#include "alibi.h"
namespace flash {
@@ -25,49 +23,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
-template <int MMA_M,
- class... Args,
- class TiledMMA>
-CUTE_HOST_DEVICE
-auto
-make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
- Stride<_1, Int<MMAStride_M>> >{},
- make_layout(size<2>(TileShape_MNK{})));
- // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template <int MMA_M,
- class... Args,
- class TiledMMA>
-CUTE_HOST_DEVICE
-auto
-make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
- Stride<_1, Int<MMAStride_M>> >{},
- // TODO: Shouldn't this be size<1>?
- make_layout(size<2>(TileShape_MNK{})));
- // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
@@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
- copy(scores_max, scores_max_prev);
+ cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
- Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
+ Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
- // TODO(laurent): reactivate the following
- // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
+ CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
- copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
+ cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
@@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
- const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
- if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+ const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
- if (Is_causal) {
- n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
+ if (Is_causal || Is_local) {
+ n_block_max = std::min(n_block_max,
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
+ // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
+ // Otherwise we might read OOB elements from gK and gV.
+ if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
+ // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
+ // exit early and no one saves the rng state.
+// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
+// auto seeds = at::cuda::philox::unpack(params.philox_args);
+// params.rng_state[0] = std::get<0>(seeds);
+// params.rng_state[1] = std::get<1>(seeds);
+// params.rng_state[0] = 0;
+// params.rng_state[1] = 0;
+// }
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(params.o_row_stride, _1{}));
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+ Tensor tOrO = make_tensor<Element>(shape(tOgO));
+ clear(tOrO);
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ );
+ #pragma unroll
+ for (int m = 0; m < size<1>(tOgO); ++m) {
+ const int row = get<0>(tOcO(0, m, 0));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
+ }
+ return;
+ }
+ // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
@@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
- auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
- auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
+ auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling
//
- auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
- // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
- auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
- auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
@@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
- flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
- binfo.actual_seqlen_q - m_block * kBlockM);
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem
@@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
- copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
- flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
- binfo.actual_seqlen_k - n_block * kBlockN);
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
@@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
- copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
// auto seeds = at::cuda::philox::unpack(params.philox_args);
@@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o);
+ float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
+
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
- constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+ constexpr int n_masking_steps = (!Is_causal && !Is_local)
+ ? 1
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
- flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
- gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
- acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- // if (cute::thread0()) { print(scores); }
+ // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
- if (!Is_causal) {
- if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
+
+ if (Has_alibi) {
+ flash::apply_alibi<Is_causal>(
+ scores,
+ n_block * kBlockN,
+ binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16,
+ alibi_slope
+ );
+ }
+
+ if (!Is_causal && !Is_local) {
+ if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
@@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
- flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
- // m_block * kBlockM + get<0>(idx_row(0)),
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- kNWarps * 16);
- // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
- // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
+ flash::apply_mask_local</*HasWSLeft=*/Is_local>(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ // m_block * kBlockM + get<0>(idx_row(0)),
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right
+ // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
+ // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
+ );
+ // if (cute::thread0()) { print_tensor(scores); }
}
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
- ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
- : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
- uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
- uint32_t block_col_idx = n_block * (kBlockN / 32);
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+ int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
- copy(tOrP, tOrP_copy);
+ cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
@@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// if (cute::thread0()) { print(tOrP); }
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
- if (n_masking_steps > 1 && n_block <= 0) {
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
- for (; n_block >= 0; --n_block) {
+ for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
- acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
);
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+ if (Has_alibi) {
+ flash::apply_alibi<Is_causal>(
+ scores,
+ n_block * kBlockN,
+ binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16,
+ alibi_slope
+ );
+ }
+
+ if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+ flash::apply_mask_local(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right
+ );
+ }
+
+ softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
- uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
- uint32_t block_col_idx = n_block * (kBlockN / 32);
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+ int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
- copy(tOrP, tOrP_copy);
+ cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
@@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps);
}
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
@@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
- auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
- // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
- copy(smem_thr_copy_O, taccOrO, taccOsO);
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
@@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
- auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO));
- copy(gmem_thr_copy_O, tOsO, tOrO);
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
@@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
@@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
- flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+ flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 398ce077..66ab6206 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -4,15 +4,14 @@
#pragma once
-// #include <ATen/cuda/CUDAContext.h>
-
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
- flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
+ static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
+ flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
- // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
- // for cu_seqlens_q as well.
- const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
+ const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
- BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
+ BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
- BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
- // Will only return softmax if dropout, to reduce compilation time.
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
- // auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
- // if (smem_size >= 48 * 1024) {
- // C10_CUDA_CHECK(cudaFuncSetAttribute(
- // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- // }
- int ctas_per_sm;
- cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
- &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
- kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
- // C10_CUDA_KERNEL_LAUNCH_CHECK();
+ BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+ BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
+ BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+ // Will only return softmax if dropout, to reduce compilation time.
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+ // If return_softmax, set IsEvenMNConst to false to reduce number of templates
+ // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+ // If Is_local, set Is_causal to false
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
+ // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
+ // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
+ // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+ // int ctas_per_sm;
+ // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+ // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+ kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+ });
+ });
});
});
});
}
+
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 32;
+ constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 64;
+ constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
@@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 96;
+ constexpr static int Headdim = 96;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 128;
+ constexpr static int Headdim = 128;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 160;
+ constexpr static int Headdim = 160;
// auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 192;
+ constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
@@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 224;
+ constexpr static int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
@@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
- constexpr int Headdim = 256;
+ constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index 3468e4bf..f000ff24 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+ // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
+ using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
- using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+ using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomVtransposedNoSwizzle{},
+ Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+ // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
- using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
+ using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
@@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
- make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+ make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+ using GmemLayoutAtomOaccum = std::conditional_t<
+ kBlockKSmem == 32,
+ Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
+ Stride< _8, _1>>,
+ Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
+ Stride< _16, _1>>
+ >;
+ using GmemTiledCopyOaccum = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ GmemLayoutAtomOaccum{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
+ using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+ using GmemTiledCopyRotcossin = decltype(
+ make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+ GmemLayoutAtomRotcossin{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
+ using GmemTiledCopyRotcossinCont = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ GmemLayoutAtomRotcossin{},
+ Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
+ using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
- using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+ using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomKtransposedNoSwizzle{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
+ // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
+ using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
+ Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype(
- composition(Swizzle<kSwizzlePdS, 3, 3>{},
- Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
- Stride<_1, Int<kPBlockN>>>{}));
+ composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+ using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomPdStransposedNoSwizzle{},
+ make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
+ // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomQdOtransposedNoSwizzle{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
+ // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
- static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
- static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
diff --git a/candle-flash-attn/kernels/kernel_traits_sm90.h b/candle-flash-attn/kernels/kernel_traits_sm90.h
new file mode 100644
index 00000000..e07f3839
--- /dev/null
+++ b/candle-flash-attn/kernels/kernel_traits_sm90.h
@@ -0,0 +1,159 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "cute/algorithm/copy.hpp"
+
+#include "cutlass/cutlass.h"
+#include "cutlass/layout/layout.h"
+#include <cutlass/numeric_types.h>
+
+using namespace cute;
+
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
+struct Flash_kernel_traits_sm90 {
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ using Element = elem_type;
+ static constexpr bool Has_cp_async = true;
+#else
+ using Element = cutlass::half_t;
+ static constexpr bool Has_cp_async = false;
+#endif
+
+ using ElementAccum = float;
+ using index_t = uint32_t;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ using MMA_Atom_Arch = std::conditional_t<
+ std::is_same_v<elem_type, cutlass::half_t>,
+ MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
+ MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
+ >;
+ using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
+#else
+ using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
+ using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
+#endif
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
+ using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
+ using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
+#else
+ using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
+#endif
+};
+
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
+ typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
+struct Flash_fwd_kernel_traits : public Base {
+ using Element = typename Base::Element;
+ using ElementAccum = typename Base::ElementAccum;
+ using index_t = typename Base::index_t;
+ static constexpr bool Has_cp_async = Base::Has_cp_async;
+ using SmemCopyAtom = typename Base::SmemCopyAtom;
+ using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+ static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
+ static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
+
+ // The number of threads.
+ static constexpr int kNWarps = kNWarps_;
+ static constexpr int kNThreads = kNWarps * 32;
+
+ static constexpr int kBlockM = kBlockM_;
+ static constexpr int kBlockN = kBlockN_;
+ static constexpr int kHeadDim = kHeadDim_;
+ static_assert(kHeadDim % 32 == 0);
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+ static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+ static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+ using TiledMma = TiledMMA<
+ typename Base::MMA_Atom_Arch,
+ Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
+ typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+ using SmemLayoutAtomQ = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
+ Layout<Shape<_8, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutQ = decltype(tile_to_shape(
+ SmemLayoutAtomQ{},
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+
+ using SmemLayoutKV = decltype(tile_to_shape(
+ SmemLayoutAtomQ{},
+ Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+
+ using SmemLayoutAtomVtransposed = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
+ Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>{}));
+ using SmemLayoutVtransposed = decltype(tile_to_shape(
+ SmemLayoutAtomVtransposed{},
+ Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+ // Maybe the VtransposeNoSwizzle just needs to have the right shape
+ // And the strides don't matter?
+ using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+
+ using SmemLayoutAtomO = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutO = decltype(tile_to_shape(
+ SmemLayoutAtomO{},
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+ using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
+
+ static constexpr int kSmemQCount = size(SmemLayoutQ{});
+ static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
+ static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
+ static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+ static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
+
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+ // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
+ // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
+ // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
+ // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
+ // to the same banks.
+ static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+ static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+ using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
+
+ // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+ // from the same address by the same threadblock. This is slightly faster.
+ using Gmem_copy_struct = std::conditional_t<
+ Has_cp_async,
+ SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+ DefaultCopy
+ >;
+ using GmemTiledCopyQKV = decltype(
+ make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
+ using GmemTiledCopyO = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+ static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
+ static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
+ using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
+ Stride<Int<kGmemThreadsPerRowP>, _1>>;
+
+ using GmemTiledCopyP = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtomP{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h
index 3e9a7b45..09a93f14 100644
--- a/candle-flash-attn/kernels/softmax.h
+++ b/candle-flash-attn/kernels/softmax.h
@@ -8,8 +8,7 @@
#include <cute/tensor.hpp>
-#include <cutlass/cutlass.h>
-#include <cutlass/array.h>
+#include <cutlass/numeric_types.h>
#include "philox.cuh"
#include "utils.h"
@@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template <typename Engine, typename Layout>
-inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
+inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
+ const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const uint32_t lane_id = threadIdx.x % 32;
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
+ const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
@@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
}
}
-template <typename Engine, typename Layout>
-inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
- const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
- const uint32_t warp_row_stride) {
+template <bool HasWSLeft=true, typename Engine, typename Layout>
+inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride,
+ const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const uint32_t lane_id = threadIdx.x % 32;
- // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
- const uint32_t row_idx_offset = row_idx_offset_;
- const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
- const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
- const uint32_t row_idx = row_idx_base + i * 8;
- const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
+ const int row_idx = row_idx_base + i * 8;
+ const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+ const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const uint32_t col_idx_base = col_idx_offset + nj * 8;
+ const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const uint32_t col_idx = col_idx_base + j;
- if (col_idx >= col_idx_limit) {
+ const int col_idx = col_idx_base + j;
+ if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
@@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
}
}
+template <typename Engine, typename Layout>
+inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride) {
+ // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+ apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
+ max_seqlen_q, warp_row_stride, -1, 0);
+}
+
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
- const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
+ const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
@@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
- const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
+ const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
@@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset,
- uint32_t block_row_start, uint32_t block_col_start,
- uint32_t block_row_stride) {
+ int block_row_start, int block_col_start,
+ int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 2221a2fa..6fb39dc4 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -88,46 +88,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
-inline __device__ float2 half2_unpack(uint32_t a);
-
-template <>
-inline __device__ float2 half2_unpack<__half>(uint32_t a) {
- return __half22float2(reinterpret_cast<__half2 (&)>(a));
-}
-
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
-template <>
-inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
- return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
-}
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-// Convert two half2's or bf162's into float, then take their dot product.
-template <typename T>
-inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
- float2 af = flash::half2_unpack<T>(a);
- float2 bf = flash::half2_unpack<T>(b);
- return af.x * bf.x + af.y * bf.y;
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
-template<typename T>
-inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
- float sum;
- sum = flash::hfma2_to_float<T>(a.x, b.x);
- sum += flash::hfma2_to_float<T>(a.y, b.y);
- sum += flash::hfma2_to_float<T>(a.z, b.z);
- sum += flash::hfma2_to_float<T>(a.w, b.w);
- return sum;
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
@@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4,
- typename TiledMma, typename TiledCopy0, typename TiledCopy1>
+ typename TiledMma, typename TiledCopyA, typename TiledCopyB,
+ typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
- TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
+ TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
+ ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
- if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
- if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
+ if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
+ if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
- if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
- if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
+ if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
+ if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
@@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
- typename TiledMma, typename TiledCopy>
+ typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
- TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
+ TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
+ ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
- copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
+ cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
- copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
+ cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
}
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
@@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
- return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+ // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
+ // "int_tuple.hpp(74): error: conversion to inaccessible base class"
+ // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+ return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
- return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
- get<0, 1>(l),
- get<1, 1, 1>(l));
+ // TD [2023-08-13]: Same error as above on Cutlass 3.2
+ // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
+ // get<0, 1>(l),
+ // get<1, 1, 1>(l));
+ return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
+ get<1>(get<0>(l)),
+ get<1>(get<1>(get<1>(l))));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -319,9 +289,9 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
-inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
+inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
- Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
+ Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
@@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
- copy(thr_copy, S(_, m, k), D(_, m, k));
+ cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
- clear(D(_, m, k));
+ cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
- clear(D(_, m, _));
+ cute::clear(D(_, m, _));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
@@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
- // copy(thr_copy, S(_, m, _), D(_, m, _));
+ // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
@@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
- // copy(thr_copy, S(_, m, k), D(_, m, k));
+ // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
diff --git a/candle-flash-attn/src/ffi.rs b/candle-flash-attn/src/ffi.rs
index 90f34e43..ca65520b 100644
--- a/candle-flash-attn/src/ffi.rs
+++ b/candle-flash-attn/src/ffi.rs
@@ -7,6 +7,8 @@ extern "C" {
v_ptr: *const c_void,
o_ptr: *const c_void,
softmax_lse_ptr: *const c_void,
+ alibi_slopes_ptr: *const c_void,
+
cu_seqlens_q_ptr: *const i32,
cu_seqlens_k_ptr: *const i32,
@@ -14,6 +16,7 @@ extern "C" {
k_batch_stride: u32,
v_batch_stride: u32,
o_batch_stride: u32,
+ alibi_slopes_batch_stride: u32,
q_row_stride: u32,
k_row_stride: u32,
@@ -37,8 +40,11 @@ extern "C" {
seqlen_q_rounded: u32,
seqlen_k_rounded: u32,
- is_causal: c_int,
is_bf16: c_int,
+ is_causal: c_int,
+
+ window_size_left: c_int,
+ window_size_right: c_int,
);
}
diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs
index 3395bd0d..21a06b5e 100644
--- a/candle-flash-attn/src/lib.rs
+++ b/candle-flash-attn/src/lib.rs
@@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
-use candle::{CpuStorage, Layout, Result, Shape, Tensor};
+use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
pub struct FlashAttn {
pub softmax_scale: f32,
- pub causal: bool,
+ pub alibi_slopes: Option<Tensor>,
+ pub window_size_left: Option<usize>,
+ pub window_size_right: Option<usize>,
}
fn round_multiple(x: usize, m: usize) -> usize {
@@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
+ let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
+ if alibi_slopes.dtype() != DType::F32 {
+ candle::bail!(
+ "DType mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes.dtype(),
+ DType::F32
+ );
+ }
+
+ let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
+
+ if num_heads != alibi_slopes_layout.shape().dims1()? {
+ candle::bail!(
+ "shape mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes_layout.shape(),
+ (num_heads)
+ );
+ }
+
+ let alibi_slopes = match &*alibi_slopes {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
+ _ => candle::bail!("alibi_slopes must be a cuda tensor"),
+ };
+
+ let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
+
+ *alibi_slopes.device_ptr() as *const core::ffi::c_void
+ } else {
+ std::ptr::null()
+ };
+
+ // if window_size_left > self.max_seqlen_k or None => -1
+ let mut window_size_left = self
+ .window_size_left
+ .filter(|v| v <= &seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
+ // if window_size_right > self.max_seqlen_k or None => -1
+ let mut window_size_right = self
+ .window_size_right
+ .filter(|v| v <= &seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
- let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+ let is_causal = if window_size_left < 0 && window_size_right == 0 {
+ 1
+ } else {
+ 0
+ };
+ if window_size_left < 0 && window_size_right >= 0 {
+ window_size_left = seqlen_k as i32;
+ }
+ if window_size_left >= 0 && window_size_right < 0 {
+ window_size_right = seqlen_k as i32;
+ }
+
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr,
dst_ptr,
softmax_lse_ptr,
+ /* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32,
+ /* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
- /* is_causal */ causal,
/* is_bf16 */ is_bf16,
+ /* is_causal */ is_causal,
+ /* window_size_left */ window_size_left,
+ /* window_size_right */ window_size_right,
)
}
@@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
let op = FlashAttn {
softmax_scale,
- causal,
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
};
q.apply_op3(k, v, op)
}
-struct FlashAttnVarLen {
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_alibi(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
- max_seqlen_q: usize,
- max_seqlen_k: usize,
- seqlens_q: Tensor,
- seqlens_k: Tensor,
+) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+/// Flash-attention v2 layer.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+///
+/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
+pub fn flash_attn_alibi_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttn {
+ softmax_scale,
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+struct FlashAttnVarLen {
+ pub softmax_scale: f32,
+ pub max_seqlen_q: usize,
+ pub max_seqlen_k: usize,
+ pub seqlens_q: Tensor,
+ pub seqlens_k: Tensor,
+ pub alibi_slopes: Option<Tensor>,
+ pub window_size_left: Option<usize>,
+ pub window_size_right: Option<usize>,
}
impl FlashAttnVarLen {
@@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
}
+
let batch_size = nseqlens_q - 1;
+
+ let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
+ if alibi_slopes.dtype() != DType::F32 {
+ candle::bail!(
+ "DType mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes.dtype(),
+ DType::F32
+ );
+ }
+
+ let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
+
+ if num_heads != alibi_slopes_layout.shape().dims1()? {
+ candle::bail!(
+ "shape mismatch alibi_slopes {:?}, expected {:?}",
+ alibi_slopes_layout.shape(),
+ (num_heads)
+ );
+ }
+
+ let alibi_slopes = match &*alibi_slopes {
+ candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
+ _ => candle::bail!("alibi_slopes must be a cuda tensor"),
+ };
+
+ let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
+
+ *alibi_slopes.device_ptr() as *const core::ffi::c_void
+ } else {
+ std::ptr::null()
+ };
+
+ // if window_size_left > self.max_seqlen_k or None => -1
+ let mut window_size_left = self
+ .window_size_left
+ .filter(|v| v <= &self.max_seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
+ // if window_size_right > self.max_seqlen_k or None => -1
+ let mut window_size_right = self
+ .window_size_right
+ .filter(|v| v <= &self.max_seqlen_k)
+ .map(|v| v as i32)
+ .unwrap_or(-1);
+
let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?;
- let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 };
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+ let is_causal = if window_size_left < 0 && window_size_right == 0 {
+ 1
+ } else {
+ 0
+ };
+ if window_size_left < 0 && window_size_right >= 0 {
+ window_size_left = self.max_seqlen_k as i32;
+ }
+ if window_size_left >= 0 && window_size_right < 0 {
+ window_size_right = self.max_seqlen_k as i32;
+ }
+
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr,
dst_ptr,
softmax_lse_ptr,
+ /* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,
/* o_batch_stride */ 0,
+ /* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32,
@@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
- /* is_causal */ causal,
/* is_bf16 */ is_bf16,
+ /* is_causal */ is_causal,
+ /* window_size_left */ window_size_left,
+ /* window_size_right */ window_size_right,
)
}
@@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+pub fn flash_attn_varlen_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: None,
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+pub fn flash_attn_varlen_alibi(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ causal: bool,
+) -> Result<Tensor> {
+ let window_size_left = None;
+ let window_size_right = if causal { Some(0) } else { None };
+
+ let op = FlashAttnVarLen {
+ softmax_scale,
+ max_seqlen_q,
+ max_seqlen_k,
+ seqlens_q: seqlens_q.clone(),
+ seqlens_k: seqlens_k.clone(),
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
+ };
+ q.apply_op3(k, v, op)
+}
+
+#[allow(clippy::too_many_arguments)]
+/// Flash-attention v2 layer with variable-length batching.
+///
+/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
+/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
+/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
+///
+/// # Arguments
+///
+/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
+/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
+/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
+/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
+/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
+/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
+/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
+/// * `window_size_left` - Limit left attention to value tokens.
+/// * `window_size_right` - Limit right attention to value tokens.
+///
+/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
+/// `seqlen_1 + seqlen_2`, etc.
+///
+/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
+///
+/// # Causal mask
+///
+/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
+/// of `Q @ K^T`
+pub fn flash_attn_varlen_alibi_windowed(
+ q: &Tensor,
+ k: &Tensor,
+ v: &Tensor,
+ alibi_slopes: &Tensor,
+ seqlens_q: &Tensor,
+ seqlens_k: &Tensor,
+ max_seqlen_q: usize,
+ max_seqlen_k: usize,
+ softmax_scale: f32,
+ window_size_left: Option<usize>,
+ window_size_right: Option<usize>,
+) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
- causal,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
+ alibi_slopes: Some(alibi_slopes.clone()),
+ window_size_left,
+ window_size_right,
};
q.apply_op3(k, v, op)
}