summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2024-03-18 14:18:09 -0700
committerGitHub <noreply@github.com>2024-03-18 14:18:09 -0700
commitd8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8 (patch)
treeb46989cd5f1e397b6eda83ca7741a1a5d2716728 /src
parentc166ca015860b337e9ce07a5e02cb707964056ba (diff)
downloadbinaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.tar.gz
binaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.tar.bz2
binaryen-d8086c63a9e3e6bbd1bdc5d7e0843af8433cc4c8.zip
DeadArgumentElimination/SignaturePruning: Prune params even if called with effects (#6395)
Before this PR, when we saw a param was unused we sometimes could not remove it. For example, if there was one call like this: (call $target (call $other) ) That nested call has effects, so we can't just remove it from the outer call - we'd need to move it first. That motion was hard to integrate which was why it was left out, but it turns out that is sometimes very important. E.g. in Java it is common to have such calls that send the this parameter as the result of another call; not being able to remove such params meant we kept those nested calls alive, creating empty structs just to have something to send there. To fix this, this builds on top of #6394 which makes it easier to move all children out of a parent, leaving only nested things that can be easily moved around and removed. In more detail, DeadArgumentElimination/SignaturePruning track whether we run into effects that prevent removing a field. If we do, then we queue an operation to move the children out, which we do using a new utility ParamUtils::localizeCallsTo. The pass then does another iteration after that operation. Alternatively we could try to move things around immediately, but that is quite hard: those passes already track a lot of state. It is simpler to do the fixup in an entirely separate utility. That does come at the cost of the utility doing another pass on the module and the pass itself running another iteration, but this situation is not the most common.
Diffstat (limited to 'src')
-rw-r--r--src/passes/DeadArgumentElimination.cpp27
-rw-r--r--src/passes/SignaturePruning.cpp68
-rw-r--r--src/passes/param-utils.cpp183
-rw-r--r--src/passes/param-utils.h61
4 files changed, 277 insertions, 62 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index d0705961b..4a341571e 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -216,11 +216,26 @@ struct DAE : public Pass {
allDroppedCalls[name] = calls;
}
}
+
// Track which functions we changed, and optimize them later if necessary.
std::unordered_set<Function*> changed;
+
// If we refine return types then we will need to do more type updating
// at the end.
bool refinedReturnTypes = false;
+
+ // If we find that localizing call arguments can help (by moving their
+ // effects outside, so ParamUtils::removeParameters can handle them), then
+ // we do that at the end and perform another cycle. It is simpler to just do
+ // another cycle than to track the locations of calls, which is tricky as
+ // localization might move a call (if a call happens to be another call's
+ // param). In practice it is rare to find call arguments we want to remove,
+ // and even more rare to find effects get in the way, so this should not
+ // cause much overhead.
+ //
+ // This set tracks the functions for whom calls to it should be modified.
+ std::unordered_set<Name> callTargetsToLocalize;
+
// We now have a mapping of all call sites for each function, and can look
// for optimization opportunities.
for (auto& [name, calls] : allCalls) {
@@ -263,12 +278,15 @@ struct DAE : public Pass {
if (numParams == 0) {
continue;
}
- auto removedIndexes = ParamUtils::removeParameters(
+ auto [removedIndexes, outcome] = ParamUtils::removeParameters(
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
if (!removedIndexes.empty()) {
// Success!
changed.insert(func);
}
+ if (outcome == ParamUtils::RemovalOutcome::Failure) {
+ callTargetsToLocalize.insert(name);
+ }
}
// We can also tell which calls have all their return values dropped. Note
// that we can't do this if we changed anything so far, as we may have
@@ -307,10 +325,15 @@ struct DAE : public Pass {
changed.insert(func.get());
}
}
+ if (!callTargetsToLocalize.empty()) {
+ ParamUtils::localizeCallsTo(
+ callTargetsToLocalize, *module, getPassRunner());
+ }
if (optimize && !changed.empty()) {
OptUtils::optimizeAfterInlining(changed, module, getPassRunner());
}
- return !changed.empty() || refinedReturnTypes;
+ return !changed.empty() || refinedReturnTypes ||
+ !callTargetsToLocalize.empty();
}
private:
diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp
index 23295a66a..2e4be89e8 100644
--- a/src/passes/SignaturePruning.cpp
+++ b/src/passes/SignaturePruning.cpp
@@ -67,6 +67,16 @@ struct SignaturePruning : public Pass {
return;
}
+ // The first iteration may suggest additional work is possible. If so, run
+ // another cycle. (Even more cycles may help, but limit ourselves to 2 for
+ // now.)
+ if (iteration(module)) {
+ iteration(module);
+ }
+ }
+
+ // Returns true if more work is possible.
+ bool iteration(Module* module) {
// First, find all the information we need. Start by collecting inside each
// function in parallel.
@@ -101,6 +111,16 @@ struct SignaturePruning : public Pass {
// Map heap types to all functions with that type.
InsertOrderedMap<HeapType, std::vector<Function*>> sigFuncs;
+ // Heap types of call targets that we found we should localize calls to, in
+ // order to fully handle them. (See similar code in DeadArgumentElimination
+ // for individual functions; here we handle a HeapType at a time.) A slight
+ // complication is that we cannot track heap types here: heap types are
+ // rewritten using |GlobalTypeRewriter::updateSignatures| below, and even
+ // types that we do not modify end up replaced (as the entire set of types
+ // becomes one new big rec group). We therefore need something more stable
+ // to track here, which we do using either a Call or a Call Ref.
+ std::unordered_set<Expression*> callTargetsToLocalize;
+
// Combine all the information we gathered into that map, iterating in a
// deterministic order as we build up vectors where the order matters.
for (auto& f : module->functions) {
@@ -215,12 +235,23 @@ struct SignaturePruning : public Pass {
}
auto oldParams = sig.params;
- auto removedIndexes = ParamUtils::removeParameters(funcs,
- unusedParams,
- info.calls,
- info.callRefs,
- module,
- getPassRunner());
+ auto [removedIndexes, outcome] =
+ ParamUtils::removeParameters(funcs,
+ unusedParams,
+ info.calls,
+ info.callRefs,
+ module,
+ getPassRunner());
+ if (outcome == ParamUtils::RemovalOutcome::Failure) {
+ // Use either a Call or a CallRef that has this type (see explanation
+ // above on |callTargetsToLocalize|.
+ if (!info.calls.empty()) {
+ callTargetsToLocalize.insert(info.calls[0]);
+ } else {
+ assert(!info.callRefs.empty());
+ callTargetsToLocalize.insert(info.callRefs[0]);
+ }
+ }
if (removedIndexes.empty()) {
continue;
}
@@ -262,6 +293,31 @@ struct SignaturePruning : public Pass {
// Rewrite the types.
GlobalTypeRewriter::updateSignatures(newSignatures, *module);
+
+ if (callTargetsToLocalize.empty()) {
+ return false;
+ }
+
+ // Localize after updating signatures, to not interfere with that
+ // operation (localization adds locals, and the indexes of locals must be
+ // taken into account in |GlobalTypeRewriter::updateSignatures| (as var
+ // indexes change when params are pruned).
+ std::unordered_set<HeapType> callTargetTypes;
+ for (auto* call : callTargetsToLocalize) {
+ HeapType type;
+ if (auto* c = call->dynCast<Call>()) {
+ type = module->getFunction(c->target)->type;
+ } else if (auto* c = call->dynCast<CallRef>()) {
+ type = c->target->type.getHeapType();
+ } else {
+ WASM_UNREACHABLE("bad call");
+ }
+ callTargetTypes.insert(type);
+ }
+
+ ParamUtils::localizeCallsTo(callTargetTypes, *module, getPassRunner());
+
+ return true;
}
};
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
index e94ea95b1..0caccff11 100644
--- a/src/passes/param-utils.cpp
+++ b/src/passes/param-utils.cpp
@@ -14,11 +14,16 @@
* limitations under the License.
*/
+#include "passes/param-utils.h"
+#include "ir/eh-utils.h"
#include "ir/function-utils.h"
#include "ir/local-graph.h"
+#include "ir/localize.h"
#include "ir/possible-constant.h"
#include "ir/type-updating.h"
+#include "pass.h"
#include "support/sorted_vector.h"
+#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm::ParamUtils {
@@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) {
return usedParams;
}
-bool removeParameter(const std::vector<Function*>& funcs,
- Index index,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
@@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs,
// propagating that out, or by appending an unreachable after the call, but
// for simplicity just ignore such cases; if we are called again later then
// if DCE ran meanwhile then we could optimize.
- auto hasBadEffects = [&](auto* call) {
- auto& operands = call->operands;
- bool hasUnremovable =
- EffectAnalyzer(runner->options, *module, operands[index])
- .hasUnremovableSideEffects();
- bool wouldChangeType = call->type == Type::unreachable && !call->isReturn &&
- operands[index]->type == Type::unreachable;
- return hasUnremovable || wouldChangeType;
+ auto checkEffects = [&](auto* call) {
+ auto* operand = call->operands[index];
+
+ if (operand->type == Type::unreachable) {
+ return Failure;
+ }
+
+ bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand)
+ .hasUnremovableSideEffects();
+
+ return hasUnremovable ? Failure : Success;
};
- bool callParamsAreValid =
- std::none_of(calls.begin(), calls.end(), [&](Call* call) {
- return hasBadEffects(call);
- });
- if (!callParamsAreValid) {
- return false;
+
+ for (auto* call : calls) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
- bool callRefParamsAreValid =
- std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) {
- return hasBadEffects(call);
- });
- if (!callRefParamsAreValid) {
- return false;
+
+ for (auto* call : callRefs) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
// The type must be valid for us to handle as a local (since we
@@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs,
// local
bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index));
if (!typeIsValid) {
- return false;
+ return Failure;
}
// We can do it!
@@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs,
call->operands.erase(call->operands.begin() + index);
}
- return true;
+ return Success;
}
-SortedVector removeParameters(const std::vector<Function*>& funcs,
- SortedVector indexes,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+std::pair<SortedVector, RemovalOutcome>
+removeParameters(const std::vector<Function*>& funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
if (indexes.empty()) {
- return {};
+ return {{}, Success};
}
assert(funcs.size() > 0);
@@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
SortedVector removed;
while (1) {
if (indexes.has(i)) {
- if (removeParameter(funcs, i, calls, callRefs, module, runner)) {
- // Success!
+ auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner);
+ if (outcome == Success) {
removed.insert(i);
}
}
@@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
}
i--;
}
- return removed;
+ RemovalOutcome finalOutcome = Success;
+ if (removed.size() < indexes.size()) {
+ finalOutcome = Failure;
+ }
+ return {removed, finalOutcome};
}
SortedVector applyConstantValues(const std::vector<Function*>& funcs,
@@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
return optimized;
}
+void localizeCallsTo(const std::unordered_set<Name>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<Name>& callTargets;
+
+ LocalizerPass(const std::unordered_set<Name>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ if (!callTargets.count(curr->target)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ curr, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != curr) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ // Localization can add blocks, which might move pops.
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
+void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<HeapType>& callTargets;
+
+ LocalizerPass(const std::unordered_set<HeapType>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ handleCall(curr, getModule()->getFunction(curr->target)->type);
+ }
+
+ void visitCallRef(CallRef* curr) {
+ auto type = curr->target->type;
+ if (type.isRef()) {
+ handleCall(curr, type.getHeapType());
+ }
+ }
+
+ void handleCall(Expression* call, HeapType type) {
+ if (!callTargets.count(type)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ call, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != call) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
} // namespace wasm::ParamUtils
diff --git a/src/passes/param-utils.h b/src/passes/param-utils.h
index 202c8b007..4c458390a 100644
--- a/src/passes/param-utils.h
+++ b/src/passes/param-utils.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef wasm_ir_function_h
-#define wasm_ir_function_h
+#ifndef wasm_pass_param_utils_h
+#define wasm_pass_param_utils_h
#include "pass.h"
#include "support/sorted_vector.h"
@@ -44,6 +44,16 @@ namespace wasm::ParamUtils {
// }
std::unordered_set<Index> getUsedParams(Function* func);
+// The outcome of an attempt to remove a parameter(s).
+enum RemovalOutcome {
+ // We removed successfully.
+ Success = 0,
+ // We failed, but only because of fixable nested effects. The caller can move
+ // those effects out (e.g. using ChildLocalizer, or the helper localizeCallsTo
+ // below) and repeat.
+ Failure = 1,
+};
+
// Try to remove a parameter from a set of functions and replace it with a local
// instead. This may not succeed if the parameter type cannot be used in a
// local, or if we hit another limitation, in which case this returns false and
@@ -64,21 +74,26 @@ std::unordered_set<Index> getUsedParams(Function* func);
// need adjusting and it is easier to do it all in one place. Also, the caller
// can update all the types at once throughout the program after making
// multiple calls to removeParameter().
-bool removeParameter(const std::vector<Function*>& funcs,
- Index index,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner);
+RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner);
// The same as removeParameter, but gets a sorted list of indexes. It tries to
-// remove them all, and returns which we removed.
-SortedVector removeParameters(const std::vector<Function*>& funcs,
- SortedVector indexes,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner);
+// remove them all, and returns which we removed, as well as an indication as
+// to whether we might remove more if effects were not in the way (specifically,
+// we return Success if we removed any index, Failure if we removed none, and
+// FailureDueToEffects if at least one index could have been removed but for
+// effects).
+std::pair<SortedVector, RemovalOutcome>
+removeParameters(const std::vector<Function*>& funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner);
// Given a set of functions and the calls and call_refs that reach them, find
// which parameters are passed the same constant value in all the calls. For
@@ -92,6 +107,20 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
const std::vector<CallRef*>& callRefs,
Module* module);
+// Helper that localizes all calls to a set of targets, in an entire module.
+// This basically calls ChildLocalizer in each function, on the relevant calls.
+// This is useful when we get FailureDueToEffects, see above.
+//
+// The set of targets can be function names (the individual functions we want to
+// handle calls towards) or heap types (which will then include all functions
+// with those types).
+void localizeCallsTo(const std::unordered_set<Name>& callTargets,
+ Module& wasm,
+ PassRunner* runner);
+void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
+ Module& wasm,
+ PassRunner* runner);
+
} // namespace wasm::ParamUtils
-#endif // wasm_ir_function_h
+#endif // wasm_pass_param_utils_h