summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/kernel_traits.h
diff options
context:
space:
mode:
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h30
1 files changed, 15 insertions, 15 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
index 5a7b7491..8c089748 100644
--- a/candle-flash-attn/kernels/kernel_traits.h
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
- using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
- using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
+ using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
+ using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
@@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
+ AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
@@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
@@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base {
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
@@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
- using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
@@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
- using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
- using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
@@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base {
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
- DefaultCopy
+ AutoVectorizingCopyWithAssumedAlignment<128>
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
@@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base {
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
- make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store