summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/DeadArgumentElimination.cpp183
-rw-r--r--src/passes/param-utils.cpp14
-rw-r--r--src/passes/param-utils.h5
3 files changed, 141 insertions, 61 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index 99a709654..63f2d7fdc 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -56,6 +56,9 @@ namespace wasm {
// Information for a function
struct DAEFunctionInfo {
+ // Whether this needs to be recomputed. This begins as true for the first
+ // computation, and we reset it every time we touch the function.
+ bool stale = true;
// The unused parameters, if any.
SortedVector unusedParams;
// Maps a function name to the calls going to it.
@@ -73,16 +76,17 @@ struct DAEFunctionInfo {
// removed as well.
bool hasTailCalls = false;
std::unordered_set<Name> tailCallees;
- // Whether the function can be called from places that
- // affect what we can do. For now, any call we don't
- // see inhibits our optimizations, but TODO: an export
- // could be worked around by exporting a thunk that
- // adds the parameter.
- // This is atomic so that we can write to it from any function at any time
- // during the parallel analysis phase which is run in DAEScanner.
- std::atomic<bool> hasUnseenCalls;
-
- DAEFunctionInfo() { hasUnseenCalls = false; }
+ // The set of functions that have calls from places that limit what we can do.
+ // For now, any call we don't see inhibits our optimizations, but TODO: an
+ // export could be worked around by exporting a thunk that adds the parameter.
+ //
+ // This is built up in parallel in each function, and combined at the end.
+ std::unordered_set<Name> hasUnseenCalls;
+
+ // Clears all data, which marks us as stale and in need of recomputation.
+ void clear() { *this = DAEFunctionInfo(); }
+
+ void markStale() { stale = true; }
};
using DAEFunctionInfoMap = std::unordered_map<Name, DAEFunctionInfo>;
@@ -97,10 +101,12 @@ struct DAEScanner
DAEScanner(DAEFunctionInfoMap* infoMap) : infoMap(infoMap) {}
+ // The map of all infos for all functions.
DAEFunctionInfoMap* infoMap;
- DAEFunctionInfo* info;
- Index numParams;
+ // The info for the function this instance operates on. We stash this as an
+ // optimization.
+ DAEFunctionInfo* info = nullptr;
void visitCall(Call* curr) {
if (!getModule()->getFunction(curr->target)->imported()) {
@@ -131,33 +137,40 @@ struct DAEScanner
}
void visitRefFunc(RefFunc* curr) {
- // We can't modify another function in parallel.
- assert((*infoMap).count(curr->func));
+ // RefFunc may be visited from either a function, in which case |info| was
+ // set, or module-level code (in which case we use the null function name in
+ // the infoMap).
+ auto* currInfo = info ? info : &(*infoMap)[Name()];
+
// Treat a ref.func as an unseen call, preventing us from changing the
// function's type. If we did change it, it could be an observable
// difference from the outside, if the reference escapes, for example.
// TODO: look for actual escaping?
// TODO: create a thunk for external uses that allow internal optimizations
- (*infoMap)[curr->func].hasUnseenCalls = true;
+ currInfo->hasUnseenCalls.insert(curr->func);
}
// main entry point
void doWalkFunction(Function* func) {
- numParams = func->getNumParams();
+ // Set the info for this function.
info = &((*infoMap)[func->name]);
+
+ if (!info->stale) {
+ // Nothing changed since last time.
+ return;
+ }
+
+ // Clear the data, mark us as no longer stale, and recompute everything.
+ info->clear();
+ info->stale = false;
+
+ auto numParams = func->getNumParams();
PostWalker<DAEScanner, Visitor<DAEScanner>>::doWalkFunction(func);
- // If there are relevant params, check if they are used. If we can't
- // optimize the function anyhow, there's no point (note that our check here
- // is technically racy - another thread could update hasUnseenCalls to true
- // around when we check it - but that just means that we might or might not
- // do some extra work, as we'll ignore the results later if we have unseen
- // calls. That is, the check for hasUnseenCalls here is just a minor
- // optimization to avoid pointless work. We can avoid that work if either
- // we know there is an unseen call before the parallel analysis that we are
- // part of, say if we are exported, or if another parallel function finds a
- // RefFunc to us and updates it before we check it).
- if (numParams > 0 && !info->hasUnseenCalls) {
+ // If there are params, check if they are used.
+ // TODO: This work could be avoided if we cannot optimize for other reasons.
+ // That would require deferring this to later and checking that.
+ if (numParams > 0) {
auto usedParams = ParamUtils::getUsedParams(func, getModule());
for (Index i = 0; i < numParams; i++) {
if (usedParams.count(i) == 0) {
@@ -176,50 +189,83 @@ struct DAE : public Pass {
bool optimize = false;
void run(Module* module) override {
+ DAEFunctionInfoMap infoMap;
+ // Ensure all entries exist so the parallel threads don't modify the data
+ // structure.
+ for (auto& func : module->functions) {
+ infoMap[func->name];
+ }
+ // The null name represents module-level code (not in a function).
+ infoMap[Name()];
+
// Iterate to convergence.
while (1) {
- if (!iteration(module)) {
+ if (!iteration(module, infoMap)) {
break;
}
}
}
- bool iteration(Module* module) {
+ bool iteration(Module* module, DAEFunctionInfoMap& infoMap) {
allDroppedCalls.clear();
- DAEFunctionInfoMap infoMap;
- // Ensure they all exist so the parallel threads don't modify the data
- // structure.
- for (auto& func : module->functions) {
- infoMap[func->name];
+#if DAE_DEBUG
+ // Enable this path to mark all contents as stale at the start of each
+ // iteration, which can be used to check for staleness bugs (that is, bugs
+ // where something should have been marked stale, but wasn't). Note, though,
+ // that staleness bugs can easily cause serious issues with validation (e.g.
+ // if data is stale we may miss that there is an additional caller, that
+ // prevents refining argument types etc.), so this may not be terribly
+ // helpful.
+ if (getenv("ALWAYS_MARK_STALE")) {
+ for (auto& [_, info] : infoMap) {
+ info.markStale();
+ }
}
+#endif
+
DAEScanner scanner(&infoMap);
scanner.walkModuleCode(module);
- for (auto& curr : module->exports) {
- if (curr->kind == ExternalKind::Function) {
- infoMap[curr->value].hasUnseenCalls = true;
- }
- }
// Scan all the functions.
scanner.run(getPassRunner(), module);
// Combine all the info.
+ struct CallContext {
+ Call* call;
+ Function* func;
+ };
std::map<Name, std::vector<Call*>> allCalls;
std::unordered_set<Name> tailCallees;
- for (auto& [_, info] : infoMap) {
+ std::unordered_set<Name> hasUnseenCalls;
+ // Track the function in which relevant expressions exist. When we modify
+ // those expressions we will need to mark the function's info as stale.
+ std::unordered_map<Expression*, Name> expressionFuncs;
+ for (auto& [func, info] : infoMap) {
for (auto& [name, calls] : info.calls) {
auto& allCallsToName = allCalls[name];
allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end());
+ for (auto* call : calls) {
+ expressionFuncs[call] = func;
+ }
}
for (auto& callee : info.tailCallees) {
tailCallees.insert(callee);
}
- for (auto& [name, calls] : info.droppedCalls) {
- allDroppedCalls[name] = calls;
+ for (auto& [call, dropp] : info.droppedCalls) {
+ allDroppedCalls[call] = dropp;
+ }
+ for (auto& name : info.hasUnseenCalls) {
+ hasUnseenCalls.insert(name);
+ }
+ }
+ // Exports are considered unseen calls.
+ for (auto& curr : module->exports) {
+ if (curr->kind == ExternalKind::Function) {
+ hasUnseenCalls.insert(curr->value);
}
}
- // Track which functions we changed, and optimize them later if necessary.
- std::unordered_set<Function*> changed;
+ // Track which functions we changed that are worth re-optimizing at the end.
+ std::unordered_set<Function*> worthOptimizing;
// If we refine return types then we will need to do more type updating
// at the end.
@@ -237,11 +283,25 @@ struct DAE : public Pass {
// This set tracks the functions for whom calls to it should be modified.
std::unordered_set<Name> callTargetsToLocalize;
+ // As we optimize, we mark things as stale.
+ auto markStale = [&](Name func) {
+ // We only ever mark functions stale (not the global scope, which we never
+ // modify). An attempt to modify the global scope, identified by a null
+ // function name, is a logic bug.
+ assert(func.is());
+ infoMap[func].markStale();
+ };
+ auto markCallersStale = [&](const std::vector<Call*>& calls) {
+ for (auto* call : calls) {
+ markStale(expressionFuncs[call]);
+ }
+ };
+
// We now have a mapping of all call sites for each function, and can look
// for optimization opportunities.
for (auto& [name, calls] : allCalls) {
// We can only optimize if we see all the calls and can modify them.
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
auto* func = module->getFunction(name);
@@ -249,11 +309,14 @@ struct DAE : public Pass {
// affect whether an argument is used or not, it just refines the type
// where possible.
if (refineArgumentTypes(func, calls, module, infoMap[name])) {
- changed.insert(func);
+ worthOptimizing.insert(func);
+ markStale(func->name);
}
// Refine return types as well.
if (refineReturnTypes(func, calls, module)) {
refinedReturnTypes = true;
+ markStale(func->name);
+ markCallersStale(calls);
}
auto optimizedIndexes =
ParamUtils::applyConstantValues({func}, calls, {}, module);
@@ -262,6 +325,9 @@ struct DAE : public Pass {
// for that).
infoMap[name].unusedParams.insert(i);
}
+ if (!optimizedIndexes.empty()) {
+ markStale(func->name);
+ }
}
if (refinedReturnTypes) {
// Changing a call expression's return type can propagate out to its
@@ -271,7 +337,7 @@ struct DAE : public Pass {
}
// We now know which parameters are unused, and can potentially remove them.
for (auto& [name, calls] : allCalls) {
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
auto* func = module->getFunction(name);
@@ -283,7 +349,9 @@ struct DAE : public Pass {
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
if (!removedIndexes.empty()) {
// Success!
- changed.insert(func);
+ worthOptimizing.insert(func);
+ markStale(func->name);
+ markCallersStale(calls);
}
if (outcome == ParamUtils::RemovalOutcome::Failure) {
callTargetsToLocalize.insert(name);
@@ -293,13 +361,13 @@ struct DAE : public Pass {
// that we can't do this if we changed anything so far, as we may have
// modified allCalls (we can't modify a call site twice in one iteration,
// once to remove a param, once to drop the return value).
- if (changed.empty()) {
+ if (worthOptimizing.empty()) {
for (auto& func : module->functions) {
if (func->getResults() == Type::none) {
continue;
}
auto name = func->name;
- if (infoMap[name].hasUnseenCalls) {
+ if (hasUnseenCalls.count(name)) {
continue;
}
if (infoMap[name].hasTailCalls) {
@@ -323,17 +391,22 @@ struct DAE : public Pass {
removeReturnValue(func.get(), calls, module);
// TODO Removing a drop may also open optimization opportunities in the
// callers.
- changed.insert(func.get());
+ worthOptimizing.insert(func.get());
+ markStale(func->name);
+ markCallersStale(calls);
}
}
if (!callTargetsToLocalize.empty()) {
ParamUtils::localizeCallsTo(
- callTargetsToLocalize, *module, getPassRunner());
+ callTargetsToLocalize, *module, getPassRunner(), [&](Function* func) {
+ markStale(func->name);
+ });
}
- if (optimize && !changed.empty()) {
- OptUtils::optimizeAfterInlining(changed, module, getPassRunner());
+ if (optimize && !worthOptimizing.empty()) {
+ OptUtils::optimizeAfterInlining(worthOptimizing, module, getPassRunner());
}
- return !changed.empty() || refinedReturnTypes ||
+
+ return !worthOptimizing.empty() || refinedReturnTypes ||
!callTargetsToLocalize.empty();
}
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
index f54f91bd9..a600e1928 100644
--- a/src/passes/param-utils.cpp
+++ b/src/passes/param-utils.cpp
@@ -286,18 +286,21 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
void localizeCallsTo(const std::unordered_set<Name>& callTargets,
Module& wasm,
- PassRunner* runner) {
+ PassRunner* runner,
+ std::function<void(Function*)> onChange) {
struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
bool isFunctionParallel() override { return true; }
std::unique_ptr<Pass> create() override {
- return std::make_unique<LocalizerPass>(callTargets);
+ return std::make_unique<LocalizerPass>(callTargets, onChange);
}
const std::unordered_set<Name>& callTargets;
+ std::function<void(Function*)> onChange;
- LocalizerPass(const std::unordered_set<Name>& callTargets)
- : callTargets(callTargets) {}
+ LocalizerPass(const std::unordered_set<Name>& callTargets,
+ std::function<void(Function*)> onChange)
+ : callTargets(callTargets), onChange(onChange) {}
void visitCall(Call* curr) {
if (!callTargets.count(curr->target)) {
@@ -310,6 +313,7 @@ void localizeCallsTo(const std::unordered_set<Name>& callTargets,
if (replacement != curr) {
replaceCurrent(replacement);
optimized = true;
+ onChange(getFunction());
}
}
@@ -323,7 +327,7 @@ void localizeCallsTo(const std::unordered_set<Name>& callTargets,
}
};
- LocalizerPass(callTargets).run(runner, &wasm);
+ LocalizerPass(callTargets, onChange).run(runner, &wasm);
}
void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
diff --git a/src/passes/param-utils.h b/src/passes/param-utils.h
index 35e5d9f80..c5c52f4ce 100644
--- a/src/passes/param-utils.h
+++ b/src/passes/param-utils.h
@@ -114,9 +114,12 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
// The set of targets can be function names (the individual functions we want to
// handle calls towards) or heap types (which will then include all functions
// with those types).
+//
+// The onChange() callback is called when we modify a function.
void localizeCallsTo(const std::unordered_set<Name>& callTargets,
Module& wasm,
- PassRunner* runner);
+ PassRunner* runner,
+ std::function<void(Function*)> onChange);
void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
Module& wasm,
PassRunner* runner);