summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/flash_fwd_kernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/flash_fwd_kernel.h')
-rw-r--r--candle-flash-attn/kernels/flash_fwd_kernel.h282
1 files changed, 169 insertions, 113 deletions
diff --git a/candle-flash-attn/kernels/flash_fwd_kernel.h b/candle-flash-attn/kernels/flash_fwd_kernel.h
index 232dea0d..05f5f701 100644
--- a/candle-flash-attn/kernels/flash_fwd_kernel.h
+++ b/candle-flash-attn/kernels/flash_fwd_kernel.h
@@ -4,20 +4,18 @@
#pragma once
-#include <cmath>
#include <cute/algorithm/copy.hpp>
-#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
-#include <cutlass/numeric_conversion.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
#include "softmax.h"
-#include "philox.cuh"
+
+#include "alibi.h"
namespace flash {
@@ -25,49 +23,6 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
-template <int MMA_M,
- class... Args,
- class TiledMMA>
-CUTE_HOST_DEVICE
-auto
-make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
- Stride<_1, Int<MMAStride_M>> >{},
- make_layout(size<2>(TileShape_MNK{})));
- // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template <int MMA_M,
- class... Args,
- class TiledMMA>
-CUTE_HOST_DEVICE
-auto
-make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
- TiledMMA const& tiled_mma) {
- using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
- using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
- constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
- constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
- constexpr int MMAStride_M = MMA_M * AtomShape_M;
- auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
- Stride<_1, Int<MMAStride_M>> >{},
- // TODO: Shouldn't this be size<1>?
- make_layout(size<2>(TileShape_MNK{})));
- // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
- return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) {
@@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum);
} else {
Tensor scores_max_prev = make_fragment_like(scores_max);
- copy(scores_max, scores_max_prev);
+ cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem(
- Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
+ Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
- // TODO(laurent): reactivate the following
- // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
+ CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
- copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
+ cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
@@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
- const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
- if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+ const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+ if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+ const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
- if (Is_causal) {
- n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
+ if (Is_causal || Is_local) {
+ n_block_max = std::min(n_block_max,
+ cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// }
}
+ // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
+ // Otherwise we might read OOB elements from gK and gV.
+ if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
+ // Save seed and offset for backward. If we don't have this here, the 0-th thread block might
+ // exit early and no one saves the rng state.
+// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
+// auto seeds = at::cuda::philox::unpack(params.philox_args);
+// params.rng_state[0] = std::get<0>(seeds);
+// params.rng_state[1] = std::get<1>(seeds);
+// params.rng_state[0] = 0;
+// params.rng_state[1] = 0;
+// }
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
+ Shape<Int<kBlockM>, Int<kHeadDim>>{},
+ make_stride(params.o_row_stride, _1{}));
+ Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
+ Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+ Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+ Tensor tOrO = make_tensor<Element>(shape(tOgO));
+ clear(tOrO);
+ // Construct identity layout for sO
+ Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+ Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+ if (!Is_even_K) {
+ #pragma unroll
+ for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ );
+ #pragma unroll
+ for (int m = 0; m < size<1>(tOgO); ++m) {
+ const int row = get<0>(tOcO(0, m, 0));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
+ }
+ return;
+ }
+ // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
@@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
- auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
- auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
+ auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling
//
- auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
- // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
- auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
- auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70
@@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
- flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
- binfo.actual_seqlen_q - m_block * kBlockM);
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem
@@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
- copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
- flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
- binfo.actual_seqlen_k - n_block * kBlockN);
+ flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads();
@@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
- copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
// auto seeds = at::cuda::philox::unpack(params.philox_args);
@@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o);
+ float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
+
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
- constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+ constexpr int n_masking_steps = (!Is_causal && !Is_local)
+ ? 1
+ : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV
if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
- flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
- gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
- acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- // if (cute::thread0()) { print(scores); }
+ // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
- if (!Is_causal) {
- if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
+
+ if (Has_alibi) {
+ flash::apply_alibi<Is_causal>(
+ scores,
+ n_block * kBlockN,
+ binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16,
+ alibi_slope
+ );
+ }
+
+ if (!Is_causal && !Is_local) {
+ if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
@@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
- flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
- // m_block * kBlockM + get<0>(idx_row(0)),
- m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
- kNWarps * 16);
- // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
- // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
+ flash::apply_mask_local</*HasWSLeft=*/Is_local>(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ // m_block * kBlockM + get<0>(idx_row(0)),
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right
+ // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
+ // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
+ );
+ // if (cute::thread0()) { print_tensor(scores); }
}
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0
- ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
- : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+ ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
- uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
- uint32_t block_col_idx = n_block * (kBlockN / 32);
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+ int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
- copy(tOrP, tOrP_copy);
+ cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
@@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
// if (cute::thread0()) { print(tOrP); }
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
- if (n_masking_steps > 1 && n_block <= 0) {
+ if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block;
break;
}
}
// These are the iterations where we don't need masking on S
- for (; n_block >= 0; --n_block) {
+ for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
// Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
- acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K
);
flash::cp_async_wait<0>();
__syncthreads();
- if (n_block > 0) {
+ if (n_block > n_block_min) {
// Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
- flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
@@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
- softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+ if (Has_alibi) {
+ flash::apply_alibi<Is_causal>(
+ scores,
+ n_block * kBlockN,
+ binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16,
+ alibi_slope
+ );
+ }
+
+ if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
+ flash::apply_mask_local(
+ scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q, kNWarps * 16,
+ params.window_size_left, params.window_size_right
+ );
+ }
+
+ softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
- uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
- uint32_t block_col_idx = n_block * (kBlockN / 32);
+ int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+ int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP);
- copy(tOrP, tOrP_copy);
+ cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps
);
- flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN);
}
if (Is_dropout) {
@@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps);
}
- flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
// Epilogue
@@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
- auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
- // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
+ auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
- copy(smem_thr_copy_O, taccOrO, taccOsO);
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
@@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
- auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO));
- copy(gmem_thr_copy_O, tOsO, tOrO);
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
@@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
- flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
- gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
+ flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}
+
////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x;
// The block index for the batch.
@@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
- flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
+ flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////