summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Lively <7121787+tlively@users.noreply.github.com>2022-02-02 13:24:02 -0800
committerGitHub <noreply@github.com>2022-02-02 13:24:02 -0800
commit5387c4ec48ba753e44f3b1d92fd43ac366b10b5c (patch)
tree8708ba238bffdc87846bc7690fef6051c3285db4 /src
parent348131b0a383d2946f18923f7acfd963089f4f5d (diff)
downloadbinaryen-5387c4ec48ba753e44f3b1d92fd43ac366b10b5c.tar.gz
binaryen-5387c4ec48ba753e44f3b1d92fd43ac366b10b5c.tar.bz2
binaryen-5387c4ec48ba753e44f3b1d92fd43ac366b10b5c.zip
Topological sorting of types in isorecursive output (#4492)
Generally we try to order types by decreasing use count so that frequently used types get smaller indices. For the equirecursive and nominal systems, there are no contraints on the ordering of types, so we just have to sort them according to their use counts. For the isorecursive type system, however, there are a number of ordering constraints that have to be met for the type section to be valid. First, types in the same recursion group must be adjacent so they can be grouped together. Second, groups must be ordered topologically so that they only refer to types in themselves or prior groups. Update type ordering to produce a valid isorecursive output by performing a topological sort on the recursion groups. While performing the sort, prefer to visit and finish processing the most used groups first as a heuristic to improve the final ordering. Do not reorder types within groups, since doing so would change type identity and could affect the external interface of the module. Leave that reordering to an optimization pass (not yet implemented) that users can explicitly opt in to.
Diffstat (limited to 'src')
-rw-r--r--src/ir/module-utils.cpp217
-rw-r--r--src/ir/module-utils.h7
-rw-r--r--src/ir/type-updating.cpp2
-rw-r--r--src/support/insert_ordered.h2
-rw-r--r--src/wasm-type.h6
-rw-r--r--src/wasm/wasm-type.cpp70
6 files changed, 207 insertions, 97 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index 0d0ec531c..723675c85 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -33,6 +33,12 @@ struct Counts : public InsertOrderedMap<HeapType, size_t> {
note(ht);
}
}
+ // Ensure a type is included without increasing its count.
+ void include(HeapType type) {
+ if (!type.isBasic()) {
+ (*this)[type];
+ }
+ }
};
Counts getHeapTypeCounts(Module& wasm) {
@@ -128,14 +134,16 @@ Counts getHeapTypeCounts(Module& wasm) {
// Recursively traverse each reference type, which may have a child type that
// is itself a reference type. This reflects an appearance in the binary
- // format that is in the type section itself.
- // As we do this we may find more and more types, as nested children of
- // previous ones. Each such type will appear in the type section once, so
- // we just need to visit it once.
+ // format that is in the type section itself. As we do this we may find more
+ // and more types, as nested children of previous ones. Each such type will
+ // appear in the type section once, so we just need to visit it once. Also
+ // track which recursion groups we've already processed to avoid quadratic
+ // behavior when there is a single large group.
InsertOrderedSet<HeapType> newTypes;
for (auto& [type, _] : counts) {
newTypes.insert(type);
}
+ std::unordered_set<RecGroup> includedGroups;
while (!newTypes.empty()) {
auto iter = newTypes.begin();
auto ht = *iter;
@@ -156,16 +164,18 @@ Counts getHeapTypeCounts(Module& wasm) {
// is in flux, skip counting them to keep the type orderings in nominal
// test outputs more similar to the orderings in the equirecursive
// outputs. FIXME
- counts.note(*super);
+ counts.include(*super);
}
}
// Make sure we've noted the complete recursion group of each type as well.
auto recGroup = ht.getRecGroup();
- for (auto type : recGroup) {
- if (!counts.count(type)) {
- newTypes.insert(type);
- counts.note(type);
+ if (includedGroups.insert(recGroup).second) {
+ for (auto type : recGroup) {
+ if (!counts.count(type)) {
+ newTypes.insert(type);
+ counts.include(type);
+ }
}
}
}
@@ -173,37 +183,52 @@ Counts getHeapTypeCounts(Module& wasm) {
return counts;
}
-void coalesceRecGroups(IndexedHeapTypes& indexedTypes) {
- if (getTypeSystem() != TypeSystem::Isorecursive) {
- // No rec groups to coalesce.
- return;
- }
-
- // TODO: Perform a topological sort of the recursion groups to create a valid
- // ordering rather than this hack that just gets all the types in a group to
- // be adjacent.
- assert(indexedTypes.indices.empty());
- std::unordered_set<HeapType> seen;
- std::vector<HeapType> grouped;
- grouped.reserve(indexedTypes.types.size());
- for (auto type : indexedTypes.types) {
- if (seen.insert(type).second) {
- for (auto member : type.getRecGroup()) {
- grouped.push_back(member);
- seen.insert(member);
- }
- }
- }
- assert(grouped.size() == indexedTypes.types.size());
- indexedTypes.types = grouped;
-}
-
void setIndices(IndexedHeapTypes& indexedTypes) {
for (Index i = 0; i < indexedTypes.types.size(); i++) {
indexedTypes.indices[indexedTypes.types[i]] = i;
}
}
+// A utility data structure for performing topological sorts. The elements to be
+// sorted should all be `push`ed to the stack, then while the stack is not
+// empty, predecessors of the top element should be pushed. On each iteration,
+// if the top of the stack is unchanged after pushing all the predecessors, that
+// means the predecessors have all been finished so the current element is
+// finished as well and should be popped.
+template<typename T> struct TopologicalSortStack {
+ std::list<T> workStack;
+ // Map items to their locations in the work stack so that we can make sure
+ // each appears and is finished only once.
+ std::unordered_map<T, typename std::list<T>::iterator> locations;
+ // Remember which items we have finished so we don't visit them again.
+ std::unordered_set<T> finished;
+
+ bool empty() const { return workStack.empty(); }
+
+ void push(T item) {
+ if (finished.count(item)) {
+ return;
+ }
+ workStack.push_back(item);
+ auto newLocation = std::prev(workStack.end());
+ auto [it, inserted] = locations.insert({item, newLocation});
+ if (!inserted) {
+ // Remove the previous iteration of the pushed item and update its
+ // location.
+ workStack.erase(it->second);
+ it->second = newLocation;
+ }
+ }
+
+ T& peek() { return workStack.back(); }
+
+ void pop() {
+ finished.insert(workStack.back());
+ locations.erase(workStack.back());
+ workStack.pop_back();
+ }
+};
+
} // anonymous namespace
std::vector<HeapType> collectHeapTypes(Module& wasm) {
@@ -216,37 +241,117 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) {
return types;
}
-IndexedHeapTypes getIndexedHeapTypes(Module& wasm) {
+IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
Counts counts = getHeapTypeCounts(wasm);
- IndexedHeapTypes indexedTypes;
+
+ if (getTypeSystem() != TypeSystem::Isorecursive) {
+ // Sort by frequency and then original insertion order.
+ std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(),
+ counts.end());
+ std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
+ return a.second > b.second;
+ });
+
+ // Collect the results.
+ IndexedHeapTypes indexedTypes;
+ for (Index i = 0; i < sorted.size(); ++i) {
+ indexedTypes.types.push_back(sorted[i].first);
+ }
+
+ setIndices(indexedTypes);
+ return indexedTypes;
+ }
+
+ // Isorecursive types have to be arranged into topologically ordered recursion
+ // groups. Sort the groups by average use count among their members so that
+ // the topological sort will place frequently used types first.
+ struct GroupInfo {
+ size_t index;
+ double useCount = 0;
+ std::unordered_set<RecGroup> preds;
+ std::vector<RecGroup> sortedPreds;
+ GroupInfo(size_t index) : index(index) {}
+ bool operator<(const GroupInfo& other) const {
+ if (useCount != other.useCount) {
+ return useCount < other.useCount;
+ }
+ return index < other.index;
+ }
+ };
+
+ std::unordered_map<RecGroup, GroupInfo> groupInfos;
for (auto& [type, _] : counts) {
- indexedTypes.types.push_back(type);
+ RecGroup group = type.getRecGroup();
+ // Try to initialize a new info or get the existing info.
+ auto& info = groupInfos.insert({group, {groupInfos.size()}}).first->second;
+ // Update the reference count.
+ info.useCount += counts.at(type);
+ // Collect predecessor groups.
+ for (auto child : type.getReferencedHeapTypes()) {
+ if (!child.isBasic()) {
+ RecGroup otherGroup = child.getRecGroup();
+ if (otherGroup != group) {
+ info.preds.insert(otherGroup);
+ }
+ }
+ }
}
- coalesceRecGroups(indexedTypes);
- setIndices(indexedTypes);
- return indexedTypes;
-}
+ // Fix up the use counts to be averages to ensure groups are used comensurate
+ // with the amount of index space they occupy.
+ for (auto& [group, info] : groupInfos) {
+ info.useCount /= group.size();
+ }
-IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
- Counts counts = getHeapTypeCounts(wasm);
+ // Sort the preds of each group by increasing use count so the topological
+ // sort visits the most used first. Break ties with the group's appearance
+ // index to ensure determinism.
+ auto sortGroups = [&](std::vector<RecGroup>& groups) {
+ std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) {
+ return groupInfos.at(a) < groupInfos.at(b);
+ });
+ };
+ for (auto& [group, info] : groupInfos) {
+ info.sortedPreds.insert(
+ info.sortedPreds.end(), info.preds.begin(), info.preds.end());
+ sortGroups(info.sortedPreds);
+ info.preds.clear();
+ }
+
+ // Sort all the groups so the topological sort visits the most used first.
+ std::vector<RecGroup> sortedGroups;
+ sortedGroups.reserve(groupInfos.size());
+ for (auto& [group, _] : groupInfos) {
+ sortedGroups.push_back(group);
+ }
+ sortGroups(sortedGroups);
- // Sort by frequency and then original insertion order.
- std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), counts.end());
- std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
- return a.second > b.second;
- });
+ // Perform the topological sort.
+ TopologicalSortStack<RecGroup> workStack;
+ for (auto group : sortedGroups) {
+ workStack.push(group);
+ }
+ sortedGroups.clear();
+ while (!workStack.empty()) {
+ auto group = workStack.peek();
+ for (auto pred : groupInfos.at(group).sortedPreds) {
+ workStack.push(pred);
+ }
+ if (workStack.peek() == group) {
+ // All the predecessors are finished, so `group` is too.
+ workStack.pop();
+ sortedGroups.push_back(group);
+ }
+ }
- // Collect the results.
+ // Collect and return the grouped types.
IndexedHeapTypes indexedTypes;
- for (Index i = 0; i < sorted.size(); ++i) {
- indexedTypes.types.push_back(sorted[i].first);
+ indexedTypes.types.reserve(counts.size());
+ for (auto group : sortedGroups) {
+ for (auto member : group) {
+ indexedTypes.types.push_back(member);
+ }
}
-
- // TODO: Explicitly construct a linear extension of the partial order of
- // recursion groups by adding edges between unrelated groups according to
- // their use counts.
- coalesceRecGroups(indexedTypes);
setIndices(indexedTypes);
return indexedTypes;
}
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index c0b4dbcfd..845acda00 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -465,11 +465,8 @@ struct IndexedHeapTypes {
};
// Similar to `collectHeapTypes`, but provides fast lookup of the index for each
-// type as well.
-IndexedHeapTypes getIndexedHeapTypes(Module& wasm);
-
-// The same as `getIndexedHeapTypes`, but also sorts the types by frequency of
-// use to minimize code size.
+// type as well. Also orders the types to be valid and sorts the types by
+// frequency of use to minimize code size.
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm);
} // namespace wasm::ModuleUtils
diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp
index 10b8f1bb8..aff67b4df 100644
--- a/src/ir/type-updating.cpp
+++ b/src/ir/type-updating.cpp
@@ -26,7 +26,7 @@ namespace wasm {
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}
void GlobalTypeRewriter::update() {
- indexedTypes = ModuleUtils::getIndexedHeapTypes(wasm);
+ indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(wasm);
if (indexedTypes.types.empty()) {
return;
}
diff --git a/src/support/insert_ordered.h b/src/support/insert_ordered.h
index d4905b945..12c98a3b2 100644
--- a/src/support/insert_ordered.h
+++ b/src/support/insert_ordered.h
@@ -112,6 +112,8 @@ template<typename Key, typename T> struct InsertOrderedMap {
return insert(kv).first->second;
}
+ T& at(const Key& k) { return Map.at(k)->second; }
+
iterator find(const Key& k) {
auto it = Map.find(k);
if (it == Map.end()) {
diff --git a/src/wasm-type.h b/src/wasm-type.h
index ea12e52e3..8c8d85ebf 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -376,7 +376,11 @@ public:
static bool isSubType(HeapType left, HeapType right);
// Return the ordered HeapType children, looking through child Types.
- std::vector<HeapType> getHeapTypeChildren();
+ std::vector<HeapType> getHeapTypeChildren() const;
+
+ // Similar to `getHeapTypeChildren`, but also includes the supertype if it
+ // exists.
+ std::vector<HeapType> getReferencedHeapTypes() const;
// Return the LUB of two HeapTypes. The LUB always exists.
static HeapType getLeastUpperBound(HeapType a, HeapType b);
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index d05af96a7..89592fc37 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -1403,12 +1403,20 @@ bool HeapType::isSubType(HeapType left, HeapType right) {
return SubTyper().isSubType(left, right);
}
-std::vector<HeapType> HeapType::getHeapTypeChildren() {
+std::vector<HeapType> HeapType::getHeapTypeChildren() const {
HeapTypeChildCollector collector;
- collector.walkRoot(this);
+ collector.walkRoot(const_cast<HeapType*>(this));
return collector.children;
}
+std::vector<HeapType> HeapType::getReferencedHeapTypes() const {
+ auto types = getHeapTypeChildren();
+ if (auto super = getSuperType()) {
+ types.push_back(*super);
+ }
+ return types;
+}
+
HeapType HeapType::getLeastUpperBound(HeapType a, HeapType b) {
return TypeBounder().getLeastUpperBound(a, b);
}
@@ -1567,7 +1575,8 @@ bool SubTyper::isSubType(HeapType a, HeapType b) {
// Basic HeapTypes are never subtypes of compound HeapTypes.
return false;
}
- if (typeSystem == TypeSystem::Nominal) {
+ if (typeSystem == TypeSystem::Nominal ||
+ typeSystem == TypeSystem::Isorecursive) {
// Subtyping must be declared in a nominal system, not derived from
// structure, so we will not recurse. TODO: optimize this search with some
// form of caching.
@@ -2402,7 +2411,12 @@ size_t RecGroupHasher::hash(HeapType type) const {
// an index into a rec group. Only take the rec group identity into account if
// the child is not a member of the top-level group because in that case the
// group may not be canonicalized yet.
- size_t digest = wasm::hash(type.getRecGroupIndex());
+ size_t digest = wasm::hash(type.isBasic());
+ if (type.isBasic()) {
+ wasm::rehash(digest, type.getID());
+ return digest;
+ }
+ wasm::rehash(digest, type.getRecGroupIndex());
auto currGroup = type.getRecGroup();
if (currGroup != group) {
wasm::rehash(digest, currGroup.getID());
@@ -2525,6 +2539,9 @@ bool RecGroupEquator::eq(HeapType a, HeapType b) const {
// be canonicalized, explicitly check whether `a` and `b` are in the
// respective recursion groups of the respective top-level groups we are
// comparing, in which case the structure is still equivalent.
+ if (a.isBasic() || b.isBasic()) {
+ return a == b;
+ }
if (a.getRecGroupIndex() != b.getRecGroupIndex()) {
return false;
}
@@ -3526,22 +3543,9 @@ void canonicalizeEquirecursive(CanonicalizationState& state) {
info->supertype = nullptr;
}
-#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
-#endif
-
// Canonicalize the shape of the type definition graph.
ShapeCanonicalizer minimized(state.results);
state.update(minimized.replacements);
-
-#if TIME_CANONICALIZATION
- auto end = std::chrono::steady_clock::now();
- std::cerr << "Shape canonicalization: "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
- << " ms\n";
-#endif
}
std::optional<TypeBuilder::Error>
@@ -3597,10 +3601,6 @@ canonicalizeNominal(CanonicalizationState& state) {
// Nominal types do not require separate canonicalization, so just validate
// that their subtyping is correct.
-#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
-#endif
-
// Ensure there are no cycles in the subtype graph. This is the classic DFA
// algorithm for detecting cycles, but in the form of a simple loop because
// each node (type) has at most one child (supertype).
@@ -3627,14 +3627,6 @@ canonicalizeNominal(CanonicalizationState& state) {
return {*error};
}
-#if TIME_CANONICALIZATION
- auto end = std::chrono::steady_clock::now();
- std::cerr << "Validating subtyping took "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
- << " ms\n";
-#endif
return {};
}
@@ -3802,6 +3794,15 @@ TypeBuilder::BuildResult TypeBuilder::build() {
state.dump();
#endif
+#if TIME_CANONICALIZATION
+ using instant_t = std::chrono::time_point<std::chrono::steady_clock>;
+ auto getMillis = [&](instant_t start, instant_t end) {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
+ .count();
+ };
+ auto start = std::chrono::steady_clock::now();
+#endif
+
switch (typeSystem) {
case TypeSystem::Equirecursive:
canonicalizeEquirecursive(state);
@@ -3819,18 +3820,19 @@ TypeBuilder::BuildResult TypeBuilder::build() {
}
#if TIME_CANONICALIZATION
- auto start = std::chrono::steady_clock::now();
+ auto afterStructureCanonicalization = std::chrono::steady_clock::now();
#endif
globallyCanonicalize(state);
#if TIME_CANONICALIZATION
auto end = std::chrono::steady_clock::now();
- std::cerr << "Global canonicalization took "
- << std::chrono::duration_cast<std::chrono::milliseconds>(end -
- start)
- .count()
+ std::cerr << "Total canonicalization time was " << getMillis(start, end)
<< " ms\n";
+ std::cerr << "Structure canonicalization took "
+ << getMillis(start, afterStructureCanonicalization) << " ms\n";
+ std::cerr << "Global canonicalization took "
+ << getMillis(afterStructureCanonicalization, end) << " ms\n";
#endif
// Note built signature types. See comment in `HeapType::HeapType(Signature)`.