diff options
author | Alon Zakai <azakai@google.com> | 2024-03-18 14:18:09 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-18 14:18:09 -0700 |
commit | d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8 (patch) | |
tree | b46989cd5f1e397b6eda83ca7741a1a5d2716728 /src/passes/param-utils.cpp | |
parent | c166ca015860b337e9ce07a5e02cb707964056ba (diff) | |
download | binaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.tar.gz binaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.tar.bz2 binaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.zip |
DeadArgumentElimination/SignaturePruning: Prune params even if called with effects (#6395)
Before this PR, when we saw a param was unused we sometimes could not remove it.
For example, if there was one call like this:
(call $target
(call $other)
)
That nested call has effects, so we can't just remove it from the outer call - we'd need to
move it first. That motion was hard to integrate which was why it was left out, but it
turns out that is sometimes very important. E.g. in Java it is common to have such calls
that send the this parameter as the result of another call; not being able to remove such
params meant we kept those nested calls alive, creating empty structs just to have
something to send there.
To fix this, this builds on top of #6394 which makes it easier to move all children out of
a parent, leaving only nested things that can be easily moved around and removed. In
more detail, DeadArgumentElimination/SignaturePruning track whether we run into effects that
prevent removing a field. If we do, then we queue an operation to move the children
out, which we do using a new utility ParamUtils::localizeCallsTo. The pass then does
another iteration after that operation.
Alternatively we could try to move things around immediately, but that is quite hard:
those passes already track a lot of state. It is simpler to do the fixup in an entirely
separate utility. That does come at the cost of the utility doing another pass on the
module and the pass itself running another iteration, but this situation is not the most
common.
Diffstat (limited to 'src/passes/param-utils.cpp')
-rw-r--r-- | src/passes/param-utils.cpp | 183 |
1 files changed, 145 insertions, 38 deletions
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp index e94ea95b1..0caccff11 100644 --- a/src/passes/param-utils.cpp +++ b/src/passes/param-utils.cpp @@ -14,11 +14,16 @@ * limitations under the License. */ +#include "passes/param-utils.h" +#include "ir/eh-utils.h" #include "ir/function-utils.h" #include "ir/local-graph.h" +#include "ir/localize.h" #include "ir/possible-constant.h" #include "ir/type-updating.h" +#include "pass.h" #include "support/sorted_vector.h" +#include "wasm-traversal.h" #include "wasm.h" namespace wasm::ParamUtils { @@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) { return usedParams; } -bool removeParameter(const std::vector<Function*>& funcs, - Index index, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +RemovalOutcome removeParameter(const std::vector<Function*>& funcs, + Index index, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { assert(funcs.size() > 0); auto* first = funcs[0]; #ifndef NDEBUG @@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs, // propagating that out, or by appending an unreachable after the call, but // for simplicity just ignore such cases; if we are called again later then // if DCE ran meanwhile then we could optimize. - auto hasBadEffects = [&](auto* call) { - auto& operands = call->operands; - bool hasUnremovable = - EffectAnalyzer(runner->options, *module, operands[index]) - .hasUnremovableSideEffects(); - bool wouldChangeType = call->type == Type::unreachable && !call->isReturn && - operands[index]->type == Type::unreachable; - return hasUnremovable || wouldChangeType; + auto checkEffects = [&](auto* call) { + auto* operand = call->operands[index]; + + if (operand->type == Type::unreachable) { + return Failure; + } + + bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand) + .hasUnremovableSideEffects(); + + return hasUnremovable ? Failure : Success; }; - bool callParamsAreValid = - std::none_of(calls.begin(), calls.end(), [&](Call* call) { - return hasBadEffects(call); - }); - if (!callParamsAreValid) { - return false; + + for (auto* call : calls) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } - bool callRefParamsAreValid = - std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) { - return hasBadEffects(call); - }); - if (!callRefParamsAreValid) { - return false; + + for (auto* call : callRefs) { + auto result = checkEffects(call); + if (result != Success) { + return result; + } } // The type must be valid for us to handle as a local (since we @@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs, // local bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index)); if (!typeIsValid) { - return false; + return Failure; } // We can do it! @@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs, call->operands.erase(call->operands.begin() + index); } - return true; + return Success; } -SortedVector removeParameters(const std::vector<Function*>& funcs, - SortedVector indexes, - const std::vector<Call*>& calls, - const std::vector<CallRef*>& callRefs, - Module* module, - PassRunner* runner) { +std::pair<SortedVector, RemovalOutcome> +removeParameters(const std::vector<Function*>& funcs, + SortedVector indexes, + const std::vector<Call*>& calls, + const std::vector<CallRef*>& callRefs, + Module* module, + PassRunner* runner) { if (indexes.empty()) { - return {}; + return {{}, Success}; } assert(funcs.size() > 0); @@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, SortedVector removed; while (1) { if (indexes.has(i)) { - if (removeParameter(funcs, i, calls, callRefs, module, runner)) { - // Success! + auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner); + if (outcome == Success) { removed.insert(i); } } @@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs, } i--; } - return removed; + RemovalOutcome finalOutcome = Success; + if (removed.size() < indexes.size()) { + finalOutcome = Failure; + } + return {removed, finalOutcome}; } SortedVector applyConstantValues(const std::vector<Function*>& funcs, @@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs, return optimized; } +void localizeCallsTo(const std::unordered_set<Name>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<Name>& callTargets; + + LocalizerPass(const std::unordered_set<Name>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + if (!callTargets.count(curr->target)) { + return; + } + + ChildLocalizer localizer( + curr, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != curr) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + // Localization can add blocks, which might move pops. + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + +void localizeCallsTo(const std::unordered_set<HeapType>& callTargets, + Module& wasm, + PassRunner* runner) { + struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<LocalizerPass>(callTargets); + } + + const std::unordered_set<HeapType>& callTargets; + + LocalizerPass(const std::unordered_set<HeapType>& callTargets) + : callTargets(callTargets) {} + + void visitCall(Call* curr) { + handleCall(curr, getModule()->getFunction(curr->target)->type); + } + + void visitCallRef(CallRef* curr) { + auto type = curr->target->type; + if (type.isRef()) { + handleCall(curr, type.getHeapType()); + } + } + + void handleCall(Expression* call, HeapType type) { + if (!callTargets.count(type)) { + return; + } + + ChildLocalizer localizer( + call, getFunction(), *getModule(), getPassOptions()); + auto* replacement = localizer.getReplacement(); + if (replacement != call) { + replaceCurrent(replacement); + optimized = true; + } + } + + bool optimized = false; + + void visitFunction(Function* curr) { + if (optimized) { + EHUtils::handleBlockNestedPops(curr, *getModule()); + } + } + }; + + LocalizerPass(callTargets).run(runner, &wasm); +} + } // namespace wasm::ParamUtils |