summaryrefslogtreecommitdiff
path: root/candle-flash-attn
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2024-07-15 20:37:36 +0200
committerGitHub <noreply@github.com>2024-07-15 20:37:36 +0200
commit30cdd769f9404035235830e602ae01d50f782fb5 (patch)
treedd9d8adcfce61fe40678f2967bbb25cecb1f679a /candle-flash-attn
parentd74fbed3341f875fa81112e2f59565c464cd59d8 (diff)
downloadcandle-30cdd769f9404035235830e602ae01d50f782fb5.tar.gz
candle-30cdd769f9404035235830e602ae01d50f782fb5.tar.bz2
candle-30cdd769f9404035235830e602ae01d50f782fb5.zip
Update the flash attn kernels. (#2333)
Diffstat (limited to 'candle-flash-attn')
-rw-r--r--candle-flash-attn/build.rs18
m---------candle-flash-attn/cutlass0
-rw-r--r--candle-flash-attn/kernels/alibi.h78
-rw-r--r--candle-flash-attn/kernels/block_info.h4
-rw-r--r--candle-flash-attn/kernels/dropout.h94
-rw-r--r--candle-flash-attn/kernels/error.h8
-rw-r--r--candle-flash-attn/kernels/flash.h35
-rw-r--r--candle-flash-attn/kernels/flash_api.cu20
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu10
-rw-r--r--candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu4
-rw-r--r--candle-flash-attn/kernels/flash_fwd_kernel.h1151
-rw-r--r--candle-flash-attn/kernels/flash_fwd_launch_template.h424
-rw-r--r--candle-flash-attn/kernels/kernel_helpers.h50
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h119
-rw-r--r--candle-flash-attn/kernels/kernels.h58
-rw-r--r--candle-flash-attn/kernels/mask.h213
-rw-r--r--candle-flash-attn/kernels/philox.cuh120
-rw-r--r--candle-flash-attn/kernels/rotary.h152
-rw-r--r--candle-flash-attn/kernels/softmax.h237
-rw-r--r--candle-flash-attn/kernels/static_switch.h53
-rw-r--r--candle-flash-attn/kernels/utils.h115
51 files changed, 2274 insertions, 899 deletions
diff --git a/candle-flash-attn/build.rs b/candle-flash-attn/build.rs
index 4002770b..53fec5de 100644
--- a/candle-flash-attn/build.rs
+++ b/candle-flash-attn/build.rs
@@ -4,7 +4,7 @@
use anyhow::{Context, Result};
use std::path::PathBuf;
-const KERNEL_FILES: [&str; 17] = [
+const KERNEL_FILES: [&str; 33] = [
"kernels/flash_api.cu",
"kernels/flash_fwd_hdim128_fp16_sm80.cu",
"kernels/flash_fwd_hdim160_fp16_sm80.cu",
@@ -22,6 +22,22 @@ const KERNEL_FILES: [&str; 17] = [
"kernels/flash_fwd_hdim32_bf16_sm80.cu",
"kernels/flash_fwd_hdim64_bf16_sm80.cu",
"kernels/flash_fwd_hdim96_bf16_sm80.cu",
+ "kernels/flash_fwd_hdim128_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim160_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim192_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim224_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim256_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim32_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim64_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim96_fp16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim128_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim160_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim192_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim224_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim256_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim32_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim64_bf16_causal_sm80.cu",
+ "kernels/flash_fwd_hdim96_bf16_causal_sm80.cu",
];
fn main() -> Result<()> {
diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass
-Subproject c4f6b8c6bc94ff69048492fb34df0dfaf198393
+Subproject 7d49e6c7e2f8896c47f586706e67e1fb215529d
diff --git a/candle-flash-attn/kernels/alibi.h b/candle-flash-attn/kernels/alibi.h
index 1afb3687..e714233e 100644
--- a/candle-flash-attn/kernels/alibi.h
+++ b/candle-flash-attn/kernels/alibi.h
@@ -13,50 +13,62 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
-template <bool Is_causal, typename Engine, typename Layout>
-inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
- const int col_idx_offset_,
- const int max_seqlen_k,
- const int row_idx_offset,
- const int max_seqlen_q,
- const int warp_row_stride,
- const float alibi_slope) {
- // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
- static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const int lane_id = threadIdx.x % 32;
- const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
- if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
- #pragma unroll
- for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const int col_idx_base = col_idx_offset + nj * 8;
+template <bool Is_causal>
+struct Alibi {
+
+ const float alibi_slope;
+ const int max_seqlen_k, max_seqlen_q;
+
+ __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
+ : alibi_slope(alibi_slope)
+ , max_seqlen_k(max_seqlen_k)
+ , max_seqlen_q(max_seqlen_q) {
+ };
+
+
+ template <typename Engine, typename Layout>
+ __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
+ const int col_idx_offset_,
+ const int row_idx_offset,
+ const int warp_row_stride) {
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
- for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const int col_idx = col_idx_base + j;
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
- for (int mi = 0; mi < size<0>(tensor); ++mi) {
- tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ #pragma unroll
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
+ }
}
}
- }
- } else { // Bias depends on both row_idx and col_idx
- #pragma unroll
- for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
- const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+ } else { // Bias depends on both row_idx and col_idx
#pragma unroll
- for (int i = 0; i < size<0, 0>(tensor); ++i) {
- const int row_idx = row_idx_base + i * 8;
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
- for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const int col_idx_base = col_idx_offset + nj * 8;
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
+ const int row_idx = row_idx_base + i * 8;
#pragma unroll
- for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const int col_idx = col_idx_base + j;
- tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
+ }
}
}
}
}
}
-}
+
+};
} // namespace flash
diff --git a/candle-flash-attn/kernels/block_info.h b/candle-flash-attn/kernels/block_info.h
index 65435e51..3a23a1e1 100644
--- a/candle-flash-attn/kernels/block_info.h
+++ b/candle-flash-attn/kernels/block_info.h
@@ -24,12 +24,12 @@ struct BlockInfo {
}
template <typename index_t>
- inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+ __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
- inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+ __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
diff --git a/candle-flash-attn/kernels/dropout.h b/candle-flash-attn/kernels/dropout.h
new file mode 100644
index 00000000..4882f97d
--- /dev/null
+++ b/candle-flash-attn/kernels/dropout.h
@@ -0,0 +1,94 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "philox.cuh"
+#include "utils.h"
+
+namespace flash {
+
+struct Dropout {
+
+ const unsigned long long seed, offset;
+ const uint8_t p_dropout_in_uint8_t;
+
+ __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
+ const uint8_t p_dropout_in_uint8_t,
+ const int bid, const int hid, const int tid, const int nheads)
+ : seed(seed)
+ , offset(offset + (bid * nheads + hid) * 32 + tid % 32)
+ , p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
+ }
+
+ template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
+ __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
+ int block_row_start, int block_col_start, int block_row_stride) {
+ // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
+ Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
+ using T = typename Engine::value_type;
+ auto encode_dropout = [](bool keep, T val) {
+ return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
+ };
+ static_assert(decltype(size<2>(tensor))::value % 2 == 0);
+ const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
+ const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
+ // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
+ #pragma unroll
+ for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
+ uint2 rowcol = make_uint2(block_row_start, block_col_start);
+ #pragma unroll
+ for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
+ // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
+ uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
+ // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
+ uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
+ // Special implementation for 16-bit types: we duplicate the threshold to the
+ // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
+ // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
+ // and the high 16 bits will be either 0xffff or 0x0000, depending on whether
+ // the random value is less than the threshold.
+ // We then do a bit-wise AND between the mask and the original value (in 32-bit).
+ // We're exploiting the fact that floating point comparison is equivalent to integer
+ // comparison, since we're comparing unsigned integers whose top 8-bits are zero.
+ if (!encode_dropout_in_sign_bit
+ && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
+ uint16_t rnd_16[16];
+ #pragma unroll
+ for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
+ uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
+ #pragma unroll
+ for (int j = 0; j < 2; j++) {
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
+ // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
+ #pragma unroll
+ for (int i = 0; i < 4; i++) {
+ uint32_t mask;
+ asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
+ tensor_uint32(i) &= mask;
+ }
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
+ }
+ } else {
+ #pragma unroll
+ for (int j = 0; j < 2; j++) {
+ #pragma unroll
+ for (int i = 0; i < 8; i++) {
+ tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
+ }
+ Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
+ // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
+ }
+ }
+ // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
+ // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
+ // // }
+ }
+ }
+ }
+
+};
+
+} // namespace flash
diff --git a/candle-flash-attn/kernels/error.h b/candle-flash-attn/kernels/error.h
new file mode 100644
index 00000000..03416924
--- /dev/null
+++ b/candle-flash-attn/kernels/error.h
@@ -0,0 +1,8 @@
+#pragma once
+
+#define C10_CUDA_CHECK(EXPR) \
+ do { \
+ const cudaError_t __err = EXPR; \
+ } while (0)
+
+#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
diff --git a/candle-flash-attn/kernels/flash.h b/candle-flash-attn/kernels/flash.h
index 80b517e9..88c2f22a 100644
--- a/candle-flash-attn/kernels/flash.h
+++ b/candle-flash-attn/kernels/flash.h
@@ -7,6 +7,14 @@
#include <cuda.h>
#include <vector>
+// #ifdef OLD_GENERATOR_PATH
+// #include <ATen/CUDAGeneratorImpl.h>
+// #else
+// #include <ATen/cuda/CUDAGeneratorImpl.h>
+// #endif
+//
+// #include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
+
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
@@ -14,7 +22,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
- using index_t = uint32_t;
+ using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
@@ -59,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
- int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
+ int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q;
// The scaling factors for the kernel.
float scale_softmax;
@@ -91,7 +99,12 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
- int *__restrict__ cache_batch_idx;
+ int * __restrict__ cache_batch_idx;
+
+ // Paged KV cache
+ int * __restrict__ block_table;
+ index_t block_table_batch_stride;
+ int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
@@ -105,6 +118,13 @@ struct Flash_fwd_params : public Qkv_params {
// Local window size
int window_size_left, window_size_right;
+ float softcap;
+
+ // Random state.
+ // at::PhiloxCudaState philox_args;
+
+ // Pointer to the RNG seed (idx 0) and offset (idx 1).
+ uint64_t * rng_state;
bool is_bf16;
bool is_causal;
@@ -119,6 +139,9 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
+
+ bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
+ bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -165,7 +188,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
+template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
-template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);
+template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
diff --git a/candle-flash-attn/kernels/flash_api.cu b/candle-flash-attn/kernels/flash_api.cu
index 8113dbc7..4ca41b0a 100644
--- a/candle-flash-attn/kernels/flash_api.cu
+++ b/candle-flash-attn/kernels/flash_api.cu
@@ -1,15 +1,15 @@
+#include "kernels.h"
+#include "kernel_helpers.h"
#include "flash_fwd_launch_template.h"
-void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
- FP16_SWITCH(!params.is_bf16, [&] {
- FWD_HEADDIM_SWITCH(params.d, [&] {
-// if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
- run_mha_fwd_<elem_type, kHeadDim>(params, stream);
-// } else {
-// run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
-// }
- });
- });
+void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ HEADDIM_SWITCH(params.d, [&] {
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+ run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
+ });
+ });
+ });
}
extern "C" void run_mha(
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
new file mode 100644
index 00000000..f19049b4
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
index 6ffa4126..cb135741 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
new file mode 100644
index 00000000..dfb04b78
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
index 19b005ad..6df16b2c 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim128_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
new file mode 100644
index 00000000..230af906
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
index f674f481..cf1ffad2 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
new file mode 100644
index 00000000..1fc5ac59
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
index afd0a8a3..a9796ade 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim160_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
new file mode 100644
index 00000000..94792d4d
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
index aa91bdd6..76d5136b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
new file mode 100644
index 00000000..9e5b21e0
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
index 37a96526..b4019a0b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim192_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
new file mode 100644
index 00000000..a12a5f4a
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
index 167a0df2..8690bdb1 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
new file mode 100644
index 00000000..f01dad09
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
index 58ffe75c..7ec1e16b 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim224_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim224<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
new file mode 100644
index 00000000..3d816ab6
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
index 1b370141..c6c55229 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
new file mode 100644
index 00000000..0149abac
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
index 9f35129c..9c9a1715 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim256_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
new file mode 100644
index 00000000..29097ac3
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
index 770de6fc..cb52f34f 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
new file mode 100644
index 00000000..7bdadefb
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
index 8dbf8b94..44b38816 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim32_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
new file mode 100644
index 00000000..99cd728b
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
index 22eac878..c11096ac 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
new file mode 100644
index 00000000..2fbcd44e
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
index e6da5dd2..7b65a9c9 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim64_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
new file mode 100644
index 00000000..6fb3cf64
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
index 9c003540..e696b2f2 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_bf16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
+void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
new file mode 100644
index 00000000..bb3b744d
--- /dev/null
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_causal_sm80.cu
@@ -0,0 +1,10 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template<>
+void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
+}
diff --git a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
index 8108696a..5f3accc3 100644
--- a/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
+++ b/candle-flash-attn/kernels/flash_fwd_hdim96_fp16_sm80.cu
@@ -5,6 +5,6 @@
#include "flash_fwd_launch_template.h"
template<>
-void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
- run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
+void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
+ run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
}
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 05f5f701..1bf77f81 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -1,10 +1,10 @@
/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
+ * Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
-#include <cute/algorithm/copy.hpp>
+#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
@@ -14,66 +14,46 @@
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
-
-#include "alibi.h"
+#include "mask.h"
+#include "dropout.h"
+#include "rotary.h"
namespace flash {
using namespace cute;
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
-inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
- Tensor2 &acc_o, float softmax_scale_log2) {
- if (Is_first) {
- flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
- flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
- flash::reduce_sum(scores, scores_sum);
- } else {
- Tensor scores_max_prev = make_fragment_like(scores_max);
- cute::copy(scores_max, scores_max_prev);
- flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
- // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
- Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
- #pragma unroll
- for (int mi = 0; mi < size(scores_max); ++mi) {
- float scores_max_cur = !Check_inf
- ? scores_max(mi)
- : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
- float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
- scores_sum(mi) *= scores_scale;
- #pragma unroll
- for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
- }
- flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
- Tensor scores_sum_cur = make_fragment_like(scores_sum);
- flash::reduce_sum(scores, scores_sum_cur);
- #pragma unroll
- for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
+ #pragma unroll
+ for (int i = 0; i < size(tensor); ++i) {
+ tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
-};
+}
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
-inline __device__ void write_softmax_to_gmem(
- Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
-) {
- // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
- Layout l = tOrP.layout();
- Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
- CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
- CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
- #pragma unroll
- for (int mi = 0; mi < size<1>(tPrP); ++mi) {
- cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
- }
-};
+template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
+__forceinline__ __device__ auto get_lse_tile(const Params &params, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) {
+ // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path.
+ // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick.
+ // Otherwise, it's written as (h, b, seqlen_q).
+ const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
+ auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
+ auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
+
+ auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q);
+ auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : (
+ params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1)
+ );
-////////////////////////////////////////////////////////////////////////////////////////////////////
+ auto lse_layout = make_layout(lse_shape, lse_stride);
+ Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
+ auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
+ return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
+}
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
@@ -90,7 +70,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
- constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
+
+ auto seed_offset = std::make_tuple(0ull, 0ull);
+ // auto seed_offset = at::cuda::philox::unpack(params.philox_args);
+ flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
+ bidb, bidh, tidx, params.h);
+
+ // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might
+ // exit early and no one saves the rng states.
+ if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
+ params.rng_state[0] = std::get<0>(seed_offset);
+ params.rng_state[1] = std::get<1>(seed_offset);
+ }
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
@@ -107,23 +98,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
- // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
- // exit early and no one saves the rng state.
-// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
-// auto seeds = at::cuda::philox::unpack(params.philox_args);
-// params.rng_state[0] = std::get<0>(seeds);
-// params.rng_state[1] = std::get<1>(seeds);
-// params.rng_state[0] = 0;
-// params.rng_state[1] = 0;
-// }
- const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
- + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
- const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
- Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.o_row_stride, _1{}));
- Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
- Shape<Int<kBlockM>>{}, Stride<_1>{});
+ Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
+ make_stride(params.o_row_stride, params.o_head_stride, _1{}));
+ Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
+
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -156,25 +138,27 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
- const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
- + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
- // We move K and V to the last block.
- const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
- + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
- const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
- + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
- Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.q_row_stride, _1{}));
- Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- make_stride(params.k_row_stride, _1{}));
- Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
- Shape<Int<kBlockN>, Int<kHeadDim>>{},
- make_stride(params.v_row_stride, _1{}));
+ Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
+ make_stride(params.q_row_stride, params.q_head_stride, _1{}));
+ Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
+ Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
+ make_stride(params.k_row_stride, params.k_head_stride, _1{}));
+ Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
+ Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
+ make_stride(params.v_row_stride, params.v_head_stride, _1{}));
+ Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN)
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
@@ -186,20 +170,17 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
- Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+ Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
- typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
- auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
- Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
- Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
+ Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
- Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
@@ -207,6 +188,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
+ Tensor tSgS = thr_mma.partition_C(gP);
+
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
//
@@ -227,10 +210,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
- // TODO: this might need to change if we change the mma instruction in SM70
- Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
- Tensor scores_sum = make_fragment_like(scores_max);
-
//
// PREDICATES
//
@@ -273,16 +252,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Prologue
- Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
- // // Copy rmem to smem
- // // copy(tQrQ, tQsQ);
- // flash::cp_async_wait<0>();
- // __syncthreads();
// // if (cute::thread(1, 0)) { print(tQsQ); }
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
// // if (cute::thread0()) { print(sQNoSwizzle); }
@@ -298,7 +272,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
- flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
@@ -312,15 +286,12 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
- // auto seeds = at::cuda::philox::unpack(params.philox_args);
- // unsigned long long seed = std::get<0>(seeds);
- // unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
- unsigned long long seed = 0;
- unsigned long long offset = 0;
-
clear(acc_o);
- float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
+ flash::Softmax<2 * size<1>(acc_o)> softmax;
+
+ const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
+ flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
@@ -342,12 +313,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV
if (masking_step > 0) {
- tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
- gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
+ gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
@@ -357,58 +327,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
-
- // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
- Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- // if (cute::thread0()) { print_tensor(scores); }
- // We don't put the masking before the matmul S = Q K^T because we don't clear sK
- // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
- // can produce Inf / NaN.
-
- if (Has_alibi) {
- flash::apply_alibi<Is_causal>(
- scores,
- n_block * kBlockN,
- binfo.actual_seqlen_k,
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q,
- kNWarps * 16,
- alibi_slope
- );
+ if constexpr (Is_softcap){
+ apply_softcap(acc_s, params.softcap);
}
- if (!Is_causal && !Is_local) {
- if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
- } else {
- // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
- // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
- // static_assert(decltype(size<0>(taccScS))::value == 4);
- // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
- // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
- // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
- // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
- // m_block * kBlockM);
- // Idk why it's get<1> and not get<0> of the stride.
- // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
- // I can't get the stride from idx_row
- flash::apply_mask_local</*HasWSLeft=*/Is_local>(
- scores, n_block * kBlockN, binfo.actual_seqlen_k,
- // m_block * kBlockM + get<0>(idx_row(0)),
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q, kNWarps * 16,
- params.window_size_left, params.window_size_right
- // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
- // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
- );
- // if (cute::thread0()) { print_tensor(scores); }
- }
+ mask.template apply_mask<Is_causal, Is_even_MN>(
+ acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
+ );
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > n_block_min) {
- // Advance gK
- tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -416,33 +346,31 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
- ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
- : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
-
- // Convert scores from fp32 to fp16/bf16
- Tensor rP = flash::convert_type<Element>(scores);
- // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
- // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+ ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
+ : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
+
+ // Convert acc_s from fp32 to fp16/bf16
+ Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
- Tensor tOrP_copy = make_fragment_like(tOrP);
- cute::copy(tOrP, tOrP_copy);
- flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
- tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
- block_row_idx, block_col_idx, kNWarps
+ Tensor rP_drop = make_fragment_like(rP);
+ cute::copy(rP, rP_drop);
+ dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
+ rP_drop, block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
- tPgP.data() = tPgP.data() + (-kBlockN);
+ cute::copy(rP_drop, tSgS);
+ tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
- flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
- block_row_idx, block_col_idx, kNWarps);
+ dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
- // if (cute::thread0()) { print(tOrP); }
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+ // if (cute::thread0()) { print(tOrP); }
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
@@ -458,93 +386,57 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
- // Advance gV
- tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
+ if constexpr (Is_softcap){
+ apply_softcap(acc_s, params.softcap);
+ }
flash::cp_async_wait<0>();
__syncthreads();
if (n_block > n_block_min) {
- // Advance gK
- tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
}
- // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
- Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
-
- if (Has_alibi) {
- flash::apply_alibi<Is_causal>(
- scores,
- n_block * kBlockN,
- binfo.actual_seqlen_k,
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q,
- kNWarps * 16,
- alibi_slope
- );
- }
-
- if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
- flash::apply_mask_local(
- scores, n_block * kBlockN, binfo.actual_seqlen_k,
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- binfo.actual_seqlen_q, kNWarps * 16,
- params.window_size_left, params.window_size_right
- );
- }
+ mask.template apply_mask</*Causal_mask=*/false>(
+ acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
+ );
- softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
- Tensor rP = flash::convert_type<Element>(scores);
- // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
- // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
- Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+ Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
- Tensor tOrP_copy = make_fragment_like(tOrP);
- cute::copy(tOrP, tOrP_copy);
- flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
- tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
- block_row_idx, block_col_idx, kNWarps
+ Tensor rP_drop = make_fragment_like(rP);
+ cute::copy(rP, rP_drop);
+ dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
+ rP_drop, block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
- tPgP.data() = tPgP.data() + (-kBlockN);
+ cute::copy(rP_drop, tSgS);
+ tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
- flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
- block_row_idx, block_col_idx, kNWarps);
+ dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
- // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
- Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
- Tensor lse = make_fragment_like(scores_sum);
- #pragma unroll
- for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
- float sum = scores_sum(mi);
- float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
- lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
- float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
- #pragma unroll
- for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
- }
-
- // if (cute::thread0()) { print(acc_o_rowcol); }
+ Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
// Convert acc_o from fp32 to fp16/bf16
Tensor rO = flash::convert_type<Element>(acc_o);
@@ -560,14 +452,13 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
- const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
- + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
- const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
- Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
- Shape<Int<kBlockM>, Int<kHeadDim>>{},
- make_stride(params.o_row_stride, _1{}));
- Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
- Shape<Int<kBlockM>>{}, Stride<_1>{});
+ Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
+ make_stride(params.o_row_stride, params.o_head_stride, _1{}));
+ Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
+ Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
@@ -608,10 +499,584 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
}
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
+inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
+
+ using Element = typename Kernel_traits::Element;
+ using ElementAccum = typename Kernel_traits::ElementAccum;
+ using index_t = typename Kernel_traits::index_t;
+
+ // Shared memory.
+ extern __shared__ char smem_[];
+
+ // The thread index.
+ const int tidx = threadIdx.x;
+
+ constexpr int kBlockM = Kernel_traits::kBlockM;
+ constexpr int kBlockN = Kernel_traits::kBlockN;
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
+ constexpr int kNWarps = Kernel_traits::kNWarps;
+
+ using GmemTiledCopyO = std::conditional_t<
+ !Split,
+ typename Kernel_traits::GmemTiledCopyO,
+ typename Kernel_traits::GmemTiledCopyOaccum
+ >;
+ using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
+
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
+ // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+
+ const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
+ const int n_block_min = !Is_local
+ ? n_split_idx * n_blocks_per_split
+ : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
+ int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
+ if (Is_causal || Is_local) {
+ n_block_max = std::min(n_block_max,
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
+ }
+ if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
+ // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
+ // Otherwise we might read OOB elements from gK and gV,
+ // or get wrong results when we combine gOaccum from different blocks.
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ + m_block * kBlockM) * params.d_rounded;
+ const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+ clear(tOrOaccum);
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ );
+ #pragma unroll
+ for (int m = 0; m < size<1>(tOgOaccum); ++m) {
+ const int row = get<0>(tOcO(0, m, 0));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; }
+ }
+ return;
+ }
+
+ // We iterate over the blocks in reverse order. This is because the last block is the only one
+ // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
+ // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
+
+ // We move K and V to the last block.
+ const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
+ const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
+ const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
+ const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
+ const index_t row_offset_k = block_table == nullptr
+ ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
+ : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+ const index_t row_offset_v = block_table == nullptr
+ ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
+ : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+
+ Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
+ make_shape(binfo.actual_seqlen_q, params.h, params.d),
+ make_stride(params.q_row_stride, params.q_head_stride, _1{}));
+ Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_coord(m_block, 0)); // (kBlockM, kHeadDim)
+ Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.k_row_stride, _1{}));
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
+ Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.v_row_stride, _1{}));
+
+ Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
+ typename Kernel_traits::SmemLayoutQ{});
+ Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
+ Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
+ Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+ Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+
+ Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+ Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+ Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+ Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
+ Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+
+ typename Kernel_traits::TiledMma tiled_mma;
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
+ Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
+ Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
+ Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
+
+ Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
+
+ //
+ // Copy Atom retiling
+ //
+
+ auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+ Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+ auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+ Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+ auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+ Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+ // PREDICATES
+ //
+
+ // // Allocate predicate tensors for m and n
+ // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
+ // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
+
+ // Construct identity layout for sQ and sK
+ Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+ // Repeat the partitioning with identity layouts
+ Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+ // Allocate predicate tensors for k
+ Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
+ Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
+
+ // Set predicates for k bounds
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
+ #pragma unroll
+ for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
+ }
+
+ // Prologue
+
+ // Copy from Knew to K, optionally apply rotary embedding.
+ typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
+ auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
+ auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
+ if constexpr (Append_KV) {
+ // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
+ // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
+ // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
+ const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+ Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.rotary_dim / 2, _1{}));
+ Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+ Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+ Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+ Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+ // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
+ // if (cute::thread(8, 0)) { print_tensor(gCos); }
+ // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
+
+ const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb)
+ + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+ const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb)
+ + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+ // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them,
+ // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64].
+ // This maps to accessing the first 64 rows of knew_ptr.
+ Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr)
+ + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.knew_row_stride, _1{}));
+ // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); }
+ Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr)
+ + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
+ Shape<Int<kBlockN>, Int<kHeadDim>>{},
+ make_stride(params.vnew_row_stride, _1{}));
+ Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
+ Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
+
+ const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+ auto tKgK_data = tKgK.data();
+ auto tVgV_data = tVgV.data();
+ for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
+ flash::copy_w_min_idx<Is_even_K>(
+ tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+ );
+ tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+ if (params.rotary_dim == 0) {
+ flash::copy_w_min_idx<Is_even_K>(
+ tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
+ );
+ } else {
+ if (params.is_rotary_interleaved) {
+ // Don't clear OOB_K because we're writing to global memory
+ flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
+ tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
+ );
+ tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
+ tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
+ } else {
+ // Don't clear OOB_K because we're writing to global memory
+ flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
+ tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN,
+ binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim
+ );
+ tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+ tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+
+ }
+ }
+ tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+ if (block_table == nullptr) {
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ } else {
+ if (n_block > n_block_copy_min) {
+ const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+ const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+ const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+ const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+ const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
+ const int offset_diff = block_table_offset_next - block_table_offset_cur;
+ tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
+ tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
+ }
+ }
+ }
+ // Need this before we can read in K again, so that we'll see the updated K values.
+ __syncthreads();
+ tKgK.data() = tKgK_data;
+ tVgV.data() = tVgV_data;
+ }
+
+ // Read Q from gmem to smem, optionally apply rotary embedding.
+ if (!Append_KV || params.rotary_dim == 0) {
+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
+ } else {
+ const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
+ // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
+ // We do this by setting the row stride of gCos / gSin to 0.
+ Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+ Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+ Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+ Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+ Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+ if (params.is_rotary_interleaved) {
+ flash::copy_rotary_interleaved<Is_even_K>(
+ tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+ 0, params.d, params.rotary_dim
+ );
+ } else {
+ flash::copy_rotary_contiguous<Is_even_K>(
+ tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM,
+ 0, params.d, params.rotary_dim
+ );
+ }
+ }
+
+ int n_block = n_block_max - 1;
+ // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN);
+ cute::cp_async_fence();
+
+ // flash::cp_async_wait<0>();
+ // __syncthreads();
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
+ // __syncthreads();
+
+ clear(acc_o);
+
+ flash::Softmax<2 * size<1>(acc_o)> softmax;
+
+ const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
+ flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
+
+ // For performance reason, we separate out two kinds of iterations:
+ // those that need masking on S, and those that don't.
+ // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
+ // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
+ // We will have at least 1 "masking" iteration.
+
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+ constexpr int n_masking_steps = (!Is_causal && !Is_local)
+ ? 1
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+ #pragma unroll
+ for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
+ Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+
+ // Advance gV
+ if (masking_step > 0) {
+ if (block_table == nullptr) {
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ } else {
+ const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
+ const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
+ const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
+ const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
+ tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
+ }
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ } else {
+ // Clear the smem tiles to account for predicated off loads
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
+ );
+ }
+ cute::cp_async_fence();
+
+ flash::gemm(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
+ );
+ // if (cute::thread0()) { print(acc_s); }
+ if constexpr (Is_softcap){
+ apply_softcap(acc_s, params.softcap);
+ }
+
+
+ mask.template apply_mask<Is_causal, Is_even_MN>(
+ acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
+ );
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
+ // __syncthreads();
+
+ if (n_block > n_block_min) {
+ // Advance gK
+ if (block_table == nullptr) {
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ } else {
+ const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+ const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+ const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+ const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+ tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
+ }
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ // We have key_padding_mask so we'll need to Check_inf
+ masking_step == 0
+ ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
+ : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
+ // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
+
+ // Convert acc_s from fp32 to fp16/bf16
+ Tensor rP = flash::convert_type<Element>(acc_s);
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+
+ // This check is at the end of the loop since we always have at least 1 iteration
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
+ --n_block;
+ break;
+ }
+ }
+
+ // These are the iterations where we don't need masking on S
+ for (; n_block >= n_block_min; --n_block) {
+ Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ // Advance gV
+ if (block_table == nullptr) {
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ } else {
+ const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
+ const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
+ const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
+ const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
+ tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
+ }
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ cute::cp_async_fence();
+
+ flash::gemm(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
+ );
+ if constexpr (Is_softcap){
+ apply_softcap(acc_s, params.softcap);
+ }
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ if (n_block > n_block_min) {
+ // Advance gK
+ if (block_table == nullptr) {
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ } else {
+ const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
+ const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
+ const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
+ const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
+ tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
+ }
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ mask.template apply_mask</*Causal_mask=*/false>(
+ acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
+ );
+ softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
+
+ Tensor rP = flash::convert_type<Element>(acc_s);
+ // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+ Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+ flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ }
+
+ // Epilogue
+
+ Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
+ // if (cute::thread0()) { print(lse); }
+
+ Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
+ // Partition sO to match the accumulator partitioning
+ using SmemTiledCopyO = std::conditional_t<
+ !Split,
+ typename Kernel_traits::SmemCopyAtomO,
+ typename Kernel_traits::SmemCopyAtomOaccum
+ >;
+ auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
+ auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor rO = flash::convert_type<ElementO>(acc_o);
+ Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
+ Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+ // sOaccum is larger than sQ, so we need to syncthreads here
+ // TODO: allocate enough smem for sOaccum
+ if constexpr (Split) { __syncthreads(); }
+
+ cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
+
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
+ + m_block * kBlockM) * params.d_rounded;
+ const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
+ ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)
+ ) + m_block * kBlockM;
+
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum),
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
+ // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); }
+
+ GmemTiledCopyO gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+
+ __syncthreads();
+
+ Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+ cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
+
+ Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
+ static_assert(decltype(size<0>(taccOcO))::value == 4);
+ // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+ Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
+ CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
+ if (get<1>(taccOcO_row(0)) == 0) {
+ #pragma unroll
+ for (int mi = 0; mi < size(lse); ++mi) {
+ const int row = get<0>(taccOcO_row(mi));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
+ }
+ }
+
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ );
+}
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
@@ -627,9 +1092,207 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
- flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+ flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params>
+inline __device__ void compute_attn_splitkv(const Params &params) {
+ const int m_block = blockIdx.x;
+ // The block index for the batch.
+ const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
+ // The block index for the head.
+ const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
+ const int n_split_idx = Split ? blockIdx.y : 0;
+ const int num_n_splits = Split ? gridDim.y : 1;
+ flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
+template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
+inline __device__ void combine_attn_seqk_parallel(const Params &params) {
+ using Element = typename Kernel_traits::Element;
+ using ElementAccum = typename Kernel_traits::ElementAccum;
+ using index_t = typename Kernel_traits::index_t;
+ constexpr int kMaxSplits = 1 << Log_max_splits;
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
+ constexpr int kNThreads = Kernel_traits::kNThreads;
+
+ static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
+ static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
+ static_assert(kNThreads == 128, "We assume that each block has 128 threads");
+
+ // Shared memory.
+ // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
+ __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
+
+ // The thread and block index.
+ const int tidx = threadIdx.x;
+ const int bidx = blockIdx.x;
+
+ const index_t lse_size = params.b * params.h * params.seqlen_q;
+
+ const index_t row_offset_lse = bidx * kBlockM;
+ Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
+ Shape<Int<kMaxSplits>, Int<kBlockM>>{},
+ make_stride(lse_size, _1{}));
+
+ // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile.
+ // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}.
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+ // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}.
+ Layout flat_layout = make_layout(lse_size);
+ Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b));
+ auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
+ Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
+ Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout));
+
+ Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout);
+
+ constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
+
+ // Read the LSE values from gmem and store them in shared memory, then transpose them.
+ constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
+ #pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
+ const int col = tidx % kBlockM;
+ ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
+ if (row < kMaxSplits) { sLSE[row][col] = lse; }
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
+ }
+ // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
+ __syncthreads();
+ Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
+ constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
+ // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
+ // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
+ // kBlockM rows, so each time we load we can load 128 / kBlockM rows).
+ // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
+ // static_assert(kThreadsPerSplit <= 32);
+ static_assert(kRowsPerLoadTranspose <= 32);
+ static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
+ #pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+ const int col = tidx / kRowsPerLoadTranspose;
+ lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
+ }
+
+ // Compute the logsumexp of the LSE along the split dimension.
+ ElementAccum lse_max = lse_accum(0);
+ #pragma unroll
+ for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
+ MaxOp<float> max_op;
+ lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
+ lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
+ float lse_sum = expf(lse_accum(0) - lse_max);
+ #pragma unroll
+ for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
+ SumOp<float> sum_op;
+ lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
+ // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise
+ // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum.
+ ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max;
+ // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
+ if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) {
+ if (params.unpadded_lse) {
+ const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
+ if (lse_offset < lse_size) {
+ gLSE_unpadded(lse_offset) = lse_logsum;
+ }
+ } else {
+ gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+ }
+ }
+ // Store the scales exp(lse - lse_logsum) in shared memory.
+ #pragma unroll
+ for (int l = 0; l < kNLsePerThread; ++l) {
+ const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+ const int col = tidx / kRowsPerLoadTranspose;
+ if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }
+ }
+ __syncthreads();
+
+ const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
+ Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ Stride<Int<kHeadDim>, _1>{});
+ constexpr int kBlockN = kNThreads / kBlockM;
+ using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
+ using GmemTiledCopyOaccum = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ GmemLayoutAtomOaccum{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
+ GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+ auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+ Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
+ Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
+ Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
+ clear(tOrO);
+
+ // Predicates
+ Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
+ // Repeat the partitioning with identity layouts
+ Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
+ Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
+ }
+ // Load Oaccum in then scale and accumulate to O
+ for (int split = 0; split < params.num_splits; ++split) {
+ flash::copy</*Is_even_MN=*/false, Is_even_K>(
+ gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
+ );
+ #pragma unroll
+ for (int m = 0; m < size<1>(tOrOaccum); ++m) {
+ int row = get<0>(tOcOaccum(0, m, 0));
+ ElementAccum lse_scale = sLSE[split][row];
+ #pragma unroll
+ for (int k = 0; k < size<2>(tOrOaccum); ++k) {
+ #pragma unroll
+ for (int i = 0; i < size<0>(tOrOaccum); ++i) {
+ tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
+ }
+ }
+ // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
+ }
+ tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
+ }
+ // if (cute::thread0()) { print_tensor(tOrO); }
+
+ Tensor rO = flash::convert_type<Element>(tOrO);
+ // Write to gO
+ #pragma unroll
+ for (int m = 0; m < size<1>(rO); ++m) {
+ const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
+ if (idx < params.b * params.h * params.seqlen_q) {
+ const int batch_idx = idx / (params.h * params.seqlen_q);
+ const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
+ // The index to the rows of Q
+ const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
+ auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
+ + head_idx * params.o_head_stride + row * params.o_row_stride;
+ #pragma unroll
+ for (int k = 0; k < size<2>(rO); ++k) {
+ if (Is_even_K || tOpOaccum(k)) {
+ const int col = get<1>(tOcOaccum(0, m, k));
+ Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
+ Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
+ // TODO: Should check if this is using vectorized store, but it seems pretty fast
+ copy(rO(_, m, k), gO);
+ // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
+ // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
+ }
+ }
+ }
+ }
+}
+
} // namespace flash
diff --git a/candle-flash-attn/kernels/flash_fwd_launch_template.h b/candle-flash-attn/kernels/flash_fwd_launch_template.h
index 002dd8ec..9e5449d7 100644
--- a/candle-flash-attn/kernels/flash_fwd_launch_template.h
+++ b/candle-flash-attn/kernels/flash_fwd_launch_template.h
@@ -4,14 +4,49 @@
#pragma once
+// #include <ATen/cuda/CUDAContext.h>
+
+#include "error.h"
#include "static_switch.h"
#include "flash.h"
#include "flash_fwd_kernel.h"
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
-__global__ void flash_fwd_kernel(Flash_fwd_params params) {
- static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
- flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
+// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#define ARCH_SUPPORTS_FLASH
+#define KERNEL_PARAM_MODIFIER __grid_constant__
+#else
+#define KERNEL_PARAM_MODIFIER
+#endif
+
+// Define a macro for unsupported architecture handling to centralize the error message
+#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
+
+// Use a macro to clean up kernel definitions
+#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
+template<typename Kernel_traits, __VA_ARGS__> \
+__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
+ #if defined(ARCH_SUPPORTS_FLASH)
+ static_assert(!(Is_causal && Is_local)); // Enforce constraints
+ flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
+ #else
+ FLASH_UNSUPPORTED_ARCH
+ #endif
+}
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
+ #if defined(ARCH_SUPPORTS_FLASH)
+ flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
+ #else
+ FLASH_UNSUPPORTED_ARCH
+ #endif
+}
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
+ static_assert(Log_max_splits >= 1);
+ flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@@ -29,181 +64,246 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
- BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
- BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
- BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
- // Will only return softmax if dropout, to reduce compilation time.
- // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
- // If return_softmax, set IsEvenMNConst to false to reduce number of templates
- // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
- // If Is_local, set Is_causal to false
- auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
- // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
- // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
- // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
- if (smem_size >= 48 * 1024) {
- cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
- }
- // int ctas_per_sm;
- // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
- // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
- // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
- kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+ SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+ // Will only return softmax if dropout, to reduce compilation time.
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+ // If return_softmax, set IsEvenMNConst to false to reduce number of templates
+ // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+ // If Is_local, set Is_causal to false
+ auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout>;
+ // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
+ // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
+ // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+ if (smem_size >= 48 * 1024) {
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+ }
+ // int ctas_per_sm;
+ // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+ // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+ // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+ kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ });
+ });
+ });
+ });
+ });
+}
+
+template<typename Kernel_traits, bool Is_causal>
+void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+ static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
+ static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
+ constexpr size_t smem_size = Kernel_traits::kSmemSize;
+ const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+ dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
+ const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+ const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+ BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+ LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+ BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+ BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+ ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+ SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+ // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+ // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+ // If Is_local, set Is_causal to false
+ auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+ // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+ if (smem_size >= 48 * 1024) {
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+ }
+ kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ });
});
});
});
});
});
+ if (params.num_splits > 1) {
+ // We want kBlockM to be as small as possible for more parallelism.
+ // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
+ // If headdim is divisible by 64, then we set kBlockM = 8, etc.
+ constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
+ dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
+ EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+ if (params.num_splits <= 2) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 4) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 8) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 16) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 32) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 64) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ } else if (params.num_splits <= 128) {
+ flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+ }
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ });
+ }
}
+template<typename T, int Headdim, bool Is_causal>
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
+ constexpr static int kBlockM = 64; // Fixed for all head dimensions
+ // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
+ // and for headdim 192 with block size 64 x 128.
+ // Also for headdim 160 with block size 64 x 128 after the rotary addition.
+ constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
+ run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
+}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- });
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- if constexpr(!Is_dropout) {
- // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
- // Using block size (64 x 256) is 27% slower for seqlen=2k
- // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
- });
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ if constexpr(!Is_dropout) {
+ // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+ // Using block size (64 x 256) is 27% slower for seqlen=2k
+ // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
});
}
-template<typename T>
+inline bool cuda_is_sm8x() {
+ // dprops = at::cuda::getCurrentDeviceProperties();
+ // return dprops->major == 8 && dprops->minor > 0;
+ return false;
+}
+
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
- // auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
- if (is_sm8x) {
- if constexpr(!Is_causal) {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
- } else {
+ bool is_sm8x = cuda_is_sm8x();
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+ if (is_sm8x) {
+ if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
- // These two are always slower
- // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
- });
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ // These two are always slower
+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
- // auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- if constexpr(!Is_dropout) {
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
- // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
- if (is_sm8x) {
- if constexpr(!Is_causal) {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
+ bool is_sm8x = cuda_is_sm8x();
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ if constexpr(!Is_dropout) {
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+ // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
+ if (is_sm8x) {
+ if constexpr(!Is_causal) {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // 1st ones are good for H100, A100
- // 2nd one is good for A6000 bc we get slightly better occupancy
} else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
- });
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // 1st ones are good for H100, A100
+ // 2nd one is good for A6000 bc we get slightly better occupancy
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+ }
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
- // auto dprops = at::cuda::getCurrentDeviceProperties();
- bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- // For A100, H100, 128 x 32 is the fastest.
- // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
- // and 128 x 64 with 8 warps is the fastest for non-causal.
- if (is_sm8x) {
- if constexpr(!Is_causal) {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
+ bool is_sm8x = cuda_is_sm8x();
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ // For A100, H100, 128 x 32 is the fastest.
+ // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+ // and 128 x 64 with 8 warps is the fastest for non-causal.
+ if (is_sm8x) {
+ if constexpr(!Is_causal) {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
- });
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- if constexpr(!Is_dropout) {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
- });
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ if constexpr(!Is_dropout) {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
int device;
@@ -211,25 +311,26 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+ if (status_ != cudaSuccess) {
+ C10_CUDA_CHECK(status_);
+ }
// printf("max_smem_per_block = %d\n", max_smem_per_block);
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
- // If we have N = 32, there are only 1024 elements to load at once, where each load
- // is 8 elements. This means we can only use 128 threads and not 256 threads.
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- });
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
+ // If we have N = 32, there are only 1024 elements to load at once, where each load
+ // is 8 elements. This means we can only use 128 threads and not 256 threads.
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
-template<typename T>
+template<typename T, bool Is_causal>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
@@ -239,20 +340,21 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+ if (status_ != cudaSuccess) {
+ C10_CUDA_CHECK(status_);
+ }
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
- BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
- BOOL_SWITCH(params.is_causal, Is_causal, [&] {
- // For A100, we want to run with 128 x 64 (128KB smem).
- // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
- if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- } else {
- run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- }
- // 64 KB
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
- // 96 KB
- // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
- });
+ DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+ // For A100, we want to run with 128 x 64 (128KB smem).
+ // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
+ if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ } else {
+ run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ }
+ // 64 KB
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+ // 96 KB
+ // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
}
diff --git a/candle-flash-attn/kernels/kernel_helpers.h b/candle-flash-attn/kernels/kernel_helpers.h
new file mode 100644
index 00000000..22e40cc4
--- /dev/null
+++ b/candle-flash-attn/kernels/kernel_helpers.h
@@ -0,0 +1,50 @@
+// This header is not specific to our application and you'll probably want
+// something like this for any extension you're building. This includes the
+// infrastructure needed to serialize descriptors that are used with the
+// "opaque" parameter of the GPU custom call. In our example we'll use this
+// parameter to pass the size of our problem.
+
+#ifndef _GPU_OPS_KERNEL_HELPERS_H_
+#define _GPU_OPS_KERNEL_HELPERS_H_
+
+#include <cstdint>
+#include <stdexcept>
+#include <string>
+#include <type_traits>
+
+#define JAX_APEX_WARP_SIZE 32
+
+namespace gpu_ops {
+
+// https://en.cppreference.com/w/cpp/numeric/bit_cast
+template <class To, class From>
+typename std::enable_if<sizeof(To) == sizeof(From) &&
+ std::is_trivially_copyable<From>::value &&
+ std::is_trivially_copyable<To>::value,
+ To>::type
+bit_cast(const From &src) noexcept {
+ static_assert(std::is_trivially_constructible<To>::value,
+ "This implementation additionally requires destination type to "
+ "be trivially constructible");
+
+ To dst;
+ memcpy(&dst, &src, sizeof(To));
+ return dst;
+}
+
+template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
+ return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
+}
+
+template <typename T>
+const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
+ if (opaque_len != sizeof(T)) {
+ throw std::runtime_error("Invalid opaque object size");
+ }
+ return bit_cast<const T *>(opaque);
+}
+
+} // namespace gpu_ops
+
+#endif
+
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index f000ff24..5a7b7491 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -1,10 +1,10 @@
/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
+ * Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
-#include "cute/algorithm/copy.hpp"
+#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
@@ -24,7 +24,7 @@ struct Flash_kernel_traits {
#endif
using ElementAccum = float;
- using index_t = uint32_t;
+ using index_t = int64_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
@@ -32,10 +32,8 @@ struct Flash_kernel_traits {
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
- using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
- using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
@@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base {
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * kNWarps>, _16, _16>>;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -91,20 +89,10 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
- // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
- using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomVtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
- using SmemLayoutVtransposed = decltype(tile_to_shape(
- SmemLayoutAtomVtransposed{},
- Shape<Int<kHeadDim>, Int<kBlockN>>{}));
- // Maybe the VtransposeNoSwizzle just needs to have the right shape
- // And the strides don't matter?
- using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomVtransposedNoSwizzle{},
- Shape<Int<kHeadDim>, Int<kBlockN>>{}));
- // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+ // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
+ using SmemLayoutVtransposed = decltype(
+ composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
+ using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -116,10 +104,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
- static constexpr int kSmemQCount = size(SmemLayoutQ{});
- static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
- static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
- static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+ static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
+ static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
@@ -149,15 +135,6 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
- static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
- static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
- using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
- Stride<Int<kGmemThreadsPerRowP>, _1>>;
-
- using GmemTiledCopyP = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
- GmemLayoutAtomP{},
- Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
@@ -218,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base {
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -247,26 +224,18 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
- using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomKtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
- using SmemLayoutKtransposed = decltype(tile_to_shape(
- SmemLayoutAtomKtransposed{},
- make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
- // Maybe the KtransposeNoSwizzle just needs to have the right shape
- // And the strides don't matter?
- using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomKtransposedNoSwizzle{},
- make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
- // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+ using SmemLayoutKtransposed = decltype(
+ composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
+ using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
- static_assert(kBlockN >= 64);
+ // Temporarily disabling this for hdim 256 on sm86 and sm89
+ // static_assert(kBlockN >= 64);
+ static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
- static constexpr int kPBlockN = 64;
+ static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
@@ -277,30 +246,15 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
- using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
- Stride<_1, Int<kPBlockN>>>;
- using SmemLayoutAtomPdStransposed = decltype(
- composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
- using SmemLayoutPdStransposed = decltype(tile_to_shape(
- SmemLayoutAtomPdStransposed{},
- make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomPdStransposedNoSwizzle{},
- make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+ using SmemLayoutPdStransposed = decltype(
+ composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
+ using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
+
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
- using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomQdOtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
- using SmemLayoutQdOtransposed = decltype(tile_to_shape(
- SmemLayoutAtomQdOtransposed{},
- make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomQdOtransposedNoSwizzle{},
- make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+ using SmemLayoutQdOtransposed = decltype(
+ composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -320,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base {
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
- static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
- static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
- static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
- static constexpr int kSmemPCount = size(SmemLayoutPdS{});
- static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
- static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
- static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
- static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
- static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
- static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
+ // Double buffer for sQ
+ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
+ static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
+ static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
+ static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
+ static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
@@ -338,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base {
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
- static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
- + kSmemdSSize + kSmemPSize;
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
diff --git a/candle-flash-attn/kernels/kernels.h b/candle-flash-attn/kernels/kernels.h
new file mode 100644
index 00000000..20d6605f
--- /dev/null
+++ b/candle-flash-attn/kernels/kernels.h
@@ -0,0 +1,58 @@
+#ifndef _GPU_OPS_KERNELS_H_
+#define _GPU_OPS_KERNELS_H_
+
+#include <cuda_runtime_api.h>
+
+#include <cstddef>
+#include <cstdint>
+
+#include<stdlib.h>
+#include<stdint.h>
+
+namespace gpu_ops {
+
+struct MHAParams {
+ uint32_t q_batch_stride;
+ uint32_t k_batch_stride;
+ uint32_t v_batch_stride;
+ uint32_t o_batch_stride;
+
+ uint32_t q_row_stride;
+ uint32_t k_row_stride;
+ uint32_t v_row_stride;
+ uint32_t o_row_stride;
+
+ uint32_t q_head_stride;
+ uint32_t k_head_stride;
+ uint32_t v_head_stride;
+ uint32_t o_head_stride;
+
+ uint32_t b;
+ uint32_t h;
+ uint32_t h_k;
+ uint32_t d;
+ uint32_t d_rounded;
+ float softmax_scale;
+ float softcap;
+
+ uint32_t seqlen_q;
+ uint32_t seqlen_k;
+ uint32_t seqlen_q_rounded;
+ uint32_t seqlen_k_rounded;
+
+ int window_size_left;
+ int window_size_right;
+
+ int is_causal;
+ int is_bf16;
+};
+
+void run_mha_fwd_j(cudaStream_t stream, void **buffers,
+ const char *opaque,
+ std::size_t opaque_len);
+void run_mha_bwd_j(cudaStream_t stream, void **buffers,
+ const char *opaque,
+ std::size_t opaque_len);
+}
+
+#endif
diff --git a/candle-flash-attn/kernels/mask.h b/candle-flash-attn/kernels/mask.h
new file mode 100644
index 00000000..7ba435a3
--- /dev/null
+++ b/candle-flash-attn/kernels/mask.h
@@ -0,0 +1,213 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cute/tensor.hpp>
+
+namespace flash {
+
+using namespace cute;
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
+ const int col_idx_offset_ = 0) {
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ if (col_idx >= max_seqlen_k) {
+ // Without the "make_coord" we get wrong results
+ #pragma unroll
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
+ tensor(mi, make_coord(j, nj)) = -INFINITY;
+ }
+ }
+ }
+ }
+}
+
+template <bool HasWSLeft=true, typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride,
+ const int window_size_left, const int window_size_right) {
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ static_assert(Layout::rank == 2, "Only support 2D Tensor");
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ #pragma unroll
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+ #pragma unroll
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
+ const int row_idx = row_idx_base + i * 8;
+ const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+ const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
+ tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+ }
+ }
+ }
+ // if (cute::thread0()) {
+ // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
+ // print(tensor(make_coord(i, mi), _));
+ // // print(tensor(_, j + nj * size<1, 0>(tensor)));
+ // }
+ }
+ }
+}
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+ const int max_seqlen_k, const int row_idx_offset,
+ const int max_seqlen_q, const int warp_row_stride) {
+ // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+ apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
+ max_seqlen_q, warp_row_stride, -1, 0);
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void apply_mask_causal_w_idx(
+ Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
+ const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
+{
+ // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+ static_assert(Layout1::rank == 2, "Only support 2D Tensor");
+ CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
+ CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
+ #pragma unroll
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
+ const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
+ #pragma unroll
+ for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
+ if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
+ tensor(mi, ni) = -INFINITY;
+ }
+ }
+ // if (cute::thread0()) {
+ // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
+ // print(tensor(_, make_coord(j, ni)));
+ // // print(tensor(_, j + ni * size<1, 0>(tensor)));
+ // }
+ }
+}
+
+template <bool Is_causal, bool Is_local, bool Has_alibi>
+struct Mask {
+
+ const int max_seqlen_k, max_seqlen_q;
+ const int window_size_left, window_size_right;
+ const float alibi_slope;
+
+ __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
+ const int window_size_left, const int window_size_right,
+ const float alibi_slope=0.f)
+ : max_seqlen_k(max_seqlen_k)
+ , max_seqlen_q(max_seqlen_q)
+ , window_size_left(window_size_left)
+ , window_size_right(window_size_right)
+ , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
+ };
+
+ // Causal_mask: whether this particular iteration needs causal masking
+ template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
+ __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
+ const int col_idx_offset_,
+ const int row_idx_offset,
+ const int warp_row_stride) {
+ static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
+ static_assert(Layout::rank == 3, "Only support 3D Tensor");
+ static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
+ static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
+ // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
+ if constexpr (Need_masking) {
+ // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
+ // Do we need both row and column indices, or just column incides?
+ static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
+ const int lane_id = threadIdx.x % 32;
+ const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+ if constexpr (Col_idx_only) {
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ #pragma unroll
+ for (int mi = 0; mi < size<0>(tensor); ++mi) {
+ // No causal, no local
+ if constexpr (Has_alibi) {
+ tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
+ }
+ if constexpr (!Is_even_MN) {
+ if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
+ }
+ }
+ }
+ }
+ } else {
+ #pragma unroll
+ for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+ const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+ #pragma unroll
+ for (int i = 0; i < size<0, 0>(tensor); ++i) {
+ const int row_idx = row_idx_base + i * 8;
+ const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+ const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+ #pragma unroll
+ for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+ const int col_idx_base = col_idx_offset + nj * 8;
+ #pragma unroll
+ for (int j = 0; j < size<1, 0>(tensor); ++j) {
+ const int col_idx = col_idx_base + j;
+ if constexpr (Has_alibi) {
+ if constexpr (Is_causal) {
+ tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
+ } else {
+ tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
+
+ }
+ }
+ if constexpr (Causal_mask) {
+ if (col_idx >= col_idx_limit_right) {
+ tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+ }
+ }
+ if constexpr (Is_local) {
+ if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
+ tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+ }
+ }
+ if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
+ // Causal and Local already handles MN masking
+ if (col_idx >= max_seqlen_k) {
+ tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ };
+
+};
+
+} // namespace flash
diff --git a/candle-flash-attn/kernels/philox.cuh b/candle-flash-attn/kernels/philox.cuh
index 6ce1440f..cd7e4d2f 100644
--- a/candle-flash-attn/kernels/philox.cuh
+++ b/candle-flash-attn/kernels/philox.cuh
@@ -9,7 +9,7 @@ struct ull2 {
unsigned long long y;
};
-inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
+__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
@@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
return *res;
}
-inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
+__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
@@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
return ret;
}
-inline __device__ uint4 philox(unsigned long long seed,
+__forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;
@@ -49,117 +49,3 @@ inline __device__ uint4 philox(unsigned long long seed,
}
} // namespace flash
-
-namespace {
-
-class Philox {
-public:
- __device__ inline Philox(unsigned long long seed,
- unsigned long long subsequence,
- unsigned long long offset)
- : STATE(0)
- , seed_(seed)
- , offset_(offset)
- , key(reinterpret_cast<const uint2&>(seed)) {
- //key.x = (unsigned int)seed;
- //key.y = (unsigned int)(seed >> 32);
- //counter = make_uint4(0, 0, 0, 0);
- //counter.z = (unsigned int)(subsequence);
- //counter.w = (unsigned int)(subsequence >> 32);
- //STATE = 0;
- //incr_n(offset / 4);
-
- // key = reinterpret_cast<const uint2&>(seed);
- ull2 * tmp = reinterpret_cast<ull2*>(&counter);
- tmp->x = offset / 4;
- tmp->y = subsequence;
- // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
- // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
- // }
- }
- __device__ inline uint4 operator()() {
- // // if (STATE == 0) {
- // uint4 counter_ = counter;
- // uint2 key_ = key;
- // // 7-round philox
- // #pragma unroll
- // for (int i = 0; i < 6; i++) {
- // counter_ = flash::philox_single_round(counter_, key_);
- // key_.x += (kPhilox10A);
- // key_.y += (kPhilox10B);
- // }
- // // output = philox_single_round(counter_, key_);
- // uint4 output = flash::philox_single_round(counter_, key_);
- // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
- // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
- // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
- // // }
- // incr();
- // // }
- // // return a float4 directly
- // // unsigned long ret;
- // // switch(STATE) {
- // // case 0: ret = output.x; break;
- // // case 1: ret = output.y; break;
- // // case 2: ret = output.z; break;
- // // case 3: ret = output.w; break;
- // //}
- // // STATE = (STATE + 1) % 4;
- // return output;
- return flash::philox(seed_, offset_, offset_);
- }
-
-private:
- unsigned long long offset_, seed_;
- struct ull2 {
- uint64_t x;
- uint64_t y;
- };
- uint4 counter;
- // uint4 output;
- const uint2 key;
- unsigned int STATE;
- __device__ inline void incr_n(unsigned long long n) {
- unsigned int nlo = (unsigned int)(n);
- unsigned int nhi = (unsigned int)(n >> 32);
- counter.x += nlo;
- if (counter.x < nlo)
- nhi++;
- counter.y += nhi;
- if (nhi <= counter.y)
- return;
- if (++counter.z)
- return;
- ++counter.w;
- }
-
- __device__ uint4 incr128 (uint4 ctr)
- {
- uint4 res;
- asm ("add.cc.u32 %0, %4, %8;\n\t"
- "addc.cc.u32 %1, %5, %9;\n\t"
- "addc.cc.u32 %2, %6, %10;\n\t"
- "addc.u32 %3, %7, %11;\n\t"
- : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
- : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
- "n"(1), "n"(0), "n"(0), "n"(0));
- return res;
- }
-
- __device__ inline void incr() {
- // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
- // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
- // }
- counter = incr128(counter);
- // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
- // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
- // }
- }
-
- static const unsigned long kPhilox10A = 0x9E3779B9;
- static const unsigned long kPhilox10B = 0xBB67AE85;
- // static const unsigned long kPhiloxSA = 0xD2511F53;
- // static const unsigned long kPhiloxSB = 0xCD9E8D57;
-};
-
-} // namespace
diff --git a/candle-flash-attn/kernels/rotary.h b/candle-flash-attn/kernels/rotary.h
new file mode 100644
index 00000000..7f1614ad
--- /dev/null
+++ b/candle-flash-attn/kernels/rotary.h
@@ -0,0 +1,152 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cute/tensor.hpp>
+
+#include "utils.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
+ Tensor<Engine1, Layout1> &D,
+ Tensor<Engine2, Layout2> const &Cos,
+ Tensor<Engine2, Layout2> const &Sin,
+ Tensor<Engine3, Layout3> const &identity_MN,
+ const int max_MN, const int min_MN,
+ const int dim, const int rotary_dim) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
+ static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
+ Tensor rCos = make_fragment_like(Cos);
+ Tensor rSin = make_fragment_like(Sin);
+ Tensor rS = make_fragment_like(S);
+ #pragma unroll
+ for (int m = 0; m < size<1>(S); ++m) {
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+ #pragma unroll
+ for (int k = 0; k < size<2>(S); ++k) {
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+ cute::copy(S(_, m, k), rS(_, m, k));
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+ cute::copy(Cos(_, m, k), rCos(_, m, k));
+ cute::copy(Sin(_, m, k), rSin(_, m, k));
+ Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+ Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+ Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+ #pragma unroll
+ for (int i = 0; i < size<0>(rS) / 2; ++i) {
+ float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
+ float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
+ S_fp32(2 * i) = real;
+ S_fp32(2 * i + 1) = imag;
+ }
+ // Idk but I need to copy for the convert_type to work
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
+ cute::copy(S_fp32, S_fp32_copy);
+ using T = typename Engine0::value_type;
+ Tensor S_og_type = convert_type<T>(S_fp32_copy);
+ cute::copy(S_og_type, rS(_, m, k));
+ }
+ cute::copy(rS(_, m, k), D(_, m, k));
+ } else if (Clear_OOB_K) {
+ cute::clear(D(_, m, k));
+ }
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
+ Tensor<Engine1, Layout1> &D,
+ Tensor<Engine2, Layout2> const &Cos,
+ Tensor<Engine2, Layout2> const &Sin,
+ Tensor<Engine3, Layout3> const &identity_MN,
+ const int max_MN, const int min_MN,
+ const int dim, const int rotary_dim) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
+ CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
+ static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
+ Tensor rCos = make_fragment_like(Cos);
+ Tensor rSin = make_fragment_like(Sin);
+ Tensor rS = make_fragment_like(S);
+ Tensor rS_other = make_fragment_like(rS(_, 0, 0));
+ #pragma unroll
+ for (int m = 0; m < size<1>(S); ++m) {
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+ #pragma unroll
+ for (int k = 0; k < size<2>(S); ++k) {
+ if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+ cute::copy(S(_, m, k), rS(_, m, k));
+ if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+ const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
+ Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
+ cute::copy(gS_other, rS_other);
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
+ Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
+ Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
+ cute::copy(gCos, rCos(_, m, k));
+ cute::copy(gSin, rSin(_, m, k));
+ // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
+ Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+ Tensor S_other_fp32 = convert_type<float>(rS_other);
+ Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+ Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+ #pragma unroll
+ for (int i = 0; i < size<0>(rS); ++i) {
+ S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
+ }
+ // Idk but I need to copy for the convert_type to work
+ Tensor S_fp32_copy = make_fragment_like(S_fp32);
+ cute::copy(S_fp32, S_fp32_copy);
+ using T = typename Engine0::value_type;
+ Tensor S_og_type = convert_type<T>(S_fp32_copy);
+ cute::copy(S_og_type, rS(_, m, k));
+ // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
+ }
+ cute::copy(rS(_, m, k), D(_, m, k));
+ } else if (Clear_OOB_K) {
+ cute::clear(D(_, m, k));
+ }
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace flash
diff --git a/candle-flash-attn/kernels/softmax.h b/candle-flash-attn/kernels/softmax.h
index 09a93f14..ebf1b097 100644
--- a/candle-flash-attn/kernels/softmax.h
+++ b/candle-flash-attn/kernels/softmax.h
@@ -1,5 +1,5 @@
/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
+ * Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
@@ -20,7 +20,7 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
-__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
+__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
@@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
-__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
+__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
@@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
-__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
+__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
-__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
+__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
-template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
-__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
+template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
- reduce_(tensor, sum, sum_op);
+ thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
-inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
+__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@@ -78,14 +78,21 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
- tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+ // The following macro will disable the use of fma.
+ // See: https://github.com/pytorch/pytorch/issues/121558 for more details
+ // This macro is set in PyTorch and not FlashAttention
+ #ifdef UNFUSE_FMA
+ tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
+ #else
+ tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+ #endif
}
}
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
-inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
+__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@@ -115,169 +122,67 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
}
-template <typename Engine, typename Layout>
-inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
- const int col_idx_offset_ = 0) {
- // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
- static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const int lane_id = threadIdx.x % 32;
- const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
- #pragma unroll
- for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const int col_idx_base = col_idx_offset + nj * 8;
- #pragma unroll
- for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const int col_idx = col_idx_base + j;
- if (col_idx >= max_seqlen_k) {
- // Without the "make_coord" we get wrong results
- #pragma unroll
- for (int mi = 0; mi < size<0>(tensor); ++mi) {
- tensor(mi, make_coord(j, nj)) = -INFINITY;
- }
- }
- }
- }
-}
+////////////////////////////////////////////////////////////////////////////////////////////////////
-template <bool HasWSLeft=true, typename Engine, typename Layout>
-inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
- const int max_seqlen_k, const int row_idx_offset,
- const int max_seqlen_q, const int warp_row_stride,
- const int window_size_left, const int window_size_right) {
- // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
- static_assert(Layout::rank == 2, "Only support 2D Tensor");
- const int lane_id = threadIdx.x % 32;
- const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
- #pragma unroll
- for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
- const int row_idx_base = row_idx_offset + mi * warp_row_stride;
- #pragma unroll
- for (int i = 0; i < size<0, 0>(tensor); ++i) {
- const int row_idx = row_idx_base + i * 8;
- const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
- const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+template <int kNRows>
+struct Softmax {
+
+ using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
+ TensorT row_max, row_sum;
+
+ __forceinline__ __device__ Softmax() {};
+
+ template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
+ __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
+ // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+ static_assert(decltype(size<0>(scores))::value == kNRows);
+ if (Is_first) {
+ flash::template reduce_max</*zero_init=*/true>(scores, row_max);
+ flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
+ flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
+ } else {
+ Tensor scores_max_prev = make_fragment_like(row_max);
+ cute::copy(row_max, scores_max_prev);
+ flash::template reduce_max</*zero_init=*/false>(scores, row_max);
+ // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+ Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+ static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
- for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
- const int col_idx_base = col_idx_offset + nj * 8;
+ for (int mi = 0; mi < size(row_max); ++mi) {
+ float scores_max_cur = !Check_inf
+ ? row_max(mi)
+ : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
+ float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
+ row_sum(mi) *= scores_scale;
#pragma unroll
- for (int j = 0; j < size<1, 0>(tensor); ++j) {
- const int col_idx = col_idx_base + j;
- if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
- tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
- }
- }
+ for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
}
- // if (cute::thread0()) {
- // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
- // print(tensor(make_coord(i, mi), _));
- // // print(tensor(_, j + nj * size<1, 0>(tensor)));
- // }
+ flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
+ // We don't do the reduce across threads here since we don't need to use the row_sum.
+ // We do that reduce at the end when we need to normalize the softmax.
+ flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
- }
-}
-
-template <typename Engine, typename Layout>
-inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
- const int max_seqlen_k, const int row_idx_offset,
- const int max_seqlen_q, const int warp_row_stride) {
- // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
- apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
- max_seqlen_q, warp_row_stride, -1, 0);
-}
+ };
-template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
-inline __device__ void apply_mask_causal_w_idx(
- Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
- const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
-{
- // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
- static_assert(Layout0::rank == 2, "Only support 2D Tensor");
- static_assert(Layout1::rank == 2, "Only support 2D Tensor");
- CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
- CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
- #pragma unroll
- for (int mi = 0; mi < size<0>(tensor); ++mi) {
- const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
+ template<bool Is_dropout=false, bool Split=false, typename Tensor0>
+ __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
+ SumOp<float> sum_op;
+ quad_allreduce_(row_sum, row_sum, sum_op);
+ TensorT lse = make_fragment_like(row_sum);
+ Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+ static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
- for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
- if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
- tensor(mi, ni) = -INFINITY;
- }
+ for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
+ float sum = row_sum(mi);
+ float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+ lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
+ float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
+ #pragma unroll
+ for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
- // if (cute::thread0()) {
- // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
- // print(tensor(_, make_coord(j, ni)));
- // // print(tensor(_, j + ni * size<1, 0>(tensor)));
- // }
- }
-}
-
-template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
-inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
- unsigned long long seed, unsigned long long offset,
- int block_row_start, int block_col_start,
- int block_row_stride) {
- // tensor has shape (8, MMA_M, MMA_N / 2)
- using T = typename Engine::value_type;
- auto encode_dropout = [](bool keep, T val) {
- return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
+ return lse;
};
- static_assert(decltype(size<2>(tensor))::value % 2 == 0);
- const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
- const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
- // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
- #pragma unroll
- for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
- uint2 rowcol = make_uint2(block_row_start, block_col_start);
- #pragma unroll
- for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
- // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
- uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
- // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
- uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
- // Special implementation for 16-bit types: we duplicate the threshold to the
- // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
- // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
- // and the high 16 bits will be either 0xffff or 0x0000, depending on whether
- // the random value is less than the threshold.
- // We then do a bit-wise AND between the mask and the original value (in 32-bit).
- // We're exploiting the fact that floating point comparison is equivalent to integer
- // comparison, since we're comparing unsigned integers whose top 8-bits are zero.
- if (!encode_dropout_in_sign_bit
- && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
- uint16_t rnd_16[16];
- #pragma unroll
- for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
- uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
- #pragma unroll
- for (int j = 0; j < 2; j++) {
- Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
- // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
- // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
- #pragma unroll
- for (int i = 0; i < 4; i++) {
- uint32_t mask;
- asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
- tensor_uint32(i) &= mask;
- }
- // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
- }
- } else {
- #pragma unroll
- for (int j = 0; j < 2; j++) {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
- }
- Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
- // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
- }
- }
- // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
- // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
- // // }
- }
- }
-}
+};
} // namespace flash
diff --git a/candle-flash-attn/kernels/static_switch.h b/candle-flash-attn/kernels/static_switch.h
index 4aa84740..20c2afd6 100644
--- a/candle-flash-attn/kernels/static_switch.h
+++ b/candle-flash-attn/kernels/static_switch.h
@@ -14,6 +14,7 @@
/// some_function<BoolConst>(...);
/// });
/// ```
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
@@ -25,6 +26,56 @@
} \
}()
+#ifdef FLASHATTENTION_DISABLE_DROPOUT
+ #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define DROPOUT_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+ #define ALIBI_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define ALIBI_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+ #define EVENK_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = true; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define EVENK_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_SOFTCAP
+ #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define SOFTCAP_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_LOCAL
+ #define LOCAL_SWITCH(COND, CONST_NAME, ...) \
+ [&] { \
+ constexpr static bool CONST_NAME = false; \
+ return __VA_ARGS__(); \
+ }()
+#else
+ #define LOCAL_SWITCH BOOL_SWITCH
+#endif
+
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
@@ -36,7 +87,7 @@
} \
}()
-#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
+#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
diff --git a/candle-flash-attn/kernels/utils.h b/candle-flash-attn/kernels/utils.h
index 6fb39dc4..708aeddf 100644
--- a/candle-flash-attn/kernels/utils.h
+++ b/candle-flash-attn/kernels/utils.h
@@ -14,8 +14,7 @@
#include <cuda_bf16.h>
#endif
-#include <cute/algorithm/copy.hpp>
-#include <cute/algorithm/gemm.hpp>
+#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
@@ -29,10 +28,10 @@ namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
-inline __device__ uint32_t relu2(const uint32_t x);
+__forceinline__ __device__ uint32_t relu2(const uint32_t x);
template<>
-inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
+__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
@@ -50,7 +49,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
-inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
+__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
@@ -63,10 +62,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T>
-inline __device__ uint32_t convert_relu2(const float2 x);
+__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
template<>
-inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
+__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@@ -75,7 +74,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
}
template<>
-inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
+__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@@ -89,20 +88,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
template<typename T>
struct MaxOp {
-__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
+__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
-__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
+__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
-__device__ inline T operator()(T const & x, T const & y) { return x + y; }
+__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -111,7 +110,7 @@ template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
- static __device__ inline T run(T x, Operator &op) {
+ static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
@@ -123,7 +122,7 @@ struct Allreduce {
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
-static __device__ inline T run(T x, Operator &op) {
+static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
@@ -135,7 +134,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
-inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
+__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
@@ -162,9 +161,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
-inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
- TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
- ThrCopy smem_thr_copy_B) {
+__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
+ TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
+ ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@@ -184,42 +183,48 @@ inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
-inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
+__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
- // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
- // "int_tuple.hpp(74): error: conversion to inaccessible base class"
- // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
- return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
+ return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
-// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
-// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
-inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
+__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
- static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
- static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
+ static_assert(decltype(size<0>(acc_layout))::value == 4);
+ static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
- constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
- auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
- // TD [2023-08-13]: Same error as above on Cutlass 3.2
- // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
- // get<0, 1>(l),
- // get<1, 1, 1>(l));
- return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
- get<1>(get<0>(l)),
- get<1>(get<1>(get<1>(l))));
+ if constexpr (mma_shape_K == 8) {
+ return acc_layout;
+ } else {
+ auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
+ return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+template<typename Layout>
+__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
+ using X = Underscore;
+ static_assert(decltype(size<0>(acc_layout))::value == 4);
+ static_assert(decltype(rank(acc_layout))::value == 3);
+ auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
+ return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
-inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
+__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
@@ -231,7 +236,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
-inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
+__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type;
@@ -247,7 +252,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout>
-inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
+__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>);
@@ -289,7 +294,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
-inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
+__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
@@ -355,4 +360,34 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
////////////////////////////////////////////////////////////////////////////////////////////////////
+template <bool Is_even_K=true,
+ typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+ typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
+ Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
+ Tensor<Engine3, Layout3> const &predicate_K,
+ const int max_MN=0, const int min_MN=0) {
+ CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+ CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
+ CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
+ CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+ #pragma unroll
+ for (int m = 0; m < size<1>(S); ++m) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+ if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+ // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+ #pragma unroll
+ for (int k = 0; k < size<2>(S); ++k) {
+ if (Is_even_K || predicate_K(k)) {
+ cute::copy(S(_, m, k), D(_, m, k));
+ }
+ }
+ }
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
} // namespace flash