diff options
author | Alon Zakai <azakai@google.com> | 2022-04-28 11:26:37 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-28 11:26:37 -0700 |
commit | 408b2eb7df01f42157e24c2c58f01c4fa5c6b9d6 (patch) | |
tree | 77877306025388e73ab0e8fb771d9bbb154509f4 /src/passes/RemoveUnusedModuleElements.cpp | |
parent | 7ba0d8377dfaa9fbd24c0e5961fd2795d8349272 (diff) | |
download | binaryen-408b2eb7df01f42157e24c2c58f01c4fa5c6b9d6.tar.gz binaryen-408b2eb7df01f42157e24c2c58f01c4fa5c6b9d6.tar.bz2 binaryen-408b2eb7df01f42157e24c2c58f01c4fa5c6b9d6.zip |
RemoveUnusedModuleElements: Track CallRef/RefFunc more precisely (#4621)
If we see (ref.func $foo) that does not mean that $foo is reachable - we
must also see a (call_ref ..) of the proper type. Only after seeing both should
we mark the function as reachable, which this PR does.
This adds some complexity as we need to track intermediate state as we go,
since we could see the RefFunc before the CallRef or vice versa. We also
need to handle the case of a RefFunc without a CallRef properly: We cannot
remove the function, as the RefFunc must refer to it, but at least we can
empty out the body since we know it is never reached.
This removes an old wasm-opt test which is now superseded by a new lit
test.
On J2Wasm output this removes 3% of all functions, which account for
2.5% of total code size.
Diffstat (limited to 'src/passes/RemoveUnusedModuleElements.cpp')
-rw-r--r-- | src/passes/RemoveUnusedModuleElements.cpp | 94 |
1 files changed, 91 insertions, 3 deletions
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index 8f8fc24a3..2466d6252 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -26,6 +26,7 @@ #include "ir/module-utils.h" #include "ir/utils.h" #include "pass.h" +#include "wasm-builder.h" #include "wasm.h" namespace wasm { @@ -43,6 +44,27 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { std::set<ModuleElement> reachable; bool usesMemory = false; + // The signatures that we have seen a call_ref for. When we see a RefFunc of a + // signature in here, we know it is reachable. + std::unordered_set<HeapType> calledSignatures; + + // All the RefFuncs we've seen, grouped by heap type. When we see a CallRef of + // one of the types here, we know all the RefFuncs corresponding to it are + // reachable. This is the reverse side of calledSignatures: for a function to + // be reached via a reference, we need the combination of a RefFunc of it as + // well as a CallRef of that, and we may see them in any order. (Or, if the + // RefFunc is in a table, we need a CallIndirect, which is handled in the + // table logic.) + // + // After we see a call for a type, we can clear out the entry here for it, as + // we'll have that type in calledSignatures, and so this contains only + // RefFuncs that we have not seen a call for yet, hence "uncalledRefFuncMap." + // + // TODO: We assume a closed world in the GC space atm, but eventually should + // have a flag for that, and when the world is not closed we'd need to + // check for RefFuncs that flow out to exports. + std::unordered_map<HeapType, std::vector<Name>> uncalledRefFuncMap; + ReachabilityAnalyzer(Module* module, const std::vector<ModuleElement>& roots) : module(module) { queue = roots; @@ -105,6 +127,33 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { } void visitCallIndirect(CallIndirect* curr) { maybeAddTable(curr->table); } + void visitCallRef(CallRef* curr) { + // Ignore unreachable code. + if (!curr->target->type.isRef()) { + return; + } + + auto type = curr->target->type.getHeapType(); + + // Call all the functions of that signature. We can then forget about + // them, as this signature will be marked as called. + auto iter = uncalledRefFuncMap.find(type); + if (iter != uncalledRefFuncMap.end()) { + // We must not have a type in both calledSignatures and + // uncalledRefFuncMap: once it is called, we do not track RefFuncs for + // it any more. + assert(calledSignatures.count(type) == 0); + + for (Name target : iter->second) { + maybeAdd(ModuleElement(ModuleElementKind::Function, target)); + } + + uncalledRefFuncMap.erase(iter); + } + + calledSignatures.insert(type); + } + void visitGlobalGet(GlobalGet* curr) { maybeAdd(ModuleElement(ModuleElementKind::Global, curr->name)); } @@ -126,7 +175,19 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { void visitMemorySize(MemorySize* curr) { usesMemory = true; } void visitMemoryGrow(MemoryGrow* curr) { usesMemory = true; } void visitRefFunc(RefFunc* curr) { - maybeAdd(ModuleElement(ModuleElementKind::Function, curr->func)); + auto type = curr->type.getHeapType(); + if (calledSignatures.count(type)) { + // We must not have a type in both calledSignatures and + // uncalledRefFuncMap: once it is called, we do not track RefFuncs for it + // any more. + assert(uncalledRefFuncMap.count(type) == 0); + + // We've seen a RefFunc for this, so it is reachable. + maybeAdd(ModuleElement(ModuleElementKind::Function, curr->func)); + } else { + // We've never seen a CallRef for this, but might see one later. + uncalledRefFuncMap[type].push_back(curr->func); + } } void visitTableGet(TableGet* curr) { maybeAddTable(curr->table); } void visitTableSet(TableSet* curr) { maybeAddTable(curr->table); } @@ -199,15 +260,42 @@ struct RemoveUnusedModuleElements : public Pass { importsMemory = true; } // For now, all functions that can be called indirectly are marked as roots. + // TODO: Compute this based on which ElementSegments are actually reachable, + // and which functions have a call_indirect of the proper type. ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) { roots.emplace_back(ModuleElementKind::Function, name); }); // Compute reachability starting from the root set. ReachabilityAnalyzer analyzer(module, roots); + + // RefFuncs that are never called are a special case: We cannot remove the + // function, since then (ref.func $foo) would not validate. But if we know + // it is never called, at least the contents do not matter, so we can + // empty it out. + std::unordered_set<Name> uncalledRefFuncs; + for (auto& [type, targets] : analyzer.uncalledRefFuncMap) { + for (auto target : targets) { + uncalledRefFuncs.insert(target); + } + } + // Remove unreachable elements. module->removeFunctions([&](Function* curr) { - return analyzer.reachable.count( - ModuleElement(ModuleElementKind::Function, curr->name)) == 0; + if (analyzer.reachable.count( + ModuleElement(ModuleElementKind::Function, curr->name))) { + return false; + } + + if (uncalledRefFuncs.count(curr->name)) { + // See comment above on uncalledRefFuncs. + if (!curr->imported()) { + curr->body = Builder(*module).makeUnreachable(); + } + return false; + } + + // The function is not reached and has no references; remove it. + return true; }); module->removeGlobals([&](Global* curr) { return analyzer.reachable.count( |