diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 27 | ||||
-rw-r--r-- | src/passes/SignaturePruning.cpp | 68 | ||||
-rw-r--r-- | src/passes/param-utils.cpp | 183 | ||||
-rw-r--r-- | src/passes/param-utils.h | 61 |
4 files changed, 277 insertions, 62 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index d0705961b..4a341571e 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -216,11 +216,26 @@ struct DAE : public Pass { allDroppedCalls[name] = calls; } } + // Track which functions we changed, and optimize them later if necessary. std::unordered_set<Function*> changed; + // If we refine return types then we will need to do more type updating // at the end. bool refinedReturnTypes = false; + + // If we find that localizing call arguments can help (by moving their + // effects outside, so ParamUtils::removeParameters can handle them), then + // we do that at the end and perform another cycle. It is simpler to just do + // another cycle than to track the locations of calls, which is tricky as + // localization might move a call (if a call happens to be another call's + // param). In practice it is rare to find call arguments we want to remove, + // and even more rare to find effects get in the way, so this should not + // cause much overhead. + // + // This set tracks the functions for whom calls to it should be modified. + std::unordered_set<Name> callTargetsToLocalize; + // We now have a mapping of all call sites for each function, and can look // for optimization opportunities. for (auto& [name, calls] : allCalls) { @@ -263,12 +278,15 @@ struct DAE : public Pass { if (numParams == 0) { continue; } - auto removedIndexes = ParamUtils::removeParameters( + auto [removedIndexes, outcome] = ParamUtils::removeParameters( {func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner()); if (!removedIndexes.empty()) { // Success! changed.insert(func); } + if (outcome == ParamUtils::RemovalOutcome::Failure) { + callTargetsToLocalize.insert(name); + } } // We can also tell which calls have all their return values dropped. Note // that we can't do this if we changed anything so far, as we may have @@ -307,10 +325,15 @@ struct DAE : public Pass { changed.insert(func.get()); } } + if (!callTargetsToLocalize.empty()) { + ParamUtils::localizeCallsTo( + callTargetsToLocalize, *module, getPassRunner()); + } if (optimize && !changed.empty()) { OptUtils::optimizeAfterInlining(changed, module, getPassRunner()); } - return !changed.empty() || refinedReturnTypes; + return !changed.empty() || refinedReturnTypes || + !callTargetsToLocalize.empty(); } private: diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp index 23295a66a..2e4be89e8 100644 --- a/src/passes/SignaturePruning.cpp +++ b/src/passes/SignaturePruning.cpp @@ -67,6 +67,16 @@ struct SignaturePruning : public Pass { return; } + // The first iteration may suggest additional work is possible. If so, run + // another cycle. (Even more cycles may help, but limit ourselves to 2 for + // now.) + if (iteration(module)) { + iteration(module); + } + } + + // Returns true if more work is possible. + bool iteration(Module* module) { // First, find all the information we need. Start by collecting inside each // function in parallel. @@ -101,6 +111,16 @@ struct SignaturePruning : public Pass { // Map heap types to all functions with that type. InsertOrderedMap<HeapType, std::vector<Function*>> sigFuncs; + // Heap types of call targets that we found we should localize calls to, in + // order to fully handle them. (See similar code in DeadArgumentElimination + // for individual functions; here we handle a HeapType at a time.) A slight + // complication is that we cannot track heap types here: heap types are + // rewritten using |GlobalTypeRewriter::updateSignatures| below, and even + // types that we do not modify end up replaced (as the entire set of types + // becomes one new big rec group). We therefore need something more stable + // to track here, which we do using either a Call or a Call Ref. + std::unordered_set<Expression*> callTargetsToLocalize; + // Combine all the information we gathered into that map, iterating in a // deterministic order as we build up vectors where the order matters. for (auto& f : module->functions) { @@ -215,12 +235,23 @@ struct SignaturePruning : public Pass { } auto oldParams = sig.params; - auto removedIndexes = ParamUtils::removeParameters(funcs, - unusedParams, - info.calls, - info.callRefs, - module, - getPassRunner()); + auto [removedIndexes, outcome] = + ParamUtils::removeParameters(funcs, + unusedParams, + info.calls, + info.callRefs, + module, + getPassRunner()); + if (outcome == ParamUtils::RemovalOutcome::Failure) { + // Use either a Call or a CallRef that has this type (see explanation + // above on |callTargetsToLocalize|. + if (!info.calls.empty()) { + callTargetsToLocalize.insert(info.calls[0]); + } else { + assert(!info.callRefs.empty()); + callTargetsToLocalize.insert(info.callRefs[0]); + } + } if (removedIndexes.empty()) { continue; } @@ -262,6 +293,31 @@ struct SignaturePruning : public Pass { // Rewrite the types. GlobalTypeRewriter::updateSignatures(newSignatures, *module); + + if (callTargetsToLocalize.empty()) { + return false; + } + + // Localize after updating signatures, to not interfere with that + // operation (localization adds locals, and the indexes of locals must be + // taken into account in |GlobalTypeRewriter::updateSignatures| (as var + // indexes change when params are pruned). + std::unordered_set<HeapType> callTargetTypes; + for (auto* call : callTargetsToLocalize) { + HeapType type; + if (auto* c = call->dynCast<Call>()) { + type = module->getFunction(c->target)->type; + } else if (auto* c = call->dynCast<CallRef>()) { + type = c->target->type.getHeapType(); + } else { + WASM_UNREACHABLE("bad call"); + } + callTargetTypes.insert(type); + } + + ParamUtils::localizeCallsTo(callTargetTypes, *module, getPassRunner()); + + return true; } }; diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp index e94ea95b1..0caccff11 100644 --- a/src/passes/param-utils.cpp +++ b/src/passes/param-utils.cpp @@ -14,11 +14,16 @@ * limitations under the License. */ +#include "passes/param-utils.h" +#include "ir/eh-utils.h" #include "ir/function-utils.h" #include "ir/local-graph.h" +#include "ir/localize.h" #include "ir/possible-constant.h" #include "ir/type-updating.h" +#include "pass.h" #include "support/sorted_vector.h" +#include "wasm-traversal.h" #include "wasm.h" namespace wasm::ParamUtils { @@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) { return usedParams; } -bool removeParameter(const std::vector<Function*>& funcs, - Index index, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +RemovalOutcome removeParameter(const std::vector<Function*>& funcs, + Index index, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { assert(funcs.size() > 0); auto* first = funcs[0]; #ifndef NDEBUG @@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs, // propagating that out, or by appending an unreachable after the call, but // for simplicity just ignore such cases; if we are called again later then // if DCE ran meanwhile then we could optimize. - auto hasBadEffects = [&](auto* call) { - auto& operands = call->operands; - bool hasUnremovable = - EffectAnalyzer(runner->options, *module, operands[index]) - .hasUnremovableSideEffects(); - bool wouldChangeType = call->type == Type::unreachable && !call->isReturn && - operands[index]->type == Type::unreachable; - return hasUnremovable || wouldChangeType; + auto checkEffects = [&](auto* call) { + auto* operand = call->operands[index]; + + if (operand->type == Type::unreachable) { + return Failure; + } + + bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand) + .hasUnremovableSideEffects(); + + return hasUnremovable ? Failure : Success; }; - bool callParamsAreValid = - std::none_of(calls.begin(), calls.end(), [&](Call* call) { - return hasBadEffects(call); - }); - if (!callParamsAreValid) { - return false; + + for (auto* call : calls) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } - bool callRefParamsAreValid = - std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) { - return hasBadEffects(call); - }); - if (!callRefParamsAreValid) { - return false; + + for (auto* call : callRefs) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } // The type must be valid for us to handle as a local (since we @@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs, // local bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index)); if (!typeIsValid) { - return false; + return Failure; } // We can do it! @@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs, call->operands.erase(call->operands.begin() + index); } - return true; + return Success; } -SortedVector removeParameters(const std::vector<Function*>& funcs, - SortedVector indexes, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +std::pair<SortedVector, RemovalOutcome> +removeParameters(const std::vector<Function*>& funcs, + SortedVector indexes, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { if (indexes.empty()) { - return {}; + return {{}, Success}; } assert(funcs.size() > 0); @@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, SortedVector removed; while (1) { if (indexes.has(i)) { - if (removeParameter(funcs, i, calls, callRefs, module, runner)) { - // Success! + auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner); + if (outcome == Success) { removed.insert(i); } } @@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, } i--; } - return removed; + RemovalOutcome finalOutcome = Success; + if (removed.size() < indexes.size()) { + finalOutcome = Failure; + } + return {removed, finalOutcome}; } SortedVector applyConstantValues(const std::vector<Function*>& funcs, @@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, return optimized; } +void localizeCallsTo(const std::unordered_set<Name>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<Name>& callTargets; + + LocalizerPass(const std::unordered_set<Name>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + if (!callTargets.count(curr->target)) { + return; + } + + ChildLocalizer localizer( + curr, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != curr) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + // Localization can add blocks, which might move pops. + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + +void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<HeapType>& callTargets; + + LocalizerPass(const std::unordered_set<HeapType>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + handleCall(curr, getModule()->getFunction(curr->target)->type); + } + + void visitCallRef(CallRef* curr) { + auto type = curr->target->type; + if (type.isRef()) { + handleCall(curr, type.getHeapType()); + } + } + + void handleCall(Expression* call, HeapType type) { + if (!callTargets.count(type)) { + return; + } + + ChildLocalizer localizer( + call, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != call) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + } // namespace wasm::ParamUtils diff --git a/src/passes/param-utils.h b/src/passes/param-utils.h index 202c8b007..4c458390a 100644 --- a/src/passes/param-utils.h +++ b/src/passes/param-utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef wasm_ir_function_h -#define wasm_ir_function_h +#ifndef wasm_pass_param_utils_h +#define wasm_pass_param_utils_h #include "pass.h" #include "support/sorted_vector.h" @@ -44,6 +44,16 @@ namespace wasm::ParamUtils { // } std::unordered_set<Index> getUsedParams(Function* func); +// The outcome of an attempt to remove a parameter(s). +enum RemovalOutcome { + // We removed successfully. + Success = 0, + // We failed, but only because of fixable nested effects. The caller can move + // those effects out (e.g. using ChildLocalizer, or the helper localizeCallsTo + // below) and repeat. + Failure = 1, +}; + // Try to remove a parameter from a set of functions and replace it with a local // instead. This may not succeed if the parameter type cannot be used in a // local, or if we hit another limitation, in which case this returns false and @@ -64,21 +74,26 @@ std::unordered_set<Index> getUsedParams(Function* func); // need adjusting and it is easier to do it all in one place. Also, the caller // can update all the types at once throughout the program after making // multiple calls to removeParameter(). -bool removeParameter(const std::vector<Function*>& funcs, - Index index, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner); +RemovalOutcome removeParameter(const std::vector<Function*>& funcs, + Index index, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner); // The same as removeParameter, but gets a sorted list of indexes. It tries to -// remove them all, and returns which we removed. -SortedVector removeParameters(const std::vector<Function*>& funcs, - SortedVector indexes, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner); +// remove them all, and returns which we removed, as well as an indication as +// to whether we might remove more if effects were not in the way (specifically, +// we return Success if we removed any index, Failure if we removed none, and +// FailureDueToEffects if at least one index could have been removed but for +// effects). +std::pair<SortedVector, RemovalOutcome> +removeParameters(const std::vector<Function*>& funcs, + SortedVector indexes, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner); // Given a set of functions and the calls and call_refs that reach them, find // which parameters are passed the same constant value in all the calls. For @@ -92,6 +107,20 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, const std::vector<CallRef*>& callRefs, Module* module); +// Helper that localizes all calls to a set of targets, in an entire module. +// This basically calls ChildLocalizer in each function, on the relevant calls. +// This is useful when we get FailureDueToEffects, see above. +// +// The set of targets can be function names (the individual functions we want to +// handle calls towards) or heap types (which will then include all functions +// with those types). +void localizeCallsTo(const std::unordered_set<Name>& callTargets, + Module& wasm, + PassRunner* runner); +void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, + Module& wasm, + PassRunner* runner); + } // namespace wasm::ParamUtils -#endif // wasm_ir_function_h +#endif // wasm_pass_param_utils_h |