diff options
-rw-r--r-- | src/ir/module-utils.cpp | 217 | ||||
-rw-r--r-- | src/ir/module-utils.h | 7 | ||||
-rw-r--r-- | src/ir/type-updating.cpp | 2 | ||||
-rw-r--r-- | src/support/insert_ordered.h | 2 | ||||
-rw-r--r-- | src/wasm-type.h | 6 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 70 | ||||
-rw-r--r-- | test/lit/isorecursive-output-ordering.wast | 100 |
7 files changed, 307 insertions, 97 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 0d0ec531c..723675c85 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -33,6 +33,12 @@ struct Counts : public InsertOrderedMap<HeapType, size_t> { note(ht); } } + // Ensure a type is included without increasing its count. + void include(HeapType type) { + if (!type.isBasic()) { + (*this)[type]; + } + } }; Counts getHeapTypeCounts(Module& wasm) { @@ -128,14 +134,16 @@ Counts getHeapTypeCounts(Module& wasm) { // Recursively traverse each reference type, which may have a child type that // is itself a reference type. This reflects an appearance in the binary - // format that is in the type section itself. - // As we do this we may find more and more types, as nested children of - // previous ones. Each such type will appear in the type section once, so - // we just need to visit it once. + // format that is in the type section itself. As we do this we may find more + // and more types, as nested children of previous ones. Each such type will + // appear in the type section once, so we just need to visit it once. Also + // track which recursion groups we've already processed to avoid quadratic + // behavior when there is a single large group. InsertOrderedSet<HeapType> newTypes; for (auto& [type, _] : counts) { newTypes.insert(type); } + std::unordered_set<RecGroup> includedGroups; while (!newTypes.empty()) { auto iter = newTypes.begin(); auto ht = *iter; @@ -156,16 +164,18 @@ Counts getHeapTypeCounts(Module& wasm) { // is in flux, skip counting them to keep the type orderings in nominal // test outputs more similar to the orderings in the equirecursive // outputs. FIXME - counts.note(*super); + counts.include(*super); } } // Make sure we've noted the complete recursion group of each type as well. auto recGroup = ht.getRecGroup(); - for (auto type : recGroup) { - if (!counts.count(type)) { - newTypes.insert(type); - counts.note(type); + if (includedGroups.insert(recGroup).second) { + for (auto type : recGroup) { + if (!counts.count(type)) { + newTypes.insert(type); + counts.include(type); + } } } } @@ -173,37 +183,52 @@ Counts getHeapTypeCounts(Module& wasm) { return counts; } -void coalesceRecGroups(IndexedHeapTypes& indexedTypes) { - if (getTypeSystem() != TypeSystem::Isorecursive) { - // No rec groups to coalesce. - return; - } - - // TODO: Perform a topological sort of the recursion groups to create a valid - // ordering rather than this hack that just gets all the types in a group to - // be adjacent. - assert(indexedTypes.indices.empty()); - std::unordered_set<HeapType> seen; - std::vector<HeapType> grouped; - grouped.reserve(indexedTypes.types.size()); - for (auto type : indexedTypes.types) { - if (seen.insert(type).second) { - for (auto member : type.getRecGroup()) { - grouped.push_back(member); - seen.insert(member); - } - } - } - assert(grouped.size() == indexedTypes.types.size()); - indexedTypes.types = grouped; -} - void setIndices(IndexedHeapTypes& indexedTypes) { for (Index i = 0; i < indexedTypes.types.size(); i++) { indexedTypes.indices[indexedTypes.types[i]] = i; } } +// A utility data structure for performing topological sorts. The elements to be +// sorted should all be `push`ed to the stack, then while the stack is not +// empty, predecessors of the top element should be pushed. On each iteration, +// if the top of the stack is unchanged after pushing all the predecessors, that +// means the predecessors have all been finished so the current element is +// finished as well and should be popped. +template<typename T> struct TopologicalSortStack { + std::list<T> workStack; + // Map items to their locations in the work stack so that we can make sure + // each appears and is finished only once. + std::unordered_map<T, typename std::list<T>::iterator> locations; + // Remember which items we have finished so we don't visit them again. + std::unordered_set<T> finished; + + bool empty() const { return workStack.empty(); } + + void push(T item) { + if (finished.count(item)) { + return; + } + workStack.push_back(item); + auto newLocation = std::prev(workStack.end()); + auto [it, inserted] = locations.insert({item, newLocation}); + if (!inserted) { + // Remove the previous iteration of the pushed item and update its + // location. + workStack.erase(it->second); + it->second = newLocation; + } + } + + T& peek() { return workStack.back(); } + + void pop() { + finished.insert(workStack.back()); + locations.erase(workStack.back()); + workStack.pop_back(); + } +}; + } // anonymous namespace std::vector<HeapType> collectHeapTypes(Module& wasm) { @@ -216,37 +241,117 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) { return types; } -IndexedHeapTypes getIndexedHeapTypes(Module& wasm) { +IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { Counts counts = getHeapTypeCounts(wasm); - IndexedHeapTypes indexedTypes; + + if (getTypeSystem() != TypeSystem::Isorecursive) { + // Sort by frequency and then original insertion order. + std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), + counts.end()); + std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { + return a.second > b.second; + }); + + // Collect the results. + IndexedHeapTypes indexedTypes; + for (Index i = 0; i < sorted.size(); ++i) { + indexedTypes.types.push_back(sorted[i].first); + } + + setIndices(indexedTypes); + return indexedTypes; + } + + // Isorecursive types have to be arranged into topologically ordered recursion + // groups. Sort the groups by average use count among their members so that + // the topological sort will place frequently used types first. + struct GroupInfo { + size_t index; + double useCount = 0; + std::unordered_set<RecGroup> preds; + std::vector<RecGroup> sortedPreds; + GroupInfo(size_t index) : index(index) {} + bool operator<(const GroupInfo& other) const { + if (useCount != other.useCount) { + return useCount < other.useCount; + } + return index < other.index; + } + }; + + std::unordered_map<RecGroup, GroupInfo> groupInfos; for (auto& [type, _] : counts) { - indexedTypes.types.push_back(type); + RecGroup group = type.getRecGroup(); + // Try to initialize a new info or get the existing info. + auto& info = groupInfos.insert({group, {groupInfos.size()}}).first->second; + // Update the reference count. + info.useCount += counts.at(type); + // Collect predecessor groups. + for (auto child : type.getReferencedHeapTypes()) { + if (!child.isBasic()) { + RecGroup otherGroup = child.getRecGroup(); + if (otherGroup != group) { + info.preds.insert(otherGroup); + } + } + } } - coalesceRecGroups(indexedTypes); - setIndices(indexedTypes); - return indexedTypes; -} + // Fix up the use counts to be averages to ensure groups are used comensurate + // with the amount of index space they occupy. + for (auto& [group, info] : groupInfos) { + info.useCount /= group.size(); + } -IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { - Counts counts = getHeapTypeCounts(wasm); + // Sort the preds of each group by increasing use count so the topological + // sort visits the most used first. Break ties with the group's appearance + // index to ensure determinism. + auto sortGroups = [&](std::vector<RecGroup>& groups) { + std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) { + return groupInfos.at(a) < groupInfos.at(b); + }); + }; + for (auto& [group, info] : groupInfos) { + info.sortedPreds.insert( + info.sortedPreds.end(), info.preds.begin(), info.preds.end()); + sortGroups(info.sortedPreds); + info.preds.clear(); + } + + // Sort all the groups so the topological sort visits the most used first. + std::vector<RecGroup> sortedGroups; + sortedGroups.reserve(groupInfos.size()); + for (auto& [group, _] : groupInfos) { + sortedGroups.push_back(group); + } + sortGroups(sortedGroups); - // Sort by frequency and then original insertion order. - std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), counts.end()); - std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { - return a.second > b.second; - }); + // Perform the topological sort. + TopologicalSortStack<RecGroup> workStack; + for (auto group : sortedGroups) { + workStack.push(group); + } + sortedGroups.clear(); + while (!workStack.empty()) { + auto group = workStack.peek(); + for (auto pred : groupInfos.at(group).sortedPreds) { + workStack.push(pred); + } + if (workStack.peek() == group) { + // All the predecessors are finished, so `group` is too. + workStack.pop(); + sortedGroups.push_back(group); + } + } - // Collect the results. + // Collect and return the grouped types. IndexedHeapTypes indexedTypes; - for (Index i = 0; i < sorted.size(); ++i) { - indexedTypes.types.push_back(sorted[i].first); + indexedTypes.types.reserve(counts.size()); + for (auto group : sortedGroups) { + for (auto member : group) { + indexedTypes.types.push_back(member); + } } - - // TODO: Explicitly construct a linear extension of the partial order of - // recursion groups by adding edges between unrelated groups according to - // their use counts. - coalesceRecGroups(indexedTypes); setIndices(indexedTypes); return indexedTypes; } diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index c0b4dbcfd..845acda00 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -465,11 +465,8 @@ struct IndexedHeapTypes { }; // Similar to `collectHeapTypes`, but provides fast lookup of the index for each -// type as well. -IndexedHeapTypes getIndexedHeapTypes(Module& wasm); - -// The same as `getIndexedHeapTypes`, but also sorts the types by frequency of -// use to minimize code size. +// type as well. Also orders the types to be valid and sorts the types by +// frequency of use to minimize code size. IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm); } // namespace wasm::ModuleUtils diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 10b8f1bb8..aff67b4df 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -26,7 +26,7 @@ namespace wasm { GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} void GlobalTypeRewriter::update() { - indexedTypes = ModuleUtils::getIndexedHeapTypes(wasm); + indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(wasm); if (indexedTypes.types.empty()) { return; } diff --git a/src/support/insert_ordered.h b/src/support/insert_ordered.h index d4905b945..12c98a3b2 100644 --- a/src/support/insert_ordered.h +++ b/src/support/insert_ordered.h @@ -112,6 +112,8 @@ template<typename Key, typename T> struct InsertOrderedMap { return insert(kv).first->second; } + T& at(const Key& k) { return Map.at(k)->second; } + iterator find(const Key& k) { auto it = Map.find(k); if (it == Map.end()) { diff --git a/src/wasm-type.h b/src/wasm-type.h index ea12e52e3..8c8d85ebf 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -376,7 +376,11 @@ public: static bool isSubType(HeapType left, HeapType right); // Return the ordered HeapType children, looking through child Types. - std::vector<HeapType> getHeapTypeChildren(); + std::vector<HeapType> getHeapTypeChildren() const; + + // Similar to `getHeapTypeChildren`, but also includes the supertype if it + // exists. + std::vector<HeapType> getReferencedHeapTypes() const; // Return the LUB of two HeapTypes. The LUB always exists. static HeapType getLeastUpperBound(HeapType a, HeapType b); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index d05af96a7..89592fc37 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -1403,12 +1403,20 @@ bool HeapType::isSubType(HeapType left, HeapType right) { return SubTyper().isSubType(left, right); } -std::vector<HeapType> HeapType::getHeapTypeChildren() { +std::vector<HeapType> HeapType::getHeapTypeChildren() const { HeapTypeChildCollector collector; - collector.walkRoot(this); + collector.walkRoot(const_cast<HeapType*>(this)); return collector.children; } +std::vector<HeapType> HeapType::getReferencedHeapTypes() const { + auto types = getHeapTypeChildren(); + if (auto super = getSuperType()) { + types.push_back(*super); + } + return types; +} + HeapType HeapType::getLeastUpperBound(HeapType a, HeapType b) { return TypeBounder().getLeastUpperBound(a, b); } @@ -1567,7 +1575,8 @@ bool SubTyper::isSubType(HeapType a, HeapType b) { // Basic HeapTypes are never subtypes of compound HeapTypes. return false; } - if (typeSystem == TypeSystem::Nominal) { + if (typeSystem == TypeSystem::Nominal || + typeSystem == TypeSystem::Isorecursive) { // Subtyping must be declared in a nominal system, not derived from // structure, so we will not recurse. TODO: optimize this search with some // form of caching. @@ -2402,7 +2411,12 @@ size_t RecGroupHasher::hash(HeapType type) const { // an index into a rec group. Only take the rec group identity into account if // the child is not a member of the top-level group because in that case the // group may not be canonicalized yet. - size_t digest = wasm::hash(type.getRecGroupIndex()); + size_t digest = wasm::hash(type.isBasic()); + if (type.isBasic()) { + wasm::rehash(digest, type.getID()); + return digest; + } + wasm::rehash(digest, type.getRecGroupIndex()); auto currGroup = type.getRecGroup(); if (currGroup != group) { wasm::rehash(digest, currGroup.getID()); @@ -2525,6 +2539,9 @@ bool RecGroupEquator::eq(HeapType a, HeapType b) const { // be canonicalized, explicitly check whether `a` and `b` are in the // respective recursion groups of the respective top-level groups we are // comparing, in which case the structure is still equivalent. + if (a.isBasic() || b.isBasic()) { + return a == b; + } if (a.getRecGroupIndex() != b.getRecGroupIndex()) { return false; } @@ -3526,22 +3543,9 @@ void canonicalizeEquirecursive(CanonicalizationState& state) { info->supertype = nullptr; } -#if TIME_CANONICALIZATION - auto start = std::chrono::steady_clock::now(); -#endif - // Canonicalize the shape of the type definition graph. ShapeCanonicalizer minimized(state.results); state.update(minimized.replacements); - -#if TIME_CANONICALIZATION - auto end = std::chrono::steady_clock::now(); - std::cerr << "Shape canonicalization: " - << std::chrono::duration_cast<std::chrono::milliseconds>(end - - start) - .count() - << " ms\n"; -#endif } std::optional<TypeBuilder::Error> @@ -3597,10 +3601,6 @@ canonicalizeNominal(CanonicalizationState& state) { // Nominal types do not require separate canonicalization, so just validate // that their subtyping is correct. -#if TIME_CANONICALIZATION - auto start = std::chrono::steady_clock::now(); -#endif - // Ensure there are no cycles in the subtype graph. This is the classic DFA // algorithm for detecting cycles, but in the form of a simple loop because // each node (type) has at most one child (supertype). @@ -3627,14 +3627,6 @@ canonicalizeNominal(CanonicalizationState& state) { return {*error}; } -#if TIME_CANONICALIZATION - auto end = std::chrono::steady_clock::now(); - std::cerr << "Validating subtyping took " - << std::chrono::duration_cast<std::chrono::milliseconds>(end - - start) - .count() - << " ms\n"; -#endif return {}; } @@ -3802,6 +3794,15 @@ TypeBuilder::BuildResult TypeBuilder::build() { state.dump(); #endif +#if TIME_CANONICALIZATION + using instant_t = std::chrono::time_point<std::chrono::steady_clock>; + auto getMillis = [&](instant_t start, instant_t end) { + return std::chrono::duration_cast<std::chrono::milliseconds>(end - start) + .count(); + }; + auto start = std::chrono::steady_clock::now(); +#endif + switch (typeSystem) { case TypeSystem::Equirecursive: canonicalizeEquirecursive(state); @@ -3819,18 +3820,19 @@ TypeBuilder::BuildResult TypeBuilder::build() { } #if TIME_CANONICALIZATION - auto start = std::chrono::steady_clock::now(); + auto afterStructureCanonicalization = std::chrono::steady_clock::now(); #endif globallyCanonicalize(state); #if TIME_CANONICALIZATION auto end = std::chrono::steady_clock::now(); - std::cerr << "Global canonicalization took " - << std::chrono::duration_cast<std::chrono::milliseconds>(end - - start) - .count() + std::cerr << "Total canonicalization time was " << getMillis(start, end) << " ms\n"; + std::cerr << "Structure canonicalization took " + << getMillis(start, afterStructureCanonicalization) << " ms\n"; + std::cerr << "Global canonicalization took " + << getMillis(afterStructureCanonicalization, end) << " ms\n"; #endif // Note built signature types. See comment in `HeapType::HeapType(Signature)`. diff --git a/test/lit/isorecursive-output-ordering.wast b/test/lit/isorecursive-output-ordering.wast new file mode 100644 index 000000000..6aaa7ba39 --- /dev/null +++ b/test/lit/isorecursive-output-ordering.wast @@ -0,0 +1,100 @@ +;; TODO: Autogenerate these checks! The current script cannot handle `rec`. + +;; RUN: foreach %s %t wasm-opt -all --hybrid -S -o - | filecheck %s + +(module + ;; Test that we order groups by average uses. + + ;; CHECK: (rec + ;; CHECK-NEXT: (type $unused-6 (struct_subtype data)) + ;; CHECK-NEXT: (type $used-a-bit (struct_subtype data)) + ;; CHECK-NEXT: ) + + ;; CHECK-NEXT: (rec + ;; CHECK-NEXT: (type $unused-1 (struct_subtype data)) + ;; CHECK-NEXT: (type $unused-2 (struct_subtype data)) + ;; CHECK-NEXT: (type $unused-3 (struct_subtype data)) + ;; CHECK-NEXT: (type $unused-4 (struct_subtype data)) + ;; CHECK-NEXT: (type $used-a-lot (struct_subtype data)) + ;; CHECK-NEXT: (type $unused-5 (struct_subtype data)) + ;; CHECK-NEXT: ) + + (rec + (type $unused-1 (struct_subtype data)) + (type $unused-2 (struct_subtype data)) + (type $unused-3 (struct_subtype data)) + (type $unused-4 (struct_subtype data)) + (type $used-a-lot (struct_subtype data)) + (type $unused-5 (struct_subtype data)) + ) + + (rec + (type $unused-6 (struct_subtype data)) + (type $used-a-bit (struct_subtype data)) + ) + + (func $use (param (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot)) (result (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit)) + (unreachable) + ) +) + +(module + ;; Test that we respect dependencies between groups before considering counts. + + ;; CHECK: (rec + ;; CHECK-NEXT: (type $leaf (struct_subtype data)) + ;; CHECK-NEXT: (type $unused (struct_subtype data)) + ;; CHECK-NEXT: ) + + ;; CHECK-NEXT: (rec + ;; CHECK-NEXT: (type $shrub (struct_subtype $leaf)) + ;; CHECK-NEXT: (type $used-a-ton (struct_subtype data)) + ;; CHECK-NEXT: ) + + ;; CHECK-NEXT: (rec + ;; CHECK-NEXT: (type $twig (struct_subtype data)) + ;; CHECK-NEXT: (type $used-a-bit (struct_subtype (field (ref $leaf)) data)) + ;; CHECK-NEXT: ) + + ;; CHECK-NEXT: (rec + ;; CHECK-NEXT: (type $root (struct_subtype data)) + ;; CHECK-NEXT: (type $used-a-lot (struct_subtype $twig)) + ;; CHECK-NEXT: ) + + (rec + (type $leaf (struct_subtype data)) + (type $unused (struct_subtype data)) + ) + + (rec + (type $twig (struct_subtype data)) + (type $used-a-bit (struct_subtype (ref $leaf) data)) + ) + + (rec + (type $shrub (struct_subtype $leaf)) + (type $used-a-ton (struct_subtype data)) + ) + + (rec + (type $root (struct_subtype data)) + (type $used-a-lot (struct_subtype $twig)) + ) + + (func $use (param (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot)) (result (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit)) + (local (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton)) + (unreachable) + ) +) + +(module + ;; Test that basic heap type children do not trigger assertions. + + (rec + (type $contains-basic (struct_subtype (ref any) data)) + ) + + (func $use (param (ref $contains-basic)) + (unreachable) + ) +) |