diff options
Diffstat (limited to 'src/passes/DeadArgumentElimination.cpp')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 104 |
1 files changed, 101 insertions, 3 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 434ea1ece..e4b8eef56 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -27,6 +27,7 @@ // If so, we can avoid even sending and receiving it. (Note how if // the previous point was true for an argument, then the second // must as well.) +// * Find return values ("return arguments" ;) that are never used. // // This pass does not depend on flattening, but it may be more effective, // as then call arguments never have side effects (which we need to @@ -53,6 +54,9 @@ struct DAEFunctionInfo { SortedVector unusedParams; // Maps a function name to the calls going to it. std::unordered_map<Name, std::vector<Call*>> calls; + // Map of all calls that are dropped, to their drops' locations (so that + // if we can optimize out the drop, we can replace the drop there). + std::unordered_map<Call*, Expression**> droppedCalls; // Whether the function can be called from places that // affect what we can do. For now, any call we don't // see inhibits our optimizations, but TODO: an export @@ -116,6 +120,12 @@ struct DAEScanner : public WalkerPass<CFGWalker<DAEScanner, Visitor<DAEScanner>, } } + void visitDrop(Drop* curr) { + if (auto* call = curr->value->dynCast<Call>()) { + info->droppedCalls[call] = getCurrentPointer(); + } + } + // main entry point void doWalkFunction(Function* func) { @@ -197,6 +207,15 @@ struct DAE : public Pass { bool optimize = false; void run(PassRunner* runner, Module* module) override { + // Iterate to convergence. + while (1) { + if (!iteration(runner, module)) { + break; + } + } + } + + bool iteration(PassRunner* runner, Module* module) { DAEFunctionInfoMap infoMap; // Ensure they all exist so the parallel threads don't modify the data structure. ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { @@ -230,6 +249,9 @@ struct DAE : public Pass { auto& allCallsToName = allCalls[name]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); } + for (auto& pair : info.droppedCalls) { + allDroppedCalls[pair.first] = pair.second; + } } // We now have a mapping of all call sites for each function. Check which // are always passed the same constant for a particular argument. @@ -237,7 +259,9 @@ struct DAE : public Pass { auto name = pair.first; // We can only optimize if we see all the calls and can modify // them. - if (infoMap[name].hasUnseenCalls) continue; + if (infoMap[name].hasUnseenCalls) { + continue; + } auto& calls = pair.second; auto* func = module->getFunction(name); auto numParams = func->getNumParams(); @@ -311,13 +335,48 @@ struct DAE : public Pass { i--; } } - if (optimize && changed.size() > 0) { + // We can also tell which calls have all their return values dropped. Note that we can't do this + // if we changed anything so far, as we may have modified allCalls (we can't modify a call site + // twice in one iteration, once to remove a param, once to drop the return value). + if (changed.empty()) { + for (auto& func : module->functions) { + if (func->result == none) { + continue; + } + auto name = func->name; + if (infoMap[name].hasUnseenCalls) { + continue; + } + auto iter = allCalls.find(name); + if (iter == allCalls.end()) { + continue; + } + auto& calls = iter->second; + bool allDropped = true; + for (auto* call : calls) { + if (!allDroppedCalls.count(call)) { + allDropped = false; + break; + } + } + if (!allDropped) { + continue; + } + removeReturnValue(func.get(), calls, module); + // TODO Removing a drop may also open optimization opportunities in the callers. + changed.insert(func.get()); + } + } + if (optimize && !changed.empty()) { OptUtils::optimizeAfterInlining(changed, module, runner); } + return !changed.empty(); } private: - void removeParameter(Function* func, Index i, std::vector<Call*> calls) { + std::unordered_map<Call*, Expression**> allDroppedCalls; + + void removeParameter(Function* func, Index i, std::vector<Call*>& calls) { // Clear the type, which is no longer accurate. func->type = Name(); // It's cumbersome to adjust local names - TODO don't clear them? @@ -354,6 +413,45 @@ private: call->operands.erase(call->operands.begin() + i); } } + + void removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) { + // Clear the type, which is no longer accurate. + func->type = Name(); + func->result = none; + Builder builder(*module); + // Remove any return values. + struct ReturnUpdater : public PostWalker<ReturnUpdater> { + Module* module; + ReturnUpdater(Function* func, Module* module) : module(module) { + walk(func->body); + } + void visitReturn(Return* curr) { + auto* value = curr->value; + assert(value); + curr->value = nullptr; + Builder builder(*module); + replaceCurrent(builder.makeSequence( + builder.makeDrop(value), + curr + )); + } + } returnUpdater(func, module); + // Remove any value flowing out. + if (isConcreteType(func->body->type)) { + func->body = builder.makeDrop(func->body); + } + // Remove the drops on the calls. + for (auto* call : calls) { + auto iter = allDroppedCalls.find(call); + assert(iter != allDroppedCalls.end()); + Expression** location = iter->second; + *location = call; + // Update the call's type. + if (call->type != unreachable) { + call->type = none; + } + } + } }; Pass *createDAEPass() { |