diff options
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r-- | candle-flash-attn/kernels/kernel_traits.h | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h index 5a7b7491..8c089748 100644 --- a/candle-flash-attn/kernels/kernel_traits.h +++ b/candle-flash-attn/kernels/kernel_traits.h @@ -101,8 +101,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape<Int<kBlockM>, Int<kHeadDim>>{})); - using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; - using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; + using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>; + using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>; static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); @@ -125,14 +125,14 @@ struct Flash_fwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store @@ -144,7 +144,7 @@ struct Flash_fwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store using GmemLayoutAtomRotcossin = GmemLayoutAtom; @@ -153,7 +153,7 @@ struct Flash_fwd_kernel_traits : public Base { GmemLayoutAtomRotcossin{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{}, GmemLayoutAtomRotcossin{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load }; @@ -250,7 +250,7 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{}))); using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); - using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; using SmemLayoutQdOtransposed = decltype( composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{}))); @@ -263,7 +263,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdKV = decltype(tile_to_shape( SmemLayoutAtomdKV{}, make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); - using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; using SmemLayoutAtomdQ = decltype( composition(Swizzle<kSwizzle, 3, 3>{}, @@ -272,7 +272,7 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutdQ = decltype(tile_to_shape( SmemLayoutAtomdQ{}, make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); - using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -303,22 +303,22 @@ struct Flash_bwd_kernel_traits : public Base { using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, - DefaultCopy + AutoVectorizingCopyWithAssumedAlignment<128> >; using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, GmemLayoutAtom{}, Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read using GmemTiledCopydO = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemTiledCopydKV = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemTiledCopydQ = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{}, GmemLayoutAtom{}, Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store using GmemLayoutAtomdQaccum = std::conditional_t< @@ -329,12 +329,12 @@ struct Flash_bwd_kernel_traits : public Base { Stride< _16, _1>> >; using GmemTiledCopydQaccum = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, GmemLayoutAtomdQaccum{}, Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store using GmemTiledCopydQaccumAtomicAdd = decltype( - make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, Layout<Shape <_8, _32>, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store |