diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 92 |
1 files changed, 84 insertions, 8 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index a95d4f91a..a2231469d 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -15,10 +15,8 @@ */ // -// Optimizes call arguments in a whole-program manner, removing ones -// that are not used (dead). -// -// Specifically, this does these things: +// Optimizes call arguments in a whole-program manner. In particular, this +// removes ones that are not used (dead), but it also does more things: // // * Find functions for whom an argument is always passed the same // constant. If so, we can just set that local to that constant @@ -28,6 +26,8 @@ // the previous point was true for an argument, then the second // must as well.) // * Find return values ("return arguments" ;) that are never used. +// * Refine the types of arguments, that is make the argument type more +// specific if all the passed values allow that. // // This pass does not depend on flattening, but it may be more effective, // as then call arguments never have side effects (which we need to @@ -40,6 +40,7 @@ #include "cfg/cfg-traversal.h" #include "ir/effects.h" #include "ir/element-utils.h" +#include "ir/find_all.h" #include "ir/module-utils.h" #include "ir/type-updating.h" #include "pass.h" @@ -307,18 +308,22 @@ struct DAE : public Pass { allDroppedCalls[pair.first] = pair.second; } } - // We now have a mapping of all call sites for each function. Check which - // are always passed the same constant for a particular argument. + // We now have a mapping of all call sites for each function, and can look + // for optimization opportunities. for (auto& pair : allCalls) { auto name = pair.first; - // We can only optimize if we see all the calls and can modify - // them. + // We can only optimize if we see all the calls and can modify them. if (infoMap[name].hasUnseenCalls) { continue; } auto& calls = pair.second; auto* func = module->getFunction(name); auto numParams = func->getNumParams(); + // Refine argument types before doing anything else. This does not + // affect whether an argument is used or not, it just refines the type + // where possible. + refineArgumentTypes(func, calls, module); + // Check if all calls pass the same constant for a particular argument. for (Index i = 0; i < numParams; i++) { Literal value; for (auto* call : calls) { @@ -515,6 +520,77 @@ private: } } } + + // Given a function and all the calls to it, see if we can refine the type of + // its arguments. If we only pass in a subtype, we may as well refine the type + // to that. + // + // This assumes that the function has no calls aside from |calls|, that is, it + // is not exported or called from the table or by reference. + void refineArgumentTypes(Function* func, + const std::vector<Call*>& calls, + Module* module) { + if (!module->features.hasGC()) { + return; + } + auto numParams = func->getNumParams(); + std::vector<Type> newParamTypes; + newParamTypes.reserve(numParams); + for (Index i = 0; i < numParams; i++) { + auto originalType = func->getLocalType(i); + if (!originalType.isRef()) { + newParamTypes.push_back(originalType); + continue; + } + Type refinedType = Type::unreachable; + for (auto* call : calls) { + auto* operand = call->operands[i]; + refinedType = Type::getLeastUpperBound(refinedType, operand->type); + if (refinedType == originalType) { + // We failed to refine this parameter to anything more specific. + break; + } + } + + // Nothing is sent here at all; leave such optimizations to DCE. + if (refinedType == Type::unreachable) { + return; + } + newParamTypes.push_back(refinedType); + } + + // Check if we are able to optimize here before we do the work to scan the + // function body. + if (Type(newParamTypes) == func->getParams()) { + return; + } + + // In terms of parameters, we can do this. However, we must also check + // local operations in the body, as if the parameter is reused and written + // to, then those types must be taken into account as well. + for (auto* set : FindAll<LocalSet>(func->body).list) { + auto index = set->index; + if (func->isParam(index) && + !Type::isSubType(set->value->type, newParamTypes[index])) { + // TODO: we could still optimize here, by creating a new local. + newParamTypes[index] = func->getLocalType(index); + } + } + + auto newParams = Type(newParamTypes); + if (newParams == func->getParams()) { + return; + } + + // We can do this! Update the types, including the types of gets. + func->setParams(newParams); + for (auto* get : FindAll<LocalGet>(func->body).list) { + auto index = get->index; + if (func->isParam(index)) { + get->type = func->getLocalType(index); + } + } + } }; Pass* createDAEPass() { return new DAE(); } |