summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ir/module-utils.cpp217
-rw-r--r--src/ir/module-utils.h7
-rw-r--r--src/ir/type-updating.cpp2
-rw-r--r--src/support/insert_ordered.h2
-rw-r--r--src/wasm-type.h6
-rw-r--r--src/wasm/wasm-type.cpp70
-rw-r--r--test/lit/isorecursive-output-ordering.wast100
7 files changed, 307 insertions, 97 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index 0d0ec531c..723675c85 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -33,6 +33,12 @@ struct Counts : public InsertOrderedMap<HeapType, size_t> {
note(ht);
}
}
+ // Ensure a type is included without increasing its count.
+ void include(HeapType type) {
+ if (!type.isBasic()) {
+ (*this)[type];
+ }
+ }
};
Counts getHeapTypeCounts(Module& wasm) {
@@ -128,14 +134,16 @@ Counts getHeapTypeCounts(Module& wasm) {
// Recursively traverse each reference type, which may have a child type that
// is itself a reference type. This reflects an appearance in the binary
- // format that is in the type section itself.
- // As we do this we may find more and more types, as nested children of
- // previous ones. Each such type will appear in the type section once, so
- // we just need to visit it once.
+ // format that is in the type section itself. As we do this we may find more
+ // and more types, as nested children of previous ones. Each such type will
+ // appear in the type section once, so we just need to visit it once. Also
+ // track which recursion groups we've already processed to avoid quadratic
+ // behavior when there is a single large group.
InsertOrderedSet<HeapType> newTypes;
for (auto& [type, _] : counts) {
newTypes.insert(type);
}
+ std::unordered_set<RecGroup> includedGroups;
while (!newTypes.empty()) {
auto iter = newTypes.begin();
auto ht = *iter;
@@ -156,16 +164,18 @@ Counts getHeapTypeCounts(Module& wasm) {
// is in flux, skip counting them to keep the type orderings in nominal
// test outputs more similar to the orderings in the equirecursive
// outputs. FIXME
- counts.note(*super);
+ counts.include(*super);
}
}
// Make sure we've noted the complete recursion group of each type as well.
auto recGroup = ht.getRecGroup();
- for (auto type : recGroup) {
- if (!counts.count(type)) {
- newTypes.insert(type);
- counts.note(type);
+ if (includedGroups.insert(recGroup).second) {
+ for (auto type : recGroup) {
+ if (!counts.count(type)) {
+ newTypes.insert(type);
+ counts.include(type);
+ }
}
}
}
@@ -173,37 +183,52 @@ Counts getHeapTypeCounts(Module& wasm) {
return counts;
}
-void coalesceRecGroups(IndexedHeapTypes& indexedTypes) {
- if (getTypeSystem() != TypeSystem::Isorecursive) {
- // No rec groups to coalesce.
- return;
- }
-
- // TODO: Perform a topological sort of the recursion groups to create a valid
- // ordering rather than this hack that just gets all the types in a group to
- // be adjacent.
- assert(indexedTypes.indices.empty());
- std::unordered_set<HeapType> seen;
- std::vector<HeapType> grouped;
- grouped.reserve(indexedTypes.types.size());
- for (auto type : indexedTypes.types) {
- if (seen.insert(type).second) {
- for (auto member : type.getRecGroup()) {
- grouped.push_back(member);
- seen.insert(member);
- }
- }
- }
- assert(grouped.size() == indexedTypes.types.size());
- indexedTypes.types = grouped;
-}
-
void setIndices(IndexedHeapTypes& indexedTypes) {
for (Index i = 0; i < indexedTypes.types.size(); i++) {
indexedTypes.indices[indexedTypes.types[i]] = i;
}
}
+// A utility data structure for performing topological sorts. The elements to be
+// sorted should all be `push`ed to the stack, then while the stack is not
+// empty, predecessors of the top element should be pushed. On each iteration,
+// if the top of the stack is unchanged after pushing all the predecessors, that
+// means the predecessors have all been finished so the current element is
+// finished as well and should be popped.
+template<typename T> struct TopologicalSortStack {
+ std::list<T> workStack;
+ // Map items to their locations in the work stack so that we can make sure
+ // each appears and is finished only once.
+ std::unordered_map<T, typename std::list<T>::iterator> locations;
+ // Remember which items we have finished so we don't visit them again.
+ std::unordered_set<T> finished;
+
+ bool empty() const { return workStack.empty(); }
+
+ void push(T item) {
+ if (finished.count(item)) {
+ return;
+ }
+ workStack.push_back(item);
+ auto newLocation = std::prev(workStack.end());
+ auto [it, inserted] = locations.insert({item, newLocation});
+ if (!inserted) {
+ // Remove the previous iteration of the pushed item and update its
+ // location.
+ workStack.erase(it->second);
+ it->second = newLocation;
+ }
+ }
+
+ T& peek() { return workStack.back(); }
+
+ void pop() {
+ finished.insert(workStack.back());
+ locations.erase(workStack.back());
+ workStack.pop_back();
+ }
+};
+
} // anonymous namespace
std::vector<HeapType> collectHeapTypes(Module& wasm) {
@@ -216,37 +241,117 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) {
return types;
}
-IndexedHeapTypes getIndexedHeapTypes(Module& wasm) {
+IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
Counts counts = getHeapTypeCounts(wasm);
- IndexedHeapTypes indexedTypes;
+
+ if (getTypeSystem() != TypeSystem::Isorecursive) {
+ // Sort by frequency and then original insertion order.
+ std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(),
+ counts.end());
+ std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
+ return a.second > b.second;
+ });
+
+ // Collect the results.
+ IndexedHeapTypes indexedTypes;
+ for (Index i = 0; i < sorted.size(); ++i) {
+ indexedTypes.types.push_back(sorted[i].first);
+ }
+
+ setIndices(indexedTypes);
+ return indexedTypes;
+ }
+
+ // Isorecursive types have to be arranged into topologically ordered recursion
+ // groups. Sort the groups by average use count among their members so that
+ // the topological sort will place frequently used types first.
+ struct GroupInfo {
+ size_t index;
+ double useCount = 0;
+ std::unordered_set<RecGroup> preds;
+ std::vector<RecGroup> sortedPreds;
+ GroupInfo(size_t index) : index(index) {}
+ bool operator<(const GroupInfo& other) const {
+ if (useCount != other.useCount) {
+ return useCount < other.useCount;
+ }
+ return index < other.index;
+ }
+ };
+
+ std::unordered_map<RecGroup, GroupInfo> groupInfos;
for (auto& [type, _] : counts) {
- indexedTypes.types.push_back(type);
+ RecGroup group = type.getRecGroup();
+ // Try to initialize a new info or get the existing info.
+ auto& info = groupInfos.insert({group, {groupInfos.size()}}).first->second;
+ // Update the reference count.
+ info.useCount += counts.at(type);
+ // Collect predecessor groups.
+ for (auto child : type.getReferencedHeapTypes()) {
+ if (!child.isBasic()) {
+ RecGroup otherGroup = child.getRecGroup();
+ if (otherGroup != group) {
+ info.preds.insert(otherGroup);
+ }
+ }
+ }
}
- coalesceRecGroups(indexedTypes);
- setIndices(indexedTypes);
- return indexedTypes;
-}
+ // Fix up the use counts to be averages to ensure groups are used comensurate
+ // with the amount of index space they occupy.
+ for (auto& [group, info] : groupInfos) {
+ info.useCount /= group.size();
+ }
-IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
- Counts counts = getHeapTypeCounts(wasm);
+ // Sort the preds of each group by increasing use count so the topological
+ // sort visits the most used first. Break ties with the group's appearance
+ // index to ensure determinism.
+ auto sortGroups = [&](std::vector<RecGroup>& groups) {
+ std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) {
+ return groupInfos.at(a) < groupInfos.at(b);
+ });
+ };
+ for (auto& [group, info] : groupInfos) {
+ info.sortedPreds.insert(
+ info.sortedPreds.end(), info.preds.begin(), info.preds.end());
+ sortGroups(info.sortedPreds);
+ info.preds.clear();
+ }
+
+ // Sort all the groups so the topological sort visits the most used first.
+ std::vector<RecGroup> sortedGroups;
+ sortedGroups.reserve(groupInfos.size());
+ for (auto& [group, _] : groupInfos) {
+ sortedGroups.push_back(group);
+ }
+ sortGroups(sortedGroups);
- // Sort by frequency and then original insertion order.
- std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), counts.end());
- std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
- return a.second > b.second;
- });
+ // Perform the topological sort.
+ TopologicalSortStack<RecGroup> workStack;
+ for (auto group : sortedGroups) {
+ workStack.push(group);
+ }
+ sortedGroups.clear();
+ while (!workStack.empty()) {
+ auto group = workStack.peek();
+ for (auto pred : groupInfos.at(group).sortedPreds) {
+ workStack.push(pred);
+ }
+ if (workStack.peek() == group) {
+ // All the predecessors are finished, so `group` is too.
+ workStack.pop();
+ sortedGroups.push_back(group);
+ }
+ }
- // Collect the results.
+ // Collect and return the grouped types.
IndexedHeapTypes indexedTypes;
- for (Index i = 0; i < sorted.size(); ++i) {
- indexedTypes.types.push_back(sorted[i].first);
+ indexedTypes.types.reserve(counts.size());
+ for (auto group : sortedGroups) {
+ for (auto member : group) {
+ indexedTypes.types.push_back(member);
+ }
}
-
- // TODO: Explicitly construct a linear extension of the partial order of
- // recursion groups by adding edges between unrelated groups according to
- // their use counts.
- coalesceRecGroups(indexedTypes);
setIndices(indexedTypes);
return indexedTypes;
}
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index c0b4dbcfd..845acda00 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -465,11 +465,8 @@ struct IndexedHeapTypes {
};
// Similar to `collectHeapTypes`, but provides fast lookup of the index for each
-// type as well.
-IndexedHeapTypes getIndexedHeapTypes(Module& wasm);
-
-// The same as `getIndexedHeapTypes`, but also sorts the types by frequency of
-// use to minimize code size.
+// type as well. Also orders the types to be valid and sorts the types by
+// frequency of use to minimize code size.
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm);
} // namespace wasm::ModuleUtils
diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp
index 10b8f1bb8..aff67b4df 100644
--- a/src/ir/type-updating.cpp
+++ b/src/ir/type-updating.cpp
@@ -26,7 +26,7 @@ namespace wasm {
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}
void GlobalTypeRewriter::update() {
- indexedTypes = ModuleUtils::getIndexedHeapTypes(wasm);
+ indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(wasm);
if (indexedTypes.types.empty()) {
return;
}
diff --git a/src/support/insert_ordered.h b/src/support/insert_ordered.h
index d4905b945..12c98a3b2 100644
--- a/src/support/insert_ordered.h
+++ b/src/support/insert_ordered.h
@@ -112,6 +112,8 @@ template<typename Key, typename T> struct InsertOrderedMap {
return insert(kv).first->second;
}
+ T& at(const Key& k) { return Map.at(k)->second; }
+
iterator find(const Key& k) {
auto it = Map.find(k);
if (it == Map.end()) {
diff --git a/src/wasm-type.h b/src/wasm-type.h
index ea12e52e3..8c8d85ebf 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -376,7 +376,11 @@ public:
static bool isSubType(HeapType left, HeapType right);
// Return the ordered HeapType children, looking through child Types.
- std::vector<HeapType> getHeapTypeChildren();
+ std::vector<HeapType> getHeapTypeChildren() const;
+
+ // Similar to `getHeapTypeChildren`, but also includes the supertype if it
+ // exists.
+ std::vector<HeapType> getReferencedHeapTypes() const;
// Return the LUB of two HeapTypes. The LUB always exists.
static HeapType getLeastUpperBound(HeapType a, HeapType b);
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index d05af96a7..89592fc37 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -1403,12 +1403,20 @@ bool HeapType::isSubType(HeapType left, HeapType right) {
return SubTyper().isSubType(left, right);
}
-std::vector<HeapType> HeapType::getHeapTypeChildren() {
+std::vector<HeapType> HeapType::getHeapTypeChildren() const {
HeapTypeChildCollector collector;
- collector.walkRoot(this);
+ collector.walkRoot(const_cast<HeapType*>(this));
return collector.children;
}
+std::vector<HeapType> HeapType::getReferencedHeapTypes() const {
+ auto types = getHeapTypeChildren();
+ if (auto super = getSuperType()) {
+ types.push_back(*super);
+ }
+ return types;
+}
+
HeapType HeapType::getLeastUpperBound(HeapType a, HeapType b) {
return TypeBounder().getLeastUpperBound(a, b);
}
@@ -1567,7 +1575,8 @@ bool SubTyper::isSubType(HeapType a, HeapType b) {
// Basic HeapTypes are never subtypes of compound HeapTypes.
return false;
}
- if (typeSystem == TypeSystem::Nominal) {
+ if (typeSystem == TypeSystem::Nominal ||
+ typeSystem == TypeSystem::Isorecursive) {
// Subtyping must be declared in a nominal system, not derived from
// structure, so we will not recurse. TODO: optimize this search with some
// form of caching.
@@ -2402,7 +2411,12 @@ size_t RecGroupHasher::hash(HeapType type) const {
// an index into a rec group. Only take the rec group identity into account if
// the child is not a member of the top-level group because in that case the
// group may not be canonicalized yet.
- size_t digest = wasm::hash(type.getRecGroupIndex());
+ size_t digest = wasm::hash(type.isBasic());
+ if (type.isBasic()) {
+ wasm::rehash(digest, type.getID());
+ return digest;
+ }
+ wasm::rehash(digest, type.getRecGroupIndex());
auto currGroup = type.getRecGroup();
if (currGroup != group) {
wasm::rehash(digest, currGroup.getID());
@@ -2525,6 +2539,9 @@ bool RecGroupEquator::eq(HeapType a, HeapType b) const {
// be canonicalized, explicitly check whether `a` and `b` are in the
// respective recursion groups of the respective top-level groups we are
// comparing, in which case the structure is still equivalent.
+ if (a.isBasic() || b.isBasic()) {
+ return a == b;
+ }
if (a.getRecGroupIndex() != b.getRecGroupIndex()) {
return false;
}
@@ -3526,22 +3543,9 @@ void canonicalizeEquirecursive(CanonicalizationState& state) {
info->supertype = nullptr;
}
-#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
-#endif
-
// Canonicalize the shape of the type definition graph.
ShapeCanonicalizer minimized(state.results);
state.update(minimized.replacements);
-
-#if TIME_CANONICALIZATION
- auto end = std::chrono::steady_clock::now();
- std::cerr << "Shape canonicalization: "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
- << " ms\n";
-#endif
}
std::optional<TypeBuilder::Error>
@@ -3597,10 +3601,6 @@ canonicalizeNominal(CanonicalizationState& state) {
// Nominal types do not require separate canonicalization, so just validate
// that their subtyping is correct.
-#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
-#endif
-
// Ensure there are no cycles in the subtype graph. This is the classic DFA
// algorithm for detecting cycles, but in the form of a simple loop because
// each node (type) has at most one child (supertype).
@@ -3627,14 +3627,6 @@ canonicalizeNominal(CanonicalizationState& state) {
return {*error};
}
-#if TIME_CANONICALIZATION
- auto end = std::chrono::steady_clock::now();
- std::cerr << "Validating subtyping took "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
- << " ms\n";
-#endif
return {};
}
@@ -3802,6 +3794,15 @@ TypeBuilder::BuildResult TypeBuilder::build() {
state.dump();
#endif
+#if TIME_CANONICALIZATION
+ using instant_t = std::chrono::time_point<std::chrono::steady_clock>;
+ auto getMillis = [&](instant_t start, instant_t end) {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
+ .count();
+ };
+ auto start = std::chrono::steady_clock::now();
+#endif
+
switch (typeSystem) {
case TypeSystem::Equirecursive:
canonicalizeEquirecursive(state);
@@ -3819,18 +3820,19 @@ TypeBuilder::BuildResult TypeBuilder::build() {
}
#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
+ auto afterStructureCanonicalization = std::chrono::steady_clock::now();
#endif
globallyCanonicalize(state);
#if TIME_CANONICALIZATION
auto end = std::chrono::steady_clock::now();
- std::cerr << "Global canonicalization took "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
+ std::cerr << "Total canonicalization time was " << getMillis(start, end)
<< " ms\n";
+ std::cerr << "Structure canonicalization took "
+ << getMillis(start, afterStructureCanonicalization) << " ms\n";
+ std::cerr << "Global canonicalization took "
+ << getMillis(afterStructureCanonicalization, end) << " ms\n";
#endif
// Note built signature types. See comment in `HeapType::HeapType(Signature)`.
diff --git a/test/lit/isorecursive-output-ordering.wast b/test/lit/isorecursive-output-ordering.wast
new file mode 100644
index 000000000..6aaa7ba39
--- /dev/null
+++ b/test/lit/isorecursive-output-ordering.wast
@@ -0,0 +1,100 @@
+;; TODO: Autogenerate these checks! The current script cannot handle `rec`.
+
+;; RUN: foreach %s %t wasm-opt -all --hybrid -S -o - | filecheck %s
+
+(module
+ ;; Test that we order groups by average uses.
+
+ ;; CHECK: (rec
+ ;; CHECK-NEXT: (type $unused-6 (struct_subtype data))
+ ;; CHECK-NEXT: (type $used-a-bit (struct_subtype data))
+ ;; CHECK-NEXT: )
+
+ ;; CHECK-NEXT: (rec
+ ;; CHECK-NEXT: (type $unused-1 (struct_subtype data))
+ ;; CHECK-NEXT: (type $unused-2 (struct_subtype data))
+ ;; CHECK-NEXT: (type $unused-3 (struct_subtype data))
+ ;; CHECK-NEXT: (type $unused-4 (struct_subtype data))
+ ;; CHECK-NEXT: (type $used-a-lot (struct_subtype data))
+ ;; CHECK-NEXT: (type $unused-5 (struct_subtype data))
+ ;; CHECK-NEXT: )
+
+ (rec
+ (type $unused-1 (struct_subtype data))
+ (type $unused-2 (struct_subtype data))
+ (type $unused-3 (struct_subtype data))
+ (type $unused-4 (struct_subtype data))
+ (type $used-a-lot (struct_subtype data))
+ (type $unused-5 (struct_subtype data))
+ )
+
+ (rec
+ (type $unused-6 (struct_subtype data))
+ (type $used-a-bit (struct_subtype data))
+ )
+
+ (func $use (param (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot)) (result (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit))
+ (unreachable)
+ )
+)
+
+(module
+ ;; Test that we respect dependencies between groups before considering counts.
+
+ ;; CHECK: (rec
+ ;; CHECK-NEXT: (type $leaf (struct_subtype data))
+ ;; CHECK-NEXT: (type $unused (struct_subtype data))
+ ;; CHECK-NEXT: )
+
+ ;; CHECK-NEXT: (rec
+ ;; CHECK-NEXT: (type $shrub (struct_subtype $leaf))
+ ;; CHECK-NEXT: (type $used-a-ton (struct_subtype data))
+ ;; CHECK-NEXT: )
+
+ ;; CHECK-NEXT: (rec
+ ;; CHECK-NEXT: (type $twig (struct_subtype data))
+ ;; CHECK-NEXT: (type $used-a-bit (struct_subtype (field (ref $leaf)) data))
+ ;; CHECK-NEXT: )
+
+ ;; CHECK-NEXT: (rec
+ ;; CHECK-NEXT: (type $root (struct_subtype data))
+ ;; CHECK-NEXT: (type $used-a-lot (struct_subtype $twig))
+ ;; CHECK-NEXT: )
+
+ (rec
+ (type $leaf (struct_subtype data))
+ (type $unused (struct_subtype data))
+ )
+
+ (rec
+ (type $twig (struct_subtype data))
+ (type $used-a-bit (struct_subtype (ref $leaf) data))
+ )
+
+ (rec
+ (type $shrub (struct_subtype $leaf))
+ (type $used-a-ton (struct_subtype data))
+ )
+
+ (rec
+ (type $root (struct_subtype data))
+ (type $used-a-lot (struct_subtype $twig))
+ )
+
+ (func $use (param (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot) (ref $used-a-lot)) (result (ref $used-a-bit) (ref $used-a-bit) (ref $used-a-bit))
+ (local (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton) (ref null $used-a-ton))
+ (unreachable)
+ )
+)
+
+(module
+ ;; Test that basic heap type children do not trigger assertions.
+
+ (rec
+ (type $contains-basic (struct_subtype (ref any) data))
+ )
+
+ (func $use (param (ref $contains-basic))
+ (unreachable)
+ )
+)