summaryrefslogtreecommitdiff
path: root/src/passes/DeadArgumentElimination.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2024-09-03 13:53:36 -0700
committerGitHub <noreply@github.com>2024-09-03 13:53:36 -0700
commit48d4c55116728823fbdbc1dd13f8f8916ac0b0af (patch)
tree23471599f6e53ff076428f5743adb77bcae32824 /src/passes/DeadArgumentElimination.cpp
parent8125f7ab5124bbe1cd18b2d8be6a3877ba97acdb (diff)
downloadbinaryen-48d4c55116728823fbdbc1dd13f8f8916ac0b0af.tar.gz
binaryen-48d4c55116728823fbdbc1dd13f8f8916ac0b0af.tar.bz2
binaryen-48d4c55116728823fbdbc1dd13f8f8916ac0b0af.zip
[NFC] Avoid repeated work in DeadArgumentElimination scanning (#6869)
This pass may do multiple iterations, and before this PR it scanned the entire module each time. That is simpler than tracking stale data, but it can be quite slow. This PR adds staleness tracking, which makes it over 3x faster (and this can be one of our slowest passes in some cases, so this is significant). To achieve this: * Add a staleness marker on function info. * Rewrite how we track unseen calls. Previously we used atomics in a clever way, * now we just accumulate the data in a simple way (easier for staleness tracking). * Add staleness invalidation in the proper places. * Add a param to localizeCallsTo to allow us to learn when a function is changed. This kind of staleness analysis is usually not worthwhile, but given the 3x plus speedup it seems justified. I fuzzed it directly, and also any staleness bug can lead to validation errors, so normal fuzzing also gives us good coverage here.
Diffstat (limited to 'src/passes/DeadArgumentElimination.cpp')
-rw-r--r--src/passes/DeadArgumentElimination.cpp183
1 files changed, 128 insertions, 55 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index 99a709654..63f2d7fdc 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -56,6 +56,9 @@ namespace wasm {
// Information for a function
struct DAEFunctionInfo {
+ // Whether this needs to be recomputed. This begins as true for the first
+ // computation, and we reset it every time we touch the function.
+ bool stale = true;
// The unused parameters, if any.
SortedVector unusedParams;
// Maps a function name to the calls going to it.
@@ -73,16 +76,17 @@ struct DAEFunctionInfo {
// removed as well.
bool hasTailCalls = false;
std::unordered_set<Name> tailCallees;
- // Whether the function can be called from places that
- // affect what we can do. For now, any call we don't
- // see inhibits our optimizations, but TODO: an export
- // could be worked around by exporting a thunk that
- // adds the parameter.
- // This is atomic so that we can write to it from any function at any time
- // during the parallel analysis phase which is run in DAEScanner.
- std::atomic<bool> hasUnseenCalls;
-
- DAEFunctionInfo() { hasUnseenCalls = false; }
+ // The set of functions that have calls from places that limit what we can do.
+ // For now, any call we don't see inhibits our optimizations, but TODO: an
+ // export could be worked around by exporting a thunk that adds the parameter.
+ //
+ // This is built up in parallel in each function, and combined at the end.
+ std::unordered_set<Name> hasUnseenCalls;
+
+ // Clears all data, which marks us as stale and in need of recomputation.
+ void clear() { *this = DAEFunctionInfo(); }
+
+ void markStale() { stale = true; }
};
using DAEFunctionInfoMap = std::unordered_map<Name, DAEFunctionInfo>;
@@ -97,10 +101,12 @@ struct DAEScanner
DAEScanner(DAEFunctionInfoMap* infoMap) : infoMap(infoMap) {}
+ // The map of all infos for all functions.
DAEFunctionInfoMap* infoMap;
- DAEFunctionInfo* info;
- Index numParams;
+ // The info for the function this instance operates on. We stash this as an
+ // optimization.
+ DAEFunctionInfo* info = nullptr;
void visitCall(Call* curr) {
if (!getModule()->getFunction(curr->target)->imported()) {
@@ -131,33 +137,40 @@ struct DAEScanner
}
void visitRefFunc(RefFunc* curr) {
- // We can't modify another function in parallel.
- assert((*infoMap).count(curr->func));
+ // RefFunc may be visited from either a function, in which case |info| was
+ // set, or module-level code (in which case we use the null function name in
+ // the infoMap).
+ auto* currInfo = info ? info : &(*infoMap)[Name()];
+
// Treat a ref.func as an unseen call, preventing us from changing the
// function's type. If we did change it, it could be an observable
// difference from the outside, if the reference escapes, for example.
// TODO: look for actual escaping?
// TODO: create a thunk for external uses that allow internal optimizations
- (*infoMap)[curr->func].hasUnseenCalls = true;
+ currInfo->hasUnseenCalls.insert(curr->func);
}
// main entry point
void doWalkFunction(Function* func) {
- numParams = func->getNumParams();
+ // Set the info for this function.
info = &((*infoMap)[func->name]);
+
+ if (!info->stale) {
+ // Nothing changed since last time.
+ return;
+ }
+
+ // Clear the data, mark us as no longer stale, and recompute everything.
+ info->clear();
+ info->stale = false;
+
+ auto numParams = func->getNumParams();
PostWalker<DAEScanner, Visitor<DAEScanner>>::doWalkFunction(func);
- // If there are relevant params, check if they are used. If we can't
- // optimize the function anyhow, there's no point (note that our check here
- // is technically racy - another thread could update hasUnseenCalls to true
- // around when we check it - but that just means that we might or might not
- // do some extra work, as we'll ignore the results later if we have unseen
- // calls. That is, the check for hasUnseenCalls here is just a minor
- // optimization to avoid pointless work. We can avoid that work if either
- // we know there is an unseen call before the parallel analysis that we are
- // part of, say if we are exported, or if another parallel function finds a
- // RefFunc to us and updates it before we check it).
- if (numParams > 0 && !info->hasUnseenCalls) {
+ // If there are params, check if they are used.
+ // TODO: This work could be avoided if we cannot optimize for other reasons.
+ // That would require deferring this to later and checking that.
+ if (numParams > 0) {
auto usedParams = ParamUtils::getUsedParams(func, getModule());
for (Index i = 0; i < numParams; i++) {
if (usedParams.count(i) == 0) {
@@ -176,50 +189,83 @@ struct DAE : public Pass {
bool optimize = false;
void run(Module* module) override {
+ DAEFunctionInfoMap infoMap;
+ // Ensure all entries exist so the parallel threads don't modify the data
+ // structure.
+ for (auto& func : module->functions) {
+ infoMap[func->name];
+ }
+ // The null name represents module-level code (not in a function).
+ infoMap[Name()];
+
// Iterate to convergence.
while (1) {
- if (!iteration(module)) {
+ if (!iteration(module, infoMap)) {
break;
}
}
}
- bool iteration(Module* module) {
+ bool iteration(Module* module, DAEFunctionInfoMap& infoMap) {
allDroppedCalls.clear();
- DAEFunctionInfoMap infoMap;
- // Ensure they all exist so the parallel threads don't modify the data
- // structure.
- for (auto& func : module->functions) {
- infoMap[func->name];
+#if DAE_DEBUG
+ // Enable this path to mark all contents as stale at the start of each
+ // iteration, which can be used to check for staleness bugs (that is, bugs
+ // where something should have been marked stale, but wasn't). Note, though,
+ // that staleness bugs can easily cause serious issues with validation (e.g.
+ // if data is stale we may miss that there is an additional caller, that
+ // prevents refining argument types etc.), so this may not be terribly
+ // helpful.
+ if (getenv("ALWAYS_MARK_STALE")) {
+ for (auto& [_, info] : infoMap) {
+ info.markStale();
+ }
}
+#endif
+
DAEScanner scanner(&infoMap);
scanner.walkModuleCode(module);
- for (auto& curr : module->exports) {
- if (curr->kind == ExternalKind::Function) {
- infoMap[curr->value].hasUnseenCalls = true;
- }
- }
// Scan all the functions.
scanner.run(getPassRunner(), module);
// Combine all the info.
+ struct CallContext {
+ Call* call;
+ Function* func;
+ };
std::map<Name, std::vector<Call*>> allCalls;
std::unordered_set<Name> tailCallees;
- for (auto& [_, info] : infoMap) {
+ std::unordered_set<Name> hasUnseenCalls;
+ // Track the function in which relevant expressions exist. When we modify
+ // those expressions we will need to mark the function's info as stale.
+ std::unordered_map<Expression*, Name> expressionFuncs;
+ for (auto& [func, info] : infoMap) {
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[name];
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
+ for (auto* call : calls) {
+ expressionFuncs[call] = func;
+ }
}
for (auto& callee : info.tailCallees) {
tailCallees.insert(callee);
}
- for (auto& [name, calls] : info.droppedCalls) {
- allDroppedCalls[name] = calls;
+ for (auto& [call, dropp] : info.droppedCalls) {
+ allDroppedCalls[call] = dropp;
+ }
+ for (auto& name : info.hasUnseenCalls) {
+ hasUnseenCalls.insert(name);
+ }
+ }
+ // Exports are considered unseen calls.
+ for (auto& curr : module->exports) {
+ if (curr->kind == ExternalKind::Function) {
+ hasUnseenCalls.insert(curr->value);
}
}
- // Track which functions we changed, and optimize them later if necessary.
- std::unordered_set<Function*> changed;
+ // Track which functions we changed that are worth re-optimizing at the end.
+ std::unordered_set<Function*> worthOptimizing;
// If we refine return types then we will need to do more type updating
// at the end.
@@ -237,11 +283,25 @@ struct DAE : public Pass {
// This set tracks the functions for whom calls to it should be modified.
std::unordered_set<Name> callTargetsToLocalize;
+ // As we optimize, we mark things as stale.
+ auto markStale = [&](Name func) {
+ // We only ever mark functions stale (not the global scope, which we never
+ // modify). An attempt to modify the global scope, identified by a null
+ // function name, is a logic bug.
+ assert(func.is());
+ infoMap[func].markStale();
+ };
+ auto markCallersStale = [&](const std::vector<Call*>& calls) {
+ for (auto* call : calls) {
+ markStale(expressionFuncs[call]);
+ }
+ };
+
// We now have a mapping of all call sites for each function, and can look
// for optimization opportunities.
for (auto& [name, calls] : allCalls) {
// We can only optimize if we see all the calls and can modify them.
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
auto* func = module->getFunction(name);
@@ -249,11 +309,14 @@ struct DAE : public Pass {
// affect whether an argument is used or not, it just refines the type
// where possible.
if (refineArgumentTypes(func, calls, module, infoMap[name])) {
- changed.insert(func);
+ worthOptimizing.insert(func);
+ markStale(func->name);
}
// Refine return types as well.
if (refineReturnTypes(func, calls, module)) {
refinedReturnTypes = true;
+ markStale(func->name);
+ markCallersStale(calls);
}
auto optimizedIndexes =
ParamUtils::applyConstantValues({func}, calls, {}, module);
@@ -262,6 +325,9 @@ struct DAE : public Pass {
// for that).
infoMap[name].unusedParams.insert(i);
}
+ if (!optimizedIndexes.empty()) {
+ markStale(func->name);
+ }
}
if (refinedReturnTypes) {
// Changing a call expression's return type can propagate out to its
@@ -271,7 +337,7 @@ struct DAE : public Pass {
}
// We now know which parameters are unused, and can potentially remove them.
for (auto& [name, calls] : allCalls) {
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
auto* func = module->getFunction(name);
@@ -283,7 +349,9 @@ struct DAE : public Pass {
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
if (!removedIndexes.empty()) {
// Success!
- changed.insert(func);
+ worthOptimizing.insert(func);
+ markStale(func->name);
+ markCallersStale(calls);
}
if (outcome == ParamUtils::RemovalOutcome::Failure) {
callTargetsToLocalize.insert(name);
@@ -293,13 +361,13 @@ struct DAE : public Pass {
// that we can't do this if we changed anything so far, as we may have
// modified allCalls (we can't modify a call site twice in one iteration,
// once to remove a param, once to drop the return value).
- if (changed.empty()) {
+ if (worthOptimizing.empty()) {
for (auto& func : module->functions) {
if (func->getResults() == Type::none) {
continue;
}
auto name = func->name;
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
if (infoMap[name].hasTailCalls) {
@@ -323,17 +391,22 @@ struct DAE : public Pass {
removeReturnValue(func.get(), calls, module);
// TODO Removing a drop may also open optimization opportunities in the
// callers.
- changed.insert(func.get());
+ worthOptimizing.insert(func.get());
+ markStale(func->name);
+ markCallersStale(calls);
}
}
if (!callTargetsToLocalize.empty()) {
ParamUtils::localizeCallsTo(
- callTargetsToLocalize, *module, getPassRunner());
+ callTargetsToLocalize, *module, getPassRunner(), [&](Function* func) {
+ markStale(func->name);
+ });
}
- if (optimize && !changed.empty()) {
- OptUtils::optimizeAfterInlining(changed, module, getPassRunner());
+ if (optimize && !worthOptimizing.empty()) {
+ OptUtils::optimizeAfterInlining(worthOptimizing, module, getPassRunner());
}
- return !changed.empty() || refinedReturnTypes ||
+
+ return !worthOptimizing.empty() || refinedReturnTypes ||
!callTargetsToLocalize.empty();
}