summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/kernel_traits.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h77
1 files changed, 54 insertions, 23 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index 3468e4bf..f000ff24 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+ // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
+ using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
- using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+ using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomVtransposedNoSwizzle{},
+ Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+ // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
- using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
+ using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
@@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
- make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+ make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+ using GmemLayoutAtomOaccum = std::conditional_t<
+ kBlockKSmem == 32,
+ Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
+ Stride< _8, _1>>,
+ Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
+ Stride< _16, _1>>
+ >;
+ using GmemTiledCopyOaccum = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ GmemLayoutAtomOaccum{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
+ using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+ using GmemTiledCopyRotcossin = decltype(
+ make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+ GmemLayoutAtomRotcossin{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
+ using GmemTiledCopyRotcossinCont = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ GmemLayoutAtomRotcossin{},
+ Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
+ using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
- using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+ using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomKtransposedNoSwizzle{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
+ // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
+ using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
+ Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype(
- composition(Swizzle<kSwizzlePdS, 3, 3>{},
- Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
- Stride<_1, Int<kPBlockN>>>{}));
+ composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
- using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+ using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomPdStransposedNoSwizzle{},
+ make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
+ // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
+ Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype(
- composition(Swizzle<kSwizzle, 3, 3>{},
- Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
- Stride<_1, Int<kBlockKSmem>>>{}));
+ composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
- using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
+ SmemLayoutAtomQdOtransposedNoSwizzle{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
+ // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
- static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
- static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)