summaryrefslogtreecommitdiff
path: root/src/passes/RemoveUnusedModuleElements.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/RemoveUnusedModuleElements.cpp')
-rw-r--r--src/passes/RemoveUnusedModuleElements.cpp94
1 files changed, 91 insertions, 3 deletions
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp
index 8f8fc24a3..2466d6252 100644
--- a/src/passes/RemoveUnusedModuleElements.cpp
+++ b/src/passes/RemoveUnusedModuleElements.cpp
@@ -26,6 +26,7 @@
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "pass.h"
+#include "wasm-builder.h"
#include "wasm.h"
namespace wasm {
@@ -43,6 +44,27 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
std::set<ModuleElement> reachable;
bool usesMemory = false;
+ // The signatures that we have seen a call_ref for. When we see a RefFunc of a
+ // signature in here, we know it is reachable.
+ std::unordered_set<HeapType> calledSignatures;
+
+ // All the RefFuncs we've seen, grouped by heap type. When we see a CallRef of
+ // one of the types here, we know all the RefFuncs corresponding to it are
+ // reachable. This is the reverse side of calledSignatures: for a function to
+ // be reached via a reference, we need the combination of a RefFunc of it as
+ // well as a CallRef of that, and we may see them in any order. (Or, if the
+ // RefFunc is in a table, we need a CallIndirect, which is handled in the
+ // table logic.)
+ //
+ // After we see a call for a type, we can clear out the entry here for it, as
+ // we'll have that type in calledSignatures, and so this contains only
+ // RefFuncs that we have not seen a call for yet, hence "uncalledRefFuncMap."
+ //
+ // TODO: We assume a closed world in the GC space atm, but eventually should
+ // have a flag for that, and when the world is not closed we'd need to
+ // check for RefFuncs that flow out to exports.
+ std::unordered_map<HeapType, std::vector<Name>> uncalledRefFuncMap;
+
ReachabilityAnalyzer(Module* module, const std::vector<ModuleElement>& roots)
: module(module) {
queue = roots;
@@ -105,6 +127,33 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
}
void visitCallIndirect(CallIndirect* curr) { maybeAddTable(curr->table); }
+ void visitCallRef(CallRef* curr) {
+ // Ignore unreachable code.
+ if (!curr->target->type.isRef()) {
+ return;
+ }
+
+ auto type = curr->target->type.getHeapType();
+
+ // Call all the functions of that signature. We can then forget about
+ // them, as this signature will be marked as called.
+ auto iter = uncalledRefFuncMap.find(type);
+ if (iter != uncalledRefFuncMap.end()) {
+ // We must not have a type in both calledSignatures and
+ // uncalledRefFuncMap: once it is called, we do not track RefFuncs for
+ // it any more.
+ assert(calledSignatures.count(type) == 0);
+
+ for (Name target : iter->second) {
+ maybeAdd(ModuleElement(ModuleElementKind::Function, target));
+ }
+
+ uncalledRefFuncMap.erase(iter);
+ }
+
+ calledSignatures.insert(type);
+ }
+
void visitGlobalGet(GlobalGet* curr) {
maybeAdd(ModuleElement(ModuleElementKind::Global, curr->name));
}
@@ -126,7 +175,19 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
void visitMemorySize(MemorySize* curr) { usesMemory = true; }
void visitMemoryGrow(MemoryGrow* curr) { usesMemory = true; }
void visitRefFunc(RefFunc* curr) {
- maybeAdd(ModuleElement(ModuleElementKind::Function, curr->func));
+ auto type = curr->type.getHeapType();
+ if (calledSignatures.count(type)) {
+ // We must not have a type in both calledSignatures and
+ // uncalledRefFuncMap: once it is called, we do not track RefFuncs for it
+ // any more.
+ assert(uncalledRefFuncMap.count(type) == 0);
+
+ // We've seen a RefFunc for this, so it is reachable.
+ maybeAdd(ModuleElement(ModuleElementKind::Function, curr->func));
+ } else {
+ // We've never seen a CallRef for this, but might see one later.
+ uncalledRefFuncMap[type].push_back(curr->func);
+ }
}
void visitTableGet(TableGet* curr) { maybeAddTable(curr->table); }
void visitTableSet(TableSet* curr) { maybeAddTable(curr->table); }
@@ -199,15 +260,42 @@ struct RemoveUnusedModuleElements : public Pass {
importsMemory = true;
}
// For now, all functions that can be called indirectly are marked as roots.
+ // TODO: Compute this based on which ElementSegments are actually reachable,
+ // and which functions have a call_indirect of the proper type.
ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
roots.emplace_back(ModuleElementKind::Function, name);
});
// Compute reachability starting from the root set.
ReachabilityAnalyzer analyzer(module, roots);
+
+ // RefFuncs that are never called are a special case: We cannot remove the
+ // function, since then (ref.func $foo) would not validate. But if we know
+ // it is never called, at least the contents do not matter, so we can
+ // empty it out.
+ std::unordered_set<Name> uncalledRefFuncs;
+ for (auto& [type, targets] : analyzer.uncalledRefFuncMap) {
+ for (auto target : targets) {
+ uncalledRefFuncs.insert(target);
+ }
+ }
+
// Remove unreachable elements.
module->removeFunctions([&](Function* curr) {
- return analyzer.reachable.count(
- ModuleElement(ModuleElementKind::Function, curr->name)) == 0;
+ if (analyzer.reachable.count(
+ ModuleElement(ModuleElementKind::Function, curr->name))) {
+ return false;
+ }
+
+ if (uncalledRefFuncs.count(curr->name)) {
+ // See comment above on uncalledRefFuncs.
+ if (!curr->imported()) {
+ curr->body = Builder(*module).makeUnreachable();
+ }
+ return false;
+ }
+
+ // The function is not reached and has no references; remove it.
+ return true;
});
module->removeGlobals([&](Global* curr) {
return analyzer.reachable.count(