/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include "block_info.h" #include "kernel_traits.h" #include "utils.h" #include "softmax.h" #include "philox.cuh" namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_HOST_DEVICE auto make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { using TileShape_MNK = typename TiledMMA::TiledShape_MNK; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; constexpr int MMAStride_M = MMA_M * AtomShape_M; auto t = make_tile(Layout, Int>, Stride<_1, Int> >{}, make_layout(size<2>(TileShape_MNK{}))); // if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); } return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); } //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTE_HOST_DEVICE auto make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { using TileShape_MNK = typename TiledMMA::TiledShape_MNK; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value; constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M; constexpr int MMAStride_M = MMA_M * AtomShape_M; auto t = make_tile(Layout, Int>, Stride<_1, Int> >{}, // TODO: Shouldn't this be size<1>? make_layout(size<2>(TileShape_MNK{}))); // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, Tensor2 &acc_o, float softmax_scale_log2) { if (Is_first) { flash::template reduce_max(scores, scores_max); flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); flash::reduce_sum(scores, scores_sum); } else { Tensor scores_max_prev = make_fragment_like(scores_max); copy(scores_max, scores_max_prev); flash::template reduce_max(scores, scores_max); // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); #pragma unroll for (int mi = 0; mi < size(scores_max); ++mi) { float scores_max_cur = !Check_inf ? scores_max(mi) : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); scores_sum(mi) *= scores_scale; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); Tensor scores_sum_cur = make_fragment_like(scores_sum); flash::reduce_sum(scores, scores_sum_cur); #pragma unroll for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void write_softmax_to_gmem( Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_thr_copy_P ) { // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) Layout l = tOrP.layout(); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); // TODO(laurent): reactivate the following // CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); #pragma unroll for (int mi = 0; mi < size<1>(tPrP); ++mi) { copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; // Shared memory. extern __shared__ char smem_[]; // The thread index. const int tidx = threadIdx.x; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } } // We iterate over the blocks in reverse order. This is because the last block is the only one // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{})); Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); Tensor tPgP = gmem_thr_copy_P.partition_D(gP); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // // Copy Atom retiling // auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // TODO: this might need to change if we change the mma instruction in SM70 Tensor scores_max = make_tensor(Shape(acc_o)>>{}); Tensor scores_sum = make_fragment_like(scores_max); // // PREDICATES // // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { // printf("%d ", get<0>(tScQ(i))); // } // printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { // printf("%d ", get<1>(tScQ(i))); // } // printf("\n"); // } // Repeat the partitioning with identity layouts Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); // Set predicates for k bounds if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } #pragma unroll for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } } // Prologue Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // Copy rmem to smem // // copy(tQrQ, tQsQ); // flash::cp_async_wait<0>(); // __syncthreads(); // // if (cute::thread(1, 0)) { print(tQsQ); } // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // if (cute::thread0()) { print(sQNoSwizzle); } if (Kernel_traits::Share_Q_K_smem) { flash::cp_async_wait<0>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); } int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { flash::cp_async_wait<1>(); __syncthreads(); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); } // auto seeds = at::cuda::philox::unpack(params.philox_args); // unsigned long long seed = std::get<0>(seeds); // unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; unsigned long long seed = 0; unsigned long long offset = 0; clear(acc_o); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV if (masking_step > 0) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); flash::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread0()) { print(scores); } // We don't put the masking before the matmul S = Q K^T because we don't clear sK // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. if (!Is_causal) { if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) // static_assert(decltype(size<0>(taccScS))::value == 4); // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, // m_block * kBlockM); // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, // m_block * kBlockM + get<0>(idx_row(0)), m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); } flash::cp_async_wait<0>(); __syncthreads(); if (n_block > 0) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; uint32_t block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps); } // if (cute::thread0()) { print(tOrP); } flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= 0) { --n_block; break; } } // These are the iterations where we don't need masking on S for (; n_block >= 0; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); // Advance gV tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); flash::copy(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K ); flash::cp_async_wait<0>(); __syncthreads(); if (n_block > 0) { // Advance gK tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); flash::copy(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; uint32_t block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); copy(tOrP, tOrP_copy); flash::apply_dropout( tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); tPgP.data() = tPgP.data() + (-kBlockN); } if (Is_dropout) { flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps); } flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); } // Epilogue // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor lse = make_fragment_like(scores_sum); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = scores_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } // if (cute::thread0()) { print(acc_o_rowcol); } // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } copy(smem_thr_copy_O, taccOrO, taccOsO); const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); Tensor tOrO = make_tensor(shape(tOgO)); copy(gmem_thr_copy_O, tOsO, tOrO); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO_row(mi)); if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } } } // Construct identity layout for sO Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting // them to have the same number of threads or have to traverse the attention matrix // in the same order. // In the Philox RNG, we use the offset to store the batch, head, and the lane id // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash