summaryrefslogtreecommitdiff
path: root/candle-flash-attn/kernels/kernel_traits.h
diff options
context:
space:
mode:
authorLaurent Mazare <laurent.mazare@gmail.com>2023-07-26 07:48:10 +0100
committerGitHub <noreply@github.com>2023-07-26 07:48:10 +0100
commitd9f9c859afaeed95df420aca5fdb73f52f9239c5 (patch)
tree2ef898b2906a24b57ea42b0294bc51b928f0513c /candle-flash-attn/kernels/kernel_traits.h
parentc97d51243c177e0497ea7147f426c4cc1e532c9b (diff)
downloadcandle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.tar.gz
candle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.tar.bz2
candle-d9f9c859afaeed95df420aca5fdb73f52f9239c5.zip
Add flash attention (#241)
* Add some flash-attn kernel, import the code for flash-attn v2 from Dao-AILab. * More flash attn. * Set up the flash attn parameters. * Get things to compile locally. * Move the flash attention files in a different directory. * Build the static C library with nvcc. * Add more flash attention. * Update the build part. * Better caching. * Exclude flash attention from the default workspace. * Put flash-attn behind a feature gate. * Get the flash attn kernel to run. * Move the flags to a more appropriate place. * Enable flash attention in llama. * Use flash attention in llama.
Diffstat (limited to 'candle-flash-attn/kernels/kernel_traits.h')
-rw-r--r--candle-flash-attn/kernels/kernel_traits.h366
1 files changed, 366 insertions, 0 deletions
diff --git a/candle-flash-attn/kernels/kernel_traits.h b/candle-flash-attn/kernels/kernel_traits.h
new file mode 100644
index 00000000..3468e4bf
--- /dev/null
+++ b/candle-flash-attn/kernels/kernel_traits.h
@@ -0,0 +1,366 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "cute/algorithm/copy.hpp"
+
+#include "cutlass/cutlass.h"
+#include "cutlass/layout/layout.h"
+#include <cutlass/numeric_types.h>
+
+using namespace cute;
+
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
+struct Flash_kernel_traits {
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ using Element = elem_type;
+ static constexpr bool Has_cp_async = true;
+#else
+ using Element = cutlass::half_t;
+ static constexpr bool Has_cp_async = false;
+#endif
+
+ using ElementAccum = float;
+ using index_t = uint32_t;
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+ using MMA_Atom_Arch = std::conditional_t<
+ std::is_same_v<elem_type, cutlass::half_t>,
+ MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
+ MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
+ >;
+ using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
+#else
+ using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
+ using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
+#endif
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
+ using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
+ using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
+#else
+ using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
+ using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
+#endif
+};
+
+// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
+ typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
+struct Flash_fwd_kernel_traits : public Base {
+ using Element = typename Base::Element;
+ using ElementAccum = typename Base::ElementAccum;
+ using index_t = typename Base::index_t;
+ static constexpr bool Has_cp_async = Base::Has_cp_async;
+ using SmemCopyAtom = typename Base::SmemCopyAtom;
+ using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+ static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
+ static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
+
+ // The number of threads.
+ static constexpr int kNWarps = kNWarps_;
+ static constexpr int kNThreads = kNWarps * 32;
+
+ static constexpr int kBlockM = kBlockM_;
+ static constexpr int kBlockN = kBlockN_;
+ static constexpr int kHeadDim = kHeadDim_;
+ static_assert(kHeadDim % 32 == 0);
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+ static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+ static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+ using TiledMma = TiledMMA<
+ typename Base::MMA_Atom_Arch,
+ Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
+ typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+ using SmemLayoutAtomQ = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
+ Layout<Shape<_8, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutQ = decltype(tile_to_shape(
+ SmemLayoutAtomQ{},
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+
+ using SmemLayoutKV = decltype(tile_to_shape(
+ SmemLayoutAtomQ{},
+ Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+
+ using SmemLayoutAtomVtransposed = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ // This has to be kBlockN and not 8, otherwise we get wrong results for d=128
+ Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>{}));
+ using SmemLayoutVtransposed = decltype(tile_to_shape(
+ SmemLayoutAtomVtransposed{},
+ Shape<Int<kHeadDim>, Int<kBlockN>>{}));
+ // Maybe the VtransposeNoSwizzle just needs to have the right shape
+ // And the strides don't matter?
+ using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
+
+ using SmemLayoutAtomO = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<Int<8>, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutO = decltype(tile_to_shape(
+ SmemLayoutAtomO{},
+ Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+ using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
+
+ static constexpr int kSmemQCount = size(SmemLayoutQ{});
+ static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
+ static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
+ static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+ static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
+
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+ // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
+ // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
+ // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
+ // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
+ // to the same banks.
+ static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+ static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+ using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
+
+ // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+ // from the same address by the same threadblock. This is slightly faster.
+ using Gmem_copy_struct = std::conditional_t<
+ Has_cp_async,
+ SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+ DefaultCopy
+ >;
+ using GmemTiledCopyQKV = decltype(
+ make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
+ using GmemTiledCopyO = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+ static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
+ static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
+ using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
+ Stride<Int<kGmemThreadsPerRowP>, _1>>;
+
+ using GmemTiledCopyP = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtomP{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
+
+};
+
+// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
+// No_double_buffer is another option to reduce smem usage, but will slow things down.
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
+ int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
+ bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
+ typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
+struct Flash_bwd_kernel_traits : public Base {
+ using Element = typename Base::Element;
+ using ElementAccum = typename Base::ElementAccum;
+ using index_t = typename Base::index_t;
+ static constexpr bool Has_cp_async = Base::Has_cp_async;
+ using SmemCopyAtom = typename Base::SmemCopyAtom;
+ using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+ static constexpr bool Is_V_in_regs = Is_V_in_regs_;
+ static constexpr bool No_double_buffer = No_double_buffer_;
+
+ // The number of threads.
+ static constexpr int kNWarps = kNWarps_;
+ static constexpr int kNThreads = kNWarps * 32;
+
+ static constexpr int kBlockM = kBlockM_;
+ static constexpr int kBlockN = kBlockN_;
+ static constexpr int kHeadDim = kHeadDim_;
+ static_assert(kHeadDim % 32 == 0);
+ static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+ static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+ static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+ static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
+ static_assert(kNWarps % AtomLayoutMSdP == 0);
+ static_assert(kNWarps % AtomLayoutNdKV == 0);
+ static_assert(kNWarps % AtomLayoutMdQ == 0);
+
+ using TiledMmaSdP = TiledMMA<
+ typename Base::MMA_Atom_Arch,
+ Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
+ typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+ using TiledMmadKV = TiledMMA<
+ typename Base::MMA_Atom_Arch,
+ Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
+ typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+ using TiledMmadQ = TiledMMA<
+ typename Base::MMA_Atom_Arch,
+ Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
+ typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
+
+ using SmemLayoutAtomQdO = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<_8, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutQdO = decltype(tile_to_shape(
+ SmemLayoutAtomQdO{},
+ make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
+
+ using SmemLayoutAtomKV = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutKV = decltype(tile_to_shape(
+ // SmemLayoutAtomQdO{},
+ SmemLayoutAtomKV{},
+ make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
+
+ using SmemLayoutAtomKtransposed = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
+ Stride<_1, Int<kBlockKSmem>>>{}));
+ using SmemLayoutKtransposed = decltype(tile_to_shape(
+ SmemLayoutAtomKtransposed{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
+ // Maybe the KtransposeNoSwizzle just needs to have the right shape
+ // And the strides don't matter?
+ using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
+
+ // TODO: generalize to other values of kBlockN
+ // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
+ // static constexpr int kPBlockN = kBlockN;
+ static_assert(kBlockN >= 64);
+ // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
+ static constexpr int kPBlockN = 64;
+ static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
+ // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
+ static constexpr int kSwizzlePdS = 3;
+ using SmemLayoutAtomPdS = decltype(
+ composition(Swizzle<kSwizzlePdS, 3, 3>{},
+ Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
+ Stride<Int<kPBlockN>, _1>>{}));
+ using SmemLayoutPdS = decltype(tile_to_shape(
+ SmemLayoutAtomPdS{},
+ make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
+ using SmemLayoutAtomPdStransposed = decltype(
+ composition(Swizzle<kSwizzlePdS, 3, 3>{},
+ Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
+ Stride<_1, Int<kPBlockN>>>{}));
+ using SmemLayoutPdStransposed = decltype(tile_to_shape(
+ SmemLayoutAtomPdStransposed{},
+ make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
+ using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
+ using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
+
+ using SmemLayoutAtomQdOtransposed = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
+ Stride<_1, Int<kBlockKSmem>>>{}));
+ using SmemLayoutQdOtransposed = decltype(tile_to_shape(
+ SmemLayoutAtomQdOtransposed{},
+ make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
+ using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
+
+ using SmemLayoutAtomdKV = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<_8, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutdKV = decltype(tile_to_shape(
+ SmemLayoutAtomdKV{},
+ make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
+ using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
+
+ using SmemLayoutAtomdQ = decltype(
+ composition(Swizzle<kSwizzle, 3, 3>{},
+ Layout<Shape<_8, Int<kBlockKSmem>>,
+ Stride<Int<kBlockKSmem>, _1>>{}));
+ using SmemLayoutdQ = decltype(tile_to_shape(
+ SmemLayoutAtomdQ{},
+ make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
+ using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
+
+ static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
+ static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
+ static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
+ static constexpr int kSmemPCount = size(SmemLayoutPdS{});
+ static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
+ static constexpr int kSmemdPsumCount = kBlockM;
+ static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
+ static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
+ static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
+ static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
+ static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
+ static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
+ static constexpr int kSmemSize = kSmemQdOSize
+ + (!Is_V_in_regs
+ ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
+ : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
+ static constexpr int kSmemSize1colblock = kSmemQdOSize
+ + (!Is_V_in_regs
+ ? kSmemKVSize + kSmemdSSize + kSmemPSize
+ : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
+ static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
+ + kSmemdSSize + kSmemPSize;
+
+
+ static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+ static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+ // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
+ // to affect speed in practice.
+ static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+ static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+ using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
+ Stride<Int<kGmemThreadsPerRow>, _1>>;
+
+ // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+ // from the same address by the same threadblock. This is slightly faster.
+ using Gmem_copy_struct = std::conditional_t<
+ Has_cp_async,
+ SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+ DefaultCopy
+ >;
+ using GmemTiledCopyQKV = decltype(
+ make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
+ using GmemTiledCopydO = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
+ using GmemTiledCopydKV = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
+ using GmemTiledCopydQ = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
+ GmemLayoutAtom{},
+ Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
+ using GmemLayoutAtomdQaccum = std::conditional_t<
+ kBlockKSmem == 32,
+ Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
+ Stride< _8, _1>>,
+ Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
+ Stride< _16, _1>>
+ >;
+ using GmemTiledCopydQaccum = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ GmemLayoutAtomdQaccum{},
+ Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
+
+ using GmemTiledCopydQaccumAtomicAdd = decltype(
+ make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+ Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
+ Stride<_32, _1>>{},
+ Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
+
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////