diff options
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r-- | candle-flash-attn/kernels/kernel_traits.h | 119 |
1 files changed, 33 insertions, 86 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index f000ff24..5a7b7491 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -1,10 +1,10 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once -#include "cute/algorithm/copy.hpp" +#include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" @@ -24,7 +24,7 @@ struct Flash_kernel_traits { #endif using ElementAccum = float; - using index_t = uint32_t; + using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< @@ -32,10 +32,8 @@ struct Flash_kernel_traits { MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN> >; - using ValLayoutMNK = Layout<Shape<_1, _2, _1>>; #else using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; - using ValLayoutMNK = Layout<Shape<_1, _2, _2>>; #endif #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 @@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base { using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile<Int<16 * kNWarps>, _16, _16>>; using SmemLayoutAtomQ = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -91,20 +89,10 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{})); - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, - Stride<_1, Int<kBlockKSmem>>>; - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape<Int<kHeadDim>, Int<kBlockN>>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomVtransposedNoSwizzle{}, - Shape<Int<kHeadDim>, Int<kBlockN>>{})); - // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -116,10 +104,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -149,15 +135,6 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>, - Stride<Int<kGmemThreadsPerRowP>, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, - GmemLayoutAtomP{}, - Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, @@ -218,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base { using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; using TiledMmadKV = TiledMMA< typename Base::MMA_Atom_Arch, Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>, - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; using SmemLayoutAtomQdO = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -247,26 +224,18 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); - using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, - Stride<_1, Int<kBlockKSmem>>>; - using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{})); - using SmemLayoutKtransposed = decltype(tile_to_shape( - SmemLayoutAtomKtransposed{}, - make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); - // Maybe the KtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomKtransposedNoSwizzle{}, - make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); - // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; - static_assert(kBlockN >= 64); + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = 64; + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; @@ -277,30 +246,15 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); - using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>, - Stride<_1, Int<kPBlockN>>>; - using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{})); - using SmemLayoutPdStransposed = decltype(tile_to_shape( - SmemLayoutAtomPdStransposed{}, - make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomPdStransposedNoSwizzle{}, - make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); - // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; - using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>, - Stride<_1, Int<kBlockKSmem>>>; - using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); - using SmemLayoutQdOtransposed = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposed{}, - make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposedNoSwizzle{}, - make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); - // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -320,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; - static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); - static constexpr int kSmemPCount = size(SmemLayoutPdS{}); - static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); - static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); - static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) @@ -338,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base { + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 - + kSmemdSSize + kSmemPSize; - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); |