summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/kernel_traits.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h119
1 files changed, 33 insertions, 86 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index f000ff24..5a7b7491 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -1,10 +1,10 @@
/******************************************************************************
- * Copyright (c) 2023, Tri Dao.
+ * Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
-#include "cute/algorithm/copy.hpp"
+#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
@@ -24,7 +24,7 @@ struct Flash_kernel_traits {
#endif
using ElementAccum = float;
- using index_t = uint32_t;
+ using index_t = int64_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
@@ -32,10 +32,8 @@ struct Flash_kernel_traits {
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
- using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
- using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
@@ -76,7 +74,7 @@ struct Flash_fwd_kernel_traits : public Base {
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * kNWarps>, _16, _16>>;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -91,20 +89,10 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
- // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
- using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomVtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
- using SmemLayoutVtransposed = decltype(tile_to_shape(
- SmemLayoutAtomVtransposed{},
- Shape<Int<kHeadDim>, Int<kBlockN>>{}));
- // Maybe the VtransposeNoSwizzle just needs to have the right shape
- // And the strides don't matter?
- using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomVtransposedNoSwizzle{},
- Shape<Int<kHeadDim>, Int<kBlockN>>{}));
- // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+ // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
+ using SmemLayoutVtransposed = decltype(
+ composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
+ using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -116,10 +104,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
- static constexpr int kSmemQCount = size(SmemLayoutQ{});
- static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
- static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
- static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+ static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
+ static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
@@ -149,15 +135,6 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
- static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
- static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
- using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
- Stride<Int<kGmemThreadsPerRowP>, _1>>;
-
- using GmemTiledCopyP = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
- GmemLayoutAtomP{},
- Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
@@ -218,17 +195,17 @@ struct Flash_bwd_kernel_traits : public Base {
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
- typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+ Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -247,26 +224,18 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
- using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomKtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
- using SmemLayoutKtransposed = decltype(tile_to_shape(
- SmemLayoutAtomKtransposed{},
- make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
- // Maybe the KtransposeNoSwizzle just needs to have the right shape
- // And the strides don't matter?
- using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomKtransposedNoSwizzle{},
- make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
- // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+ using SmemLayoutKtransposed = decltype(
+ composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
+ using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
- static_assert(kBlockN >= 64);
+ // Temporarily disabling this for hdim 256 on sm86 and sm89
+ // static_assert(kBlockN >= 64);
+ static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
- static constexpr int kPBlockN = 64;
+ static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
@@ -277,30 +246,15 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
- using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
- Stride<_1, Int<kPBlockN>>>;
- using SmemLayoutAtomPdStransposed = decltype(
- composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
- using SmemLayoutPdStransposed = decltype(tile_to_shape(
- SmemLayoutAtomPdStransposed{},
- make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomPdStransposedNoSwizzle{},
- make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+ using SmemLayoutPdStransposed = decltype(
+ composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
+ using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
+
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
- using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
- Stride<_1, Int<kBlockKSmem>>>;
- using SmemLayoutAtomQdOtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
- using SmemLayoutQdOtransposed = decltype(tile_to_shape(
- SmemLayoutAtomQdOtransposed{},
- make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
- SmemLayoutAtomQdOtransposedNoSwizzle{},
- make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+ using SmemLayoutQdOtransposed = decltype(
+ composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -320,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base {
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
- static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
- static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
- static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
- static constexpr int kSmemPCount = size(SmemLayoutPdS{});
- static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
- static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
- static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
- static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
- static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
- static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
+ // Double buffer for sQ
+ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
+ static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
+ static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
+ static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
+ static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
@@ -338,9 +288,6 @@ struct Flash_bwd_kernel_traits : public Base {
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
- static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
- + kSmemdSSize + kSmemPSize;
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");