diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 183 | ||||
-rw-r--r-- | src/passes/param-utils.cpp | 14 | ||||
-rw-r--r-- | src/passes/param-utils.h | 5 |
3 files changed, 141 insertions, 61 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 99a709654..63f2d7fdc 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -56,6 +56,9 @@ namespace wasm { // Information for a function struct DAEFunctionInfo { + // Whether this needs to be recomputed. This begins as true for the first + // computation, and we reset it every time we touch the function. + bool stale = true; // The unused parameters, if any. SortedVector unusedParams; // Maps a function name to the calls going to it. @@ -73,16 +76,17 @@ struct DAEFunctionInfo { // removed as well. bool hasTailCalls = false; std::unordered_set<Name> tailCallees; - // Whether the function can be called from places that - // affect what we can do. For now, any call we don't - // see inhibits our optimizations, but TODO: an export - // could be worked around by exporting a thunk that - // adds the parameter. - // This is atomic so that we can write to it from any function at any time - // during the parallel analysis phase which is run in DAEScanner. - std::atomic<bool> hasUnseenCalls; - - DAEFunctionInfo() { hasUnseenCalls = false; } + // The set of functions that have calls from places that limit what we can do. + // For now, any call we don't see inhibits our optimizations, but TODO: an + // export could be worked around by exporting a thunk that adds the parameter. + // + // This is built up in parallel in each function, and combined at the end. + std::unordered_set<Name> hasUnseenCalls; + + // Clears all data, which marks us as stale and in need of recomputation. + void clear() { *this = DAEFunctionInfo(); } + + void markStale() { stale = true; } }; using DAEFunctionInfoMap = std::unordered_map<Name, DAEFunctionInfo>; @@ -97,10 +101,12 @@ struct DAEScanner DAEScanner(DAEFunctionInfoMap* infoMap) : infoMap(infoMap) {} + // The map of all infos for all functions. DAEFunctionInfoMap* infoMap; - DAEFunctionInfo* info; - Index numParams; + // The info for the function this instance operates on. We stash this as an + // optimization. + DAEFunctionInfo* info = nullptr; void visitCall(Call* curr) { if (!getModule()->getFunction(curr->target)->imported()) { @@ -131,33 +137,40 @@ struct DAEScanner } void visitRefFunc(RefFunc* curr) { - // We can't modify another function in parallel. - assert((*infoMap).count(curr->func)); + // RefFunc may be visited from either a function, in which case |info| was + // set, or module-level code (in which case we use the null function name in + // the infoMap). + auto* currInfo = info ? info : &(*infoMap)[Name()]; + // Treat a ref.func as an unseen call, preventing us from changing the // function's type. If we did change it, it could be an observable // difference from the outside, if the reference escapes, for example. // TODO: look for actual escaping? // TODO: create a thunk for external uses that allow internal optimizations - (*infoMap)[curr->func].hasUnseenCalls = true; + currInfo->hasUnseenCalls.insert(curr->func); } // main entry point void doWalkFunction(Function* func) { - numParams = func->getNumParams(); + // Set the info for this function. info = &((*infoMap)[func->name]); + + if (!info->stale) { + // Nothing changed since last time. + return; + } + + // Clear the data, mark us as no longer stale, and recompute everything. + info->clear(); + info->stale = false; + + auto numParams = func->getNumParams(); PostWalker<DAEScanner, Visitor<DAEScanner>>::doWalkFunction(func); - // If there are relevant params, check if they are used. If we can't - // optimize the function anyhow, there's no point (note that our check here - // is technically racy - another thread could update hasUnseenCalls to true - // around when we check it - but that just means that we might or might not - // do some extra work, as we'll ignore the results later if we have unseen - // calls. That is, the check for hasUnseenCalls here is just a minor - // optimization to avoid pointless work. We can avoid that work if either - // we know there is an unseen call before the parallel analysis that we are - // part of, say if we are exported, or if another parallel function finds a - // RefFunc to us and updates it before we check it). - if (numParams > 0 && !info->hasUnseenCalls) { + // If there are params, check if they are used. + // TODO: This work could be avoided if we cannot optimize for other reasons. + // That would require deferring this to later and checking that. + if (numParams > 0) { auto usedParams = ParamUtils::getUsedParams(func, getModule()); for (Index i = 0; i < numParams; i++) { if (usedParams.count(i) == 0) { @@ -176,50 +189,83 @@ struct DAE : public Pass { bool optimize = false; void run(Module* module) override { + DAEFunctionInfoMap infoMap; + // Ensure all entries exist so the parallel threads don't modify the data + // structure. + for (auto& func : module->functions) { + infoMap[func->name]; + } + // The null name represents module-level code (not in a function). + infoMap[Name()]; + // Iterate to convergence. while (1) { - if (!iteration(module)) { + if (!iteration(module, infoMap)) { break; } } } - bool iteration(Module* module) { + bool iteration(Module* module, DAEFunctionInfoMap& infoMap) { allDroppedCalls.clear(); - DAEFunctionInfoMap infoMap; - // Ensure they all exist so the parallel threads don't modify the data - // structure. - for (auto& func : module->functions) { - infoMap[func->name]; +#if DAE_DEBUG + // Enable this path to mark all contents as stale at the start of each + // iteration, which can be used to check for staleness bugs (that is, bugs + // where something should have been marked stale, but wasn't). Note, though, + // that staleness bugs can easily cause serious issues with validation (e.g. + // if data is stale we may miss that there is an additional caller, that + // prevents refining argument types etc.), so this may not be terribly + // helpful. + if (getenv("ALWAYS_MARK_STALE")) { + for (auto& [_, info] : infoMap) { + info.markStale(); + } } +#endif + DAEScanner scanner(&infoMap); scanner.walkModuleCode(module); - for (auto& curr : module->exports) { - if (curr->kind == ExternalKind::Function) { - infoMap[curr->value].hasUnseenCalls = true; - } - } // Scan all the functions. scanner.run(getPassRunner(), module); // Combine all the info. + struct CallContext { + Call* call; + Function* func; + }; std::map<Name, std::vector<Call*>> allCalls; std::unordered_set<Name> tailCallees; - for (auto& [_, info] : infoMap) { + std::unordered_set<Name> hasUnseenCalls; + // Track the function in which relevant expressions exist. When we modify + // those expressions we will need to mark the function's info as stale. + std::unordered_map<Expression*, Name> expressionFuncs; + for (auto& [func, info] : infoMap) { for (auto& [name, calls] : info.calls) { auto& allCallsToName = allCalls[name]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); + for (auto* call : calls) { + expressionFuncs[call] = func; + } } for (auto& callee : info.tailCallees) { tailCallees.insert(callee); } - for (auto& [name, calls] : info.droppedCalls) { - allDroppedCalls[name] = calls; + for (auto& [call, dropp] : info.droppedCalls) { + allDroppedCalls[call] = dropp; + } + for (auto& name : info.hasUnseenCalls) { + hasUnseenCalls.insert(name); + } + } + // Exports are considered unseen calls. + for (auto& curr : module->exports) { + if (curr->kind == ExternalKind::Function) { + hasUnseenCalls.insert(curr->value); } } - // Track which functions we changed, and optimize them later if necessary. - std::unordered_set<Function*> changed; + // Track which functions we changed that are worth re-optimizing at the end. + std::unordered_set<Function*> worthOptimizing; // If we refine return types then we will need to do more type updating // at the end. @@ -237,11 +283,25 @@ struct DAE : public Pass { // This set tracks the functions for whom calls to it should be modified. std::unordered_set<Name> callTargetsToLocalize; + // As we optimize, we mark things as stale. + auto markStale = [&](Name func) { + // We only ever mark functions stale (not the global scope, which we never + // modify). An attempt to modify the global scope, identified by a null + // function name, is a logic bug. + assert(func.is()); + infoMap[func].markStale(); + }; + auto markCallersStale = [&](const std::vector<Call*>& calls) { + for (auto* call : calls) { + markStale(expressionFuncs[call]); + } + }; + // We now have a mapping of all call sites for each function, and can look // for optimization opportunities. for (auto& [name, calls] : allCalls) { // We can only optimize if we see all the calls and can modify them. - if (infoMap[name].hasUnseenCalls) { + if (hasUnseenCalls.count(name)) { continue; } auto* func = module->getFunction(name); @@ -249,11 +309,14 @@ struct DAE : public Pass { // affect whether an argument is used or not, it just refines the type // where possible. if (refineArgumentTypes(func, calls, module, infoMap[name])) { - changed.insert(func); + worthOptimizing.insert(func); + markStale(func->name); } // Refine return types as well. if (refineReturnTypes(func, calls, module)) { refinedReturnTypes = true; + markStale(func->name); + markCallersStale(calls); } auto optimizedIndexes = ParamUtils::applyConstantValues({func}, calls, {}, module); @@ -262,6 +325,9 @@ struct DAE : public Pass { // for that). infoMap[name].unusedParams.insert(i); } + if (!optimizedIndexes.empty()) { + markStale(func->name); + } } if (refinedReturnTypes) { // Changing a call expression's return type can propagate out to its @@ -271,7 +337,7 @@ struct DAE : public Pass { } // We now know which parameters are unused, and can potentially remove them. for (auto& [name, calls] : allCalls) { - if (infoMap[name].hasUnseenCalls) { + if (hasUnseenCalls.count(name)) { continue; } auto* func = module->getFunction(name); @@ -283,7 +349,9 @@ struct DAE : public Pass { {func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner()); if (!removedIndexes.empty()) { // Success! - changed.insert(func); + worthOptimizing.insert(func); + markStale(func->name); + markCallersStale(calls); } if (outcome == ParamUtils::RemovalOutcome::Failure) { callTargetsToLocalize.insert(name); @@ -293,13 +361,13 @@ struct DAE : public Pass { // that we can't do this if we changed anything so far, as we may have // modified allCalls (we can't modify a call site twice in one iteration, // once to remove a param, once to drop the return value). - if (changed.empty()) { + if (worthOptimizing.empty()) { for (auto& func : module->functions) { if (func->getResults() == Type::none) { continue; } auto name = func->name; - if (infoMap[name].hasUnseenCalls) { + if (hasUnseenCalls.count(name)) { continue; } if (infoMap[name].hasTailCalls) { @@ -323,17 +391,22 @@ struct DAE : public Pass { removeReturnValue(func.get(), calls, module); // TODO Removing a drop may also open optimization opportunities in the // callers. - changed.insert(func.get()); + worthOptimizing.insert(func.get()); + markStale(func->name); + markCallersStale(calls); } } if (!callTargetsToLocalize.empty()) { ParamUtils::localizeCallsTo( - callTargetsToLocalize, *module, getPassRunner()); + callTargetsToLocalize, *module, getPassRunner(), [&](Function* func) { + markStale(func->name); + }); } - if (optimize && !changed.empty()) { - OptUtils::optimizeAfterInlining(changed, module, getPassRunner()); + if (optimize && !worthOptimizing.empty()) { + OptUtils::optimizeAfterInlining(worthOptimizing, module, getPassRunner()); } - return !changed.empty() || refinedReturnTypes || + + return !worthOptimizing.empty() || refinedReturnTypes || !callTargetsToLocalize.empty(); } diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp index f54f91bd9..a600e1928 100644 --- a/src/passes/param-utils.cpp +++ b/src/passes/param-utils.cpp @@ -286,18 +286,21 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, void localizeCallsTo(const std::unordered_set<Name>& callTargets, Module& wasm, - PassRunner* runner) { + PassRunner* runner, + std::function<void(Function*)> onChange) { struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { bool isFunctionParallel() override { return true; } std::unique_ptr<Pass> create() override { - return std::make_unique<LocalizerPass>(callTargets); + return std::make_unique<LocalizerPass>(callTargets, onChange); } const std::unordered_set<Name>& callTargets; + std::function<void(Function*)> onChange; - LocalizerPass(const std::unordered_set<Name>& callTargets) - : callTargets(callTargets) {} + LocalizerPass(const std::unordered_set<Name>& callTargets, + std::function<void(Function*)> onChange) + : callTargets(callTargets), onChange(onChange) {} void visitCall(Call* curr) { if (!callTargets.count(curr->target)) { @@ -310,6 +313,7 @@ void localizeCallsTo(const std::unordered_set<Name>& callTargets, if (replacement != curr) { replaceCurrent(replacement); optimized = true; + onChange(getFunction()); } } @@ -323,7 +327,7 @@ void localizeCallsTo(const std::unordered_set<Name>& callTargets, } }; - LocalizerPass(callTargets).run(runner, &wasm); + LocalizerPass(callTargets, onChange).run(runner, &wasm); } void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, diff --git a/src/passes/param-utils.h b/src/passes/param-utils.h index 35e5d9f80..c5c52f4ce 100644 --- a/src/passes/param-utils.h +++ b/src/passes/param-utils.h @@ -114,9 +114,12 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, // The set of targets can be function names (the individual functions we want to // handle calls towards) or heap types (which will then include all functions // with those types). +// +// The onChange() callback is called when we modify a function. void localizeCallsTo(const std::unordered_set<Name>& callTargets, Module& wasm, - PassRunner* runner); + PassRunner* runner, + std::function<void(Function*)> onChange); void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, Module& wasm, PassRunner* runner); |