diff options
author | Thomas Lively <tlively@google.com> | 2024-09-05 13:09:42 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-05 13:09:42 -0700 |
commit | fbbdc64a5bf69d47bab5a33b6ec148e9b79a6a84 (patch) | |
tree | 3c410ee1342ab0854e3d5c38ebd753d8b624e809 | |
parent | 7562a6b7da0ba517a46db1f1792677f3ebbf2a27 (diff) | |
download | binaryen-fbbdc64a5bf69d47bab5a33b6ec148e9b79a6a84.tar.gz binaryen-fbbdc64a5bf69d47bab5a33b6ec148e9b79a6a84.tar.bz2 binaryen-fbbdc64a5bf69d47bab5a33b6ec148e9b79a6a84.zip |
[NFC] Add a more powerful API for collecting heap types (#6904)
Many passes need to know both the set of all used types and also the
sets of private or public types. Previously there was no API to get both
at once, so getting both required two API calls that internally
collected all the types twice.
Furthermore, there are many reasons to collect heap types, and they have
different requirements about precisely which types need to be collected.
For example, in some edge cases the IR can reference heap types that do
not need to be emitted into a binary; passes that replace all types
would need to collect these types, but the binary writer would not. The
existing APIs for collecting types did not distinguish between these use
cases, so the code conservatively collected extra types that were not
always needed.
Refactor the type collecting code to expose a new API that takes a
description of which types need to be collected and returns the
appropriate types, their use counts, and optionally whether they are
each public or private.
Keep this change non-functional by commenting on places where the code
could be cleaned up or improved rather than actually making the changes.
Follow-up PRs will implement the improvements, which will necessarily
come with test changes.
-rw-r--r-- | src/ir/module-utils.cpp | 263 | ||||
-rw-r--r-- | src/ir/module-utils.h | 30 |
2 files changed, 177 insertions, 116 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index bd73a2be9..e0bd9592a 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -304,8 +304,8 @@ void renameFunction(Module& wasm, Name oldName, Name newName) { namespace { // Helper for collecting HeapTypes and their frequencies. -struct Counts { - InsertOrderedMap<HeapType, size_t> counts; +struct TypeInfos { + InsertOrderedMap<HeapType, HeapTypeInfo> info; // Multivalue control flow structures need a function type, but the identity // of the function type (i.e. what recursion group it is in or whether it is @@ -315,7 +315,7 @@ struct Counts { void note(HeapType type) { if (!type.isBasic()) { - counts[type]++; + ++info[type].useCount; } } void note(Type type) { @@ -326,7 +326,7 @@ struct Counts { // Ensure a type is included without increasing its count. void include(HeapType type) { if (!type.isBasic()) { - counts[type]; + info[type]; } } void include(Type type) { @@ -339,142 +339,143 @@ struct Counts { assert(sig.params.size() == 0); if (sig.results.isTuple()) { // We have to use a function type. - controlFlowSignatures[sig]++; + ++controlFlowSignatures[sig]; } else if (sig.results != Type::none) { // The result type can be emitted directly instead of using a function // type. - note(sig.results[0]); + note(sig.results); } } + bool contains(HeapType type) { return info.count(type); } }; struct CodeScanner : PostWalker<CodeScanner, UnifiedExpressionVisitor<CodeScanner>> { - Counts& counts; + TypeInfos& info; + TypeInclusion inclusion; - CodeScanner(Module& wasm, Counts& counts) : counts(counts) { + CodeScanner(Module& wasm, TypeInfos& info, TypeInclusion inclusion) + : info(info), inclusion(inclusion) { setModule(&wasm); } void visitExpression(Expression* curr) { if (auto* call = curr->dynCast<CallIndirect>()) { - counts.note(call->heapType); + info.note(call->heapType); } else if (auto* call = curr->dynCast<CallRef>()) { - counts.note(call->target->type); + info.note(call->target->type); } else if (curr->is<RefNull>()) { - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<Select>() && curr->type.isRef()) { // This select will be annotated in the binary, so note it. - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<StructNew>()) { - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<ArrayNew>()) { - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<ArrayNewData>()) { - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<ArrayNewElem>()) { - counts.note(curr->type); + info.note(curr->type); } else if (curr->is<ArrayNewFixed>()) { - counts.note(curr->type); + info.note(curr->type); } else if (auto* copy = curr->dynCast<ArrayCopy>()) { - counts.note(copy->destRef->type); - counts.note(copy->srcRef->type); + info.note(copy->destRef->type); + info.note(copy->srcRef->type); } else if (auto* fill = curr->dynCast<ArrayFill>()) { - counts.note(fill->ref->type); + info.note(fill->ref->type); } else if (auto* init = curr->dynCast<ArrayInitData>()) { - counts.note(init->ref->type); + info.note(init->ref->type); } else if (auto* init = curr->dynCast<ArrayInitElem>()) { - counts.note(init->ref->type); + info.note(init->ref->type); } else if (auto* cast = curr->dynCast<RefCast>()) { - counts.note(cast->type); + info.note(cast->type); } else if (auto* cast = curr->dynCast<RefTest>()) { - counts.note(cast->castType); + info.note(cast->castType); } else if (auto* cast = curr->dynCast<BrOn>()) { if (cast->op == BrOnCast || cast->op == BrOnCastFail) { - counts.note(cast->ref->type); - counts.note(cast->castType); + info.note(cast->ref->type); + info.note(cast->castType); } } else if (auto* get = curr->dynCast<StructGet>()) { - counts.note(get->ref->type); - // If the type we read is a reference type then we must include it. It is - // not written in the binary format, so it doesn't need to be counted, but - // it does need to be taken into account in the IR (this may be the only - // place this type appears in the entire binary, and we must scan all - // types as the analyses that use us depend on that). TODO: This is kind - // of a hack, so it would be nice to remove. If we could remove it, we - // could also remove some of the pruning logic in getHeapTypeCounts below. - counts.include(get->type); + info.note(get->ref->type); + // TODO: Just include curr->type for AllTypes and UsedIRTypes to avoid + // this special case and to avoid emitting unnecessary types in binaries. + info.include(get->type); } else if (auto* set = curr->dynCast<StructSet>()) { - counts.note(set->ref->type); + info.note(set->ref->type); } else if (auto* get = curr->dynCast<ArrayGet>()) { - counts.note(get->ref->type); - // See note on StructGet above. - counts.include(get->type); + info.note(get->ref->type); + // See above. + info.include(get->type); } else if (auto* set = curr->dynCast<ArraySet>()) { - counts.note(set->ref->type); + info.note(set->ref->type); } else if (auto* contBind = curr->dynCast<ContBind>()) { - counts.note(contBind->contTypeBefore); - counts.note(contBind->contTypeAfter); + info.note(contBind->contTypeBefore); + info.note(contBind->contTypeAfter); } else if (auto* contNew = curr->dynCast<ContNew>()) { - counts.note(contNew->contType); + info.note(contNew->contType); } else if (auto* resume = curr->dynCast<Resume>()) { - counts.note(resume->contType); + info.note(resume->contType); } else if (Properties::isControlFlowStructure(curr)) { - counts.noteControlFlow(Signature(Type::none, curr->type)); + info.noteControlFlow(Signature(Type::none, curr->type)); } } }; -// Count the number of times each heap type that would appear in the binary is -// referenced. If `prune`, exclude types that are never referenced, even though -// a binary would be invalid without them. -InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm, - bool prune = false) { +void classifyTypeVisibility(Module& wasm, + InsertOrderedMap<HeapType, HeapTypeInfo>& types); + +} // anonymous namespace + +InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo( + Module& wasm, TypeInclusion inclusion, VisibilityHandling visibility) { // Collect module-level info. - Counts counts; - CodeScanner(wasm, counts).walkModuleCode(&wasm); + TypeInfos info; + CodeScanner(wasm, info, inclusion).walkModuleCode(&wasm); for (auto& curr : wasm.globals) { - counts.note(curr->type); + info.note(curr->type); } for (auto& curr : wasm.tags) { - counts.note(curr->sig); + info.note(curr->sig); } for (auto& curr : wasm.tables) { - counts.note(curr->type); + info.note(curr->type); } for (auto& curr : wasm.elementSegments) { - counts.note(curr->type); + info.note(curr->type); } // Collect info from functions in parallel. - ModuleUtils::ParallelFunctionAnalysis<Counts, Immutable, InsertOrderedMap> - analysis(wasm, [&](Function* func, Counts& counts) { - counts.note(func->type); + ModuleUtils::ParallelFunctionAnalysis<TypeInfos, Immutable, InsertOrderedMap> + analysis(wasm, [&](Function* func, TypeInfos& info) { + info.note(func->type); for (auto type : func->vars) { - counts.note(type); + info.note(type); } if (!func->imported()) { - CodeScanner(wasm, counts).walk(func->body); + CodeScanner(wasm, info, inclusion).walk(func->body); } }); // Combine the function info with the module info. - for (auto& [_, functionCounts] : analysis.map) { - for (auto& [type, count] : functionCounts.counts) { - counts.counts[type] += count; + for (auto& [_, functionInfo] : analysis.map) { + for (auto& [type, typeInfo] : functionInfo.info) { + info.info[type].useCount += typeInfo.useCount; } - for (auto& [sig, count] : functionCounts.controlFlowSignatures) { - counts.controlFlowSignatures[sig] += count; + for (auto& [sig, count] : functionInfo.controlFlowSignatures) { + info.controlFlowSignatures[sig] += count; } } - if (prune) { - // Remove types that are not actually used. - auto it = counts.counts.begin(); - while (it != counts.counts.end()) { - if (it->second == 0) { + // TODO: Remove this once we remove the hack for StructGet and StructSet in + // CodeScanner. + if (inclusion == TypeInclusion::UsedIRTypes) { + auto it = info.info.begin(); + while (it != info.info.end()) { + if (it->second.useCount == 0) { auto deleted = it++; - counts.counts.erase(deleted); + info.info.erase(deleted); } else { ++it; } @@ -488,6 +489,7 @@ InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm, // appear in the type section once, so we just need to visit it once. Also // track which recursion groups we've already processed to avoid quadratic // behavior when there is a single large group. + // TODO: Use a vector here, since we never try to add the same type twice. UniqueNonrepeatingDeferredQueue<HeapType> newTypes; std::unordered_map<Signature, HeapType> seenSigs; auto noteNewType = [&](HeapType type) { @@ -496,39 +498,43 @@ InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm, seenSigs.insert({type.getSignature(), type}); } }; - for (auto& [type, _] : counts.counts) { + for (auto& [type, _] : info.info) { noteNewType(type); } - auto controlFlowIt = counts.controlFlowSignatures.begin(); + auto controlFlowIt = info.controlFlowSignatures.begin(); std::unordered_set<RecGroup> includedGroups; while (!newTypes.empty()) { while (!newTypes.empty()) { auto ht = newTypes.pop(); + // TODO: Use getReferencedHeapTypes instead and remove separate + // consideration of supertypes below. for (HeapType child : ht.getHeapTypeChildren()) { if (!child.isBasic()) { - if (!counts.counts.count(child)) { + if (!info.contains(child)) { noteNewType(child); } - counts.note(child); + info.note(child); } } if (auto super = ht.getDeclaredSuperType()) { - if (!counts.counts.count(*super)) { + if (!info.contains(*super)) { noteNewType(*super); - counts.note(*super); + // TODO: This should be unconditional for the count to be correct, but + // this will be moot once we use getReferencedHeapTypes above. + info.note(*super); } } // Make sure we've noted the complete recursion group of each type as // well. - if (!prune) { + if (inclusion != TypeInclusion::UsedIRTypes) { auto recGroup = ht.getRecGroup(); if (includedGroups.insert(recGroup).second) { for (auto type : recGroup) { - if (!counts.counts.count(type)) { + if (!info.contains(type)) { noteNewType(type); - counts.include(type); + info.include(type); } } } @@ -537,44 +543,54 @@ InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm, // We've found all the types there are to find without considering more // control flow types. Consider one more control flow type and repeat. - for (; controlFlowIt != counts.controlFlowSignatures.end(); - ++controlFlowIt) { + for (; controlFlowIt != info.controlFlowSignatures.end(); ++controlFlowIt) { auto& [sig, count] = *controlFlowIt; if (auto it = seenSigs.find(sig); it != seenSigs.end()) { - counts.counts[it->second] += count; + info.info[it->second].useCount += count; } else { // We've never seen this signature before, so add a type for it. HeapType type(sig); noteNewType(type); - counts.counts[type] += count; + info.info[type].useCount += count; break; } } } - return counts.counts; -} - -void setIndices(IndexedHeapTypes& indexedTypes) { - for (Index i = 0; i < indexedTypes.types.size(); i++) { - indexedTypes.indices[indexedTypes.types[i]] = i; + if (visibility == VisibilityHandling::FindVisibility) { + classifyTypeVisibility(wasm, info.info); } + + return std::move(info.info); } -InsertOrderedSet<HeapType> getPublicTypeSet(Module& wasm) { - InsertOrderedSet<HeapType> publicTypes; +namespace { + +void classifyTypeVisibility(Module& wasm, + InsertOrderedMap<HeapType, HeapTypeInfo>& types) { + // We will need to traverse the types used by public types and mark them + // public as well. + std::vector<HeapType> workList; auto notePublic = [&](HeapType type) { if (type.isBasic()) { - return; + return false; } // All the rec group members are public as well. + bool inserted = false; for (auto member : type.getRecGroup()) { - if (!publicTypes.insert(member)) { - // We've already inserted this rec group. - break; + if (auto it = types.find(member); it != types.end()) { + if (it->second.visibility == Visibility::Public) { + // Since we mark all elements of a group public at once, if there is a + // member that is already public, all members must already be public. + break; + } + it->second.visibility = Visibility::Public; + workList.push_back(member); + inserted = true; } } + return inserted; }; // TODO: Consider Tags as well, but they should store HeapTypes instead of @@ -633,48 +649,63 @@ InsertOrderedSet<HeapType> getPublicTypeSet(Module& wasm) { } // Find all the other public types reachable from directly publicized types. - std::vector<HeapType> workList(publicTypes.begin(), publicTypes.end()); - while (workList.size()) { + while (!workList.empty()) { auto curr = workList.back(); workList.pop_back(); for (auto t : curr.getReferencedHeapTypes()) { - if (!t.isBasic() && publicTypes.insert(t)) { - workList.push_back(t); - } + notePublic(t); } } - return publicTypes; + for (auto& [_, info] : types) { + if (info.visibility != Visibility::Public) { + info.visibility = Visibility::Private; + } + } + + // TODO: In an open world, we need to consider subtypes of public types public + // as well, or potentially even consider all types to be public unless + // otherwise annotated. +} + +void setIndices(IndexedHeapTypes& indexedTypes) { + for (Index i = 0; i < indexedTypes.types.size(); i++) { + indexedTypes.indices[indexedTypes.types[i]] = i; + } } } // anonymous namespace std::vector<HeapType> collectHeapTypes(Module& wasm) { - auto counts = getHeapTypeCounts(wasm); + auto info = collectHeapTypeInfo(wasm); std::vector<HeapType> types; - types.reserve(counts.size()); - for (auto& [type, _] : counts) { + types.reserve(info.size()); + for (auto& [type, _] : info) { types.push_back(type); } return types; } std::vector<HeapType> getPublicHeapTypes(Module& wasm) { - auto publicTypes = getPublicTypeSet(wasm); + auto info = collectHeapTypeInfo( + wasm, TypeInclusion::BinaryTypes, VisibilityHandling::FindVisibility); std::vector<HeapType> types; - types.reserve(publicTypes.size()); - for (auto type : publicTypes) { - types.push_back(type); + types.reserve(info.size()); + for (auto& [type, typeInfo] : info) { + if (typeInfo.visibility == Visibility::Public) { + types.push_back(type); + } } return types; } std::vector<HeapType> getPrivateHeapTypes(Module& wasm) { - auto usedTypes = getHeapTypeCounts(wasm, true); - auto publicTypes = getPublicTypeSet(wasm); + auto info = collectHeapTypeInfo( + wasm, TypeInclusion::UsedIRTypes, VisibilityHandling::FindVisibility); std::vector<HeapType> types; - for (auto& [type, _] : usedTypes) { - if (!publicTypes.count(type)) { + types.reserve(info.size()); + for (auto& [type, typeInfo] : info) { + if (typeInfo.visibility == Visibility::Private) { types.push_back(type); } } @@ -682,7 +713,7 @@ std::vector<HeapType> getPrivateHeapTypes(Module& wasm) { } IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { - auto counts = getHeapTypeCounts(wasm); + auto counts = collectHeapTypeInfo(wasm, TypeInclusion::BinaryTypes); // Collect the rec groups. std::unordered_map<RecGroup, size_t> groupIndices; @@ -700,7 +731,7 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { for (auto group : groups) { size_t count = 0; for (auto type : group) { - count += counts.at(type); + count += counts.at(type).useCount; } groupCounts.push_back(count); } diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index d9fd69428..46e524165 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -18,6 +18,7 @@ #define wasm_ir_module_h #include "pass.h" +#include "support/insert_ordered.h" #include "support/unique_deferring_queue.h" #include "wasm.h" @@ -442,6 +443,35 @@ template<typename T> struct CallGraphPropertyAnalysis { } }; +// Which types to collect. +// +// AllTypes - Any type anywhere reachable from anything. +// +// UsedIRTypes - Same as AllTypes, but excludes types reachable only because +// they are in a rec group with some other used type and types that are only +// used from other unreachable types. +// +// BinaryTypes - Only types that need to appear in the module's type section. +// +enum class TypeInclusion { AllTypes, UsedIRTypes, BinaryTypes }; + +// Whether to classify collected types as public and private. +enum class VisibilityHandling { NoVisibility, FindVisibility }; + +// Whether a type is public or private. If visibility is not analyzed, the +// visibility will be Unknown instead. +enum class Visibility { Unknown, Public, Private }; + +struct HeapTypeInfo { + Index useCount = 0; + Visibility visibility = Visibility::Unknown; +}; + +InsertOrderedMap<HeapType, HeapTypeInfo> collectHeapTypeInfo( + Module& wasm, + TypeInclusion inclusion = TypeInclusion::AllTypes, + VisibilityHandling visibility = VisibilityHandling::NoVisibility); + // Helper function for collecting all the non-basic heap types used in the // module, i.e. the types that would appear in the type section. std::vector<HeapType> collectHeapTypes(Module& wasm); |