/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include "cute/algorithm/copy.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include using namespace cute; template struct Flash_kernel_traits_sm90 { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using Element = elem_type; static constexpr bool Has_cp_async = true; #else using Element = cutlass::half_t; static constexpr bool Has_cp_async = false; #endif using ElementAccum = float; using index_t = uint32_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using ValLayoutMNK = Layout>; #else using MMA_Atom_Arch = MMA_Atom; using ValLayoutMNK = Layout>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #else using SmemCopyAtom = Copy_Atom; using SmemCopyAtomTransposed = Copy_Atom; #endif }; template > struct Flash_fwd_kernel_traits : public Base { using Element = typename Base::Element; using ElementAccum = typename Base::ElementAccum; using index_t = typename Base::index_t; static constexpr bool Has_cp_async = Base::Has_cp_async; using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; // The number of threads. static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM using SmemLayoutAtomQ = decltype( composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); using SmemLayoutAtomVtransposed = decltype( composition(Swizzle{}, // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 Layout, Int>, Stride<_1, Int>>{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape, Int>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); using SmemCopyAtomO = Copy_Atom; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. // For example, for d=128, smem is split into 2 "pages", each page takes care of columns // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, // to the same banks. static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); using GmemLayoutAtom = Layout, Int>, Stride, _1>>; // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading // from the same address by the same threadblock. This is slightly faster. using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, DefaultCopy >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); using GmemLayoutAtomP = Layout, Int>, Stride, _1>>; using GmemTiledCopyP = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtomP{}, Layout>{})); // Val layout, 8 vals per store }; ////////////////////////////////////////////////////////////////////////////////////////////////////