diff options
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r-- | candle-flash-attn/kernels/kernel_traits.h | 77 |
1 files changed, 54 insertions, 23 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 3468e4bf..f000ff24 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape<Int<kBlockN>, Int<kHeadDim>>{})); + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, + Stride<_1, Int<kBlockKSmem>>>; using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, - Stride<_1, Int<kBlockKSmem>>>{})); + composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{})); using SmemLayoutVtransposed = decltype(tile_to_shape( SmemLayoutAtomVtransposed{}, Shape<Int<kHeadDim>, Int<kBlockN>>{})); // Maybe the VtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape<Int<kHeadDim>, Int<kBlockN>>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutAtomO = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDim>>{})); - using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; + using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; @@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base { DefaultCopy >; using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, + make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; @@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base { Stride<Int<kGmemThreadsPerRowP>, _1>>; using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, GmemLayoutAtomP{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout<Shape <_16, _8>, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout<Shape <_8, _16>, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. @@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); + using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, + Stride<_1, Int<kBlockKSmem>>>; using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, - Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>, - Stride<_1, Int<kBlockKSmem>>>{})); + composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{})); using SmemLayoutKtransposed = decltype(tile_to_shape( SmemLayoutAtomKtransposed{}, make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); // Maybe the KtransposeNoSwizzle just needs to have the right shape // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 @@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>, + Stride<_1, Int<kPBlockN>>>; using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle<kSwizzlePdS, 3, 3>{}, - Layout<Shape<Int<kPBlockN>, Int<kBlockM>>, - Stride<_1, Int<kPBlockN>>>{})); + composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{})); using SmemLayoutPdStransposed = decltype(tile_to_shape( SmemLayoutAtomPdStransposed{}, make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>, + Stride<_1, Int<kBlockKSmem>>>; using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle<kSwizzle, 3, 3>{}, - Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>, - Stride<_1, Int<kBlockKSmem>>>{})); + composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); using SmemLayoutQdOtransposed = decltype(tile_to_shape( SmemLayoutAtomQdOtransposed{}, make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutAtomdKV = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemdPsumCount = kBlockM; static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); - static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) |