diff options
Diffstat (limited to 'src/passes/DeadArgumentElimination.cpp')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 111 |
1 files changed, 10 insertions, 101 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index ec81639b0..ce4d17c74 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -40,11 +40,11 @@ #include "ir/effects.h" #include "ir/element-utils.h" #include "ir/find_all.h" -#include "ir/local-graph.h" #include "ir/lubs.h" #include "ir/module-utils.h" #include "ir/type-updating.h" #include "ir/utils.h" +#include "param-utils.h" #include "pass.h" #include "passes/opt-utils.h" #include "support/sorted_vector.h" @@ -155,41 +155,13 @@ struct DAEScanner // part of, say if we are exported, or if another parallel function finds a // RefFunc to us and updates it before we check it). if (numParams > 0 && !info->hasUnseenCalls) { - findUnusedParams(func); - } - } - - void findUnusedParams(Function* func) { - LocalGraph localGraph(func); - std::unordered_set<Index> usedParams; - for (auto& [get, sets] : localGraph.getSetses) { - if (!func->isParam(get->index)) { - continue; - } - - // Check if this get of a param index can read from the parameter value - // passed into the function. We want to ignore values set in the function - // like this: - // - // function foo(x) { - // x = 10; - // bar(x); // read of a param index, but not the param value passed in. - // } - for (auto* set : sets) { - // A nullptr value indicates there is no LocalSet* that sets the value, - // so it must be the parameter value. - if (!set) { - usedParams.insert(get->index); + auto usedParams = ParamUtils::getUsedParams(func); + for (Index i = 0; i < numParams; i++) { + if (usedParams.count(i) == 0) { + info->unusedParams.insert(i); } } } - - // We can now compute the unused params. - for (Index i = 0; i < numParams; i++) { - if (usedParams.count(i) == 0) { - info->unusedParams.insert(i); - } - } } }; @@ -315,38 +287,11 @@ struct DAE : public Pass { if (numParams == 0) { continue; } - // Iterate downwards, as we may remove more than one. - Index i = numParams - 1; - while (1) { - if (infoMap[name].unusedParams.has(i)) { - // Great, it's not used. Check if none of the calls has a param with - // side effects that we cannot remove (as if we can remove them, we - // will simply do that when we remove the parameter). Note: flattening - // the IR beforehand can help here. - bool callParamsAreValid = - std::none_of(calls.begin(), calls.end(), [&](Call* call) { - auto* operand = call->operands[i]; - return EffectAnalyzer(runner->options, *module, operand) - .hasUnremovableSideEffects(); - }); - // The type must be valid for us to handle as a local (since we - // replace the parameter with a local). - // TODO: if there are no references at all, we can avoid creating a - // local - bool typeIsValid = - TypeUpdating::canHandleAsLocal(func->getLocalType(i)); - if (callParamsAreValid && typeIsValid) { - // Wonderful, nothing stands in our way! Do it. - // TODO: parallelize this? - removeParameter(func, i, calls); - TypeUpdating::handleNonDefaultableLocals(func, *module); - changed.insert(func); - } - } - if (i == 0) { - break; - } - i--; + auto removedIndexes = ParamUtils::removeParameters( + {func}, infoMap[name].unusedParams, calls, {}, module, runner); + if (!removedIndexes.empty()) { + // Success! + changed.insert(func); } } // We can also tell which calls have all their return values dropped. Note @@ -395,42 +340,6 @@ struct DAE : public Pass { private: std::unordered_map<Call*, Expression**> allDroppedCalls; - void removeParameter(Function* func, Index i, std::vector<Call*>& calls) { - // It's cumbersome to adjust local names - TODO don't clear them? - Builder::clearLocalNames(func); - // Remove the parameter from the function. We must add a new local - // for uses of the parameter, but cannot make it use the same index - // (in general). - auto paramsType = func->getParams(); - std::vector<Type> params(paramsType.begin(), paramsType.end()); - auto type = params[i]; - params.erase(params.begin() + i); - func->setParams(Type(params)); - Index newIndex = Builder::addVar(func, type); - // Update local operations. - struct LocalUpdater : public PostWalker<LocalUpdater> { - Index removedIndex; - Index newIndex; - LocalUpdater(Function* func, Index removedIndex, Index newIndex) - : removedIndex(removedIndex), newIndex(newIndex) { - walk(func->body); - } - void visitLocalGet(LocalGet* curr) { updateIndex(curr->index); } - void visitLocalSet(LocalSet* curr) { updateIndex(curr->index); } - void updateIndex(Index& index) { - if (index == removedIndex) { - index = newIndex; - } else if (index > removedIndex) { - index--; - } - } - } localUpdater(func, i, newIndex); - // Remove the arguments from the calls. - for (auto* call : calls) { - call->operands.erase(call->operands.begin() + i); - } - } - void removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) { func->setResults(Type::none); |