diff options
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 60 |
1 files changed, 44 insertions, 16 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 5ebca9cf0..784aec31a 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -70,7 +70,8 @@ struct DAEFunctionInfo { // because being in a table inhibits DAE. TODO: Allow the removal of dropped // returns from tail-callers if their tail-callees can have their returns // removed as well. - std::atomic<bool> hasTailCalls; + bool hasTailCalls = false; + std::unordered_set<Name> tailCallees; // Whether the function can be called from places that // affect what we can do. For now, any call we don't // see inhibits our optimizations, but TODO: an export @@ -80,10 +81,7 @@ struct DAEFunctionInfo { // during the parallel analysis phase which is run in DAEScanner. std::atomic<bool> hasUnseenCalls; - DAEFunctionInfo() { - hasUnseenCalls = false; - hasTailCalls = false; - } + DAEFunctionInfo() { hasUnseenCalls = false; } }; typedef std::unordered_map<Name, DAEFunctionInfo> DAEFunctionInfoMap; @@ -107,7 +105,7 @@ struct DAEScanner } if (curr->isReturn) { info->hasTailCalls = true; - (*infoMap)[curr->target].hasTailCalls = true; + info->tailCallees.insert(curr->target); } } @@ -203,11 +201,15 @@ struct DAE : public Pass { scanner.run(runner, module); // Combine all the info. std::map<Name, std::vector<Call*>> allCalls; + std::unordered_set<Name> tailCallees; for (auto& [_, info] : infoMap) { for (auto& [name, calls] : info.calls) { auto& allCallsToName = allCalls[name]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); } + for (auto& callee : info.tailCallees) { + tailCallees.insert(callee); + } for (auto& [name, calls] : info.droppedCalls) { allDroppedCalls[name] = calls; } @@ -249,19 +251,45 @@ struct DAE : public Pass { std::unordered_set<Function*> changed; // We now know which parameters are unused, and can potentially remove them. for (auto& [name, calls] : allCalls) { - auto& info = infoMap[name]; - if (info.hasUnseenCalls) { + if (infoMap[name].hasUnseenCalls) { continue; } auto* func = module->getFunction(name); - if (func->getNumParams() > 0 && - !ParamUtils::removeParameters( - {func}, info.unusedParams, calls, {}, module, runner) - .empty()) { + auto numParams = func->getNumParams(); + if (numParams == 0) { + continue; + } + auto removedIndexes = ParamUtils::removeParameters( + {func}, infoMap[name].unusedParams, calls, {}, module, runner); + if (!removedIndexes.empty()) { // Success! changed.insert(func); - } else if (!info.hasTailCalls && func->getResults() != Type::none) { - // We can also tell which calls have all their return values dropped. + } + } + // We can also tell which calls have all their return values dropped. Note + // that we can't do this if we changed anything so far, as we may have + // modified allCalls (we can't modify a call site twice in one iteration, + // once to remove a param, once to drop the return value). + if (changed.empty()) { + for (auto& func : module->functions) { + if (func->getResults() == Type::none) { + continue; + } + auto name = func->name; + if (infoMap[name].hasUnseenCalls) { + continue; + } + if (infoMap[name].hasTailCalls) { + continue; + } + if (tailCallees.count(name)) { + continue; + } + auto iter = allCalls.find(name); + if (iter == allCalls.end()) { + continue; + } + auto& calls = iter->second; bool allDropped = std::all_of(calls.begin(), calls.end(), [&](Call* call) { return allDroppedCalls.count(call); @@ -269,10 +297,10 @@ struct DAE : public Pass { if (!allDropped) { continue; } - removeReturnValue(func, calls, module); + removeReturnValue(func.get(), calls, module); // TODO Removing a drop may also open optimization opportunities in the // callers. - changed.insert(func); + changed.insert(func.get()); } } if (optimize && !changed.empty()) { |