diff options
Diffstat (limited to 'src/passes/param-utils.cpp')
-rw-r--r-- | src/passes/param-utils.cpp | 183 |
1 files changed, 145 insertions, 38 deletions
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp index e94ea95b1..0caccff11 100644 --- a/src/passes/param-utils.cpp +++ b/src/passes/param-utils.cpp @@ -14,11 +14,16 @@ * limitations under the License. */ +#include "passes/param-utils.h" +#include "ir/eh-utils.h" #include "ir/function-utils.h" #include "ir/local-graph.h" +#include "ir/localize.h" #include "ir/possible-constant.h" #include "ir/type-updating.h" +#include "pass.h" #include "support/sorted_vector.h" +#include "wasm-traversal.h" #include "wasm.h" namespace wasm::ParamUtils { @@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) { return usedParams; } -bool removeParameter(const std::vector<Function*>& funcs, - Index index, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +RemovalOutcome removeParameter(const std::vector<Function*>& funcs, + Index index, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { assert(funcs.size() > 0); auto* first = funcs[0]; #ifndef NDEBUG @@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs, // propagating that out, or by appending an unreachable after the call, but // for simplicity just ignore such cases; if we are called again later then // if DCE ran meanwhile then we could optimize. - auto hasBadEffects = [&](auto* call) { - auto& operands = call->operands; - bool hasUnremovable = - EffectAnalyzer(runner->options, *module, operands[index]) - .hasUnremovableSideEffects(); - bool wouldChangeType = call->type == Type::unreachable && !call->isReturn && - operands[index]->type == Type::unreachable; - return hasUnremovable || wouldChangeType; + auto checkEffects = [&](auto* call) { + auto* operand = call->operands[index]; + + if (operand->type == Type::unreachable) { + return Failure; + } + + bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand) + .hasUnremovableSideEffects(); + + return hasUnremovable ? Failure : Success; }; - bool callParamsAreValid = - std::none_of(calls.begin(), calls.end(), [&](Call* call) { - return hasBadEffects(call); - }); - if (!callParamsAreValid) { - return false; + + for (auto* call : calls) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } - bool callRefParamsAreValid = - std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) { - return hasBadEffects(call); - }); - if (!callRefParamsAreValid) { - return false; + + for (auto* call : callRefs) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } // The type must be valid for us to handle as a local (since we @@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs, // local bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index)); if (!typeIsValid) { - return false; + return Failure; } // We can do it! @@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs, call->operands.erase(call->operands.begin() + index); } - return true; + return Success; } -SortedVector removeParameters(const std::vector<Function*>& funcs, - SortedVector indexes, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +std::pair<SortedVector, RemovalOutcome> +removeParameters(const std::vector<Function*>& funcs, + SortedVector indexes, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { if (indexes.empty()) { - return {}; + return {{}, Success}; } assert(funcs.size() > 0); @@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, SortedVector removed; while (1) { if (indexes.has(i)) { - if (removeParameter(funcs, i, calls, callRefs, module, runner)) { - // Success! + auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner); + if (outcome == Success) { removed.insert(i); } } @@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, } i--; } - return removed; + RemovalOutcome finalOutcome = Success; + if (removed.size() < indexes.size()) { + finalOutcome = Failure; + } + return {removed, finalOutcome}; } SortedVector applyConstantValues(const std::vector<Function*>& funcs, @@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, return optimized; } +void localizeCallsTo(const std::unordered_set<Name>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<Name>& callTargets; + + LocalizerPass(const std::unordered_set<Name>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + if (!callTargets.count(curr->target)) { + return; + } + + ChildLocalizer localizer( + curr, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != curr) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + // Localization can add blocks, which might move pops. + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + +void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<HeapType>& callTargets; + + LocalizerPass(const std::unordered_set<HeapType>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + handleCall(curr, getModule()->getFunction(curr->target)->type); + } + + void visitCallRef(CallRef* curr) { + auto type = curr->target->type; + if (type.isRef()) { + handleCall(curr, type.getHeapType()); + } + } + + void handleCall(Expression* call, HeapType type) { + if (!callTargets.count(type)) { + return; + } + + ChildLocalizer localizer( + call, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != call) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + } // namespace wasm::ParamUtils |