diff options
Diffstat (limited to 'src/passes/DeadArgumentElimination.cpp')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 789332b93..9ef762f98 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -57,6 +57,16 @@ struct DAEFunctionInfo { // Map of all calls that are dropped, to their drops' locations (so that // if we can optimize out the drop, we can replace the drop there). std::unordered_map<Call*, Expression**> droppedCalls; + // Whether this function contains any tail calls (including indirect tail + // calls) and the set of functions this function tail calls. Tail-callers and + // tail-callees cannot have their dropped returns removed because of the + // constraint that tail-callees must have the same return type as + // tail-callers. Indirectly tail called functions are already not optimized + // because being in a table inhibits DAE. TODO: Allow the removal of dropped + // returns from tail-callers if their tail-callees can have their returns + // removed as well. + bool hasTailCalls = false; + std::unordered_set<Name> tailCallees; // Whether the function can be called from places that // affect what we can do. For now, any call we don't // see inhibits our optimizations, but TODO: an export @@ -117,6 +127,16 @@ struct DAEScanner if (!getModule()->getFunction(curr->target)->imported()) { info->calls[curr->target].push_back(curr); } + if (curr->isReturn) { + info->hasTailCalls = true; + info->tailCallees.insert(curr->target); + } + } + + void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + info->hasTailCalls = true; + } } void visitDrop(Drop* curr) { @@ -239,6 +259,7 @@ struct DAE : public Pass { DAEScanner(&infoMap).run(runner, module); // Combine all the info. std::unordered_map<Name, std::vector<Call*>> allCalls; + std::unordered_set<Name> tailCallees; for (auto& pair : infoMap) { auto& info = pair.second; for (auto& pair : info.calls) { @@ -247,6 +268,9 @@ struct DAE : public Pass { auto& allCallsToName = allCalls[name]; allCallsToName.insert(allCallsToName.end(), calls.begin(), calls.end()); } + for (auto& callee : info.tailCallees) { + tailCallees.insert(callee); + } for (auto& pair : info.droppedCalls) { allDroppedCalls[pair.first] = pair.second; } @@ -314,14 +338,11 @@ struct DAE : public Pass { // Great, it's not used. Check if none of the calls has a param with // side effects, as that would prevent us removing them (flattening // should have been done earlier). - bool canRemove = true; - for (auto* call : calls) { - auto* operand = call->operands[i]; - if (EffectAnalyzer(runner->options, operand).hasSideEffects()) { - canRemove = false; - break; - } - } + bool canRemove = + std::none_of(calls.begin(), calls.end(), [&](Call* call) { + auto* operand = call->operands[i]; + return EffectAnalyzer(runner->options, operand).hasSideEffects(); + }); if (canRemove) { // Wonderful, nothing stands in our way! Do it. // TODO: parallelize this? @@ -348,18 +369,21 @@ struct DAE : public Pass { if (infoMap[name].hasUnseenCalls) { continue; } + if (infoMap[name].hasTailCalls) { + continue; + } + if (tailCallees.find(name) != tailCallees.end()) { + continue; + } auto iter = allCalls.find(name); if (iter == allCalls.end()) { continue; } auto& calls = iter->second; - bool allDropped = true; - for (auto* call : calls) { - if (!allDroppedCalls.count(call)) { - allDropped = false; - break; - } - } + bool allDropped = + std::all_of(calls.begin(), calls.end(), [&](Call* call) { + return allDroppedCalls.count(call); + }); if (!allDropped) { continue; } |