summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/DeadArgumentElimination.cpp27
-rw-r--r--src/passes/SignaturePruning.cpp68
-rw-r--r--src/passes/param-utils.cpp183
-rw-r--r--src/passes/param-utils.h61
4 files changed, 277 insertions, 62 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index d0705961b..4a341571e 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -216,11 +216,26 @@ struct DAE : public Pass {
allDroppedCalls[name] = calls;
}
}
+
// Track which functions we changed, and optimize them later if necessary.
std::unordered_set<Function*> changed;
+
// If we refine return types then we will need to do more type updating
// at the end.
bool refinedReturnTypes = false;
+
+ // If we find that localizing call arguments can help (by moving their
+ // effects outside, so ParamUtils::removeParameters can handle them), then
+ // we do that at the end and perform another cycle. It is simpler to just do
+ // another cycle than to track the locations of calls, which is tricky as
+ // localization might move a call (if a call happens to be another call's
+ // param). In practice it is rare to find call arguments we want to remove,
+ // and even more rare to find effects get in the way, so this should not
+ // cause much overhead.
+ //
+ // This set tracks the functions for whom calls to it should be modified.
+ std::unordered_set<Name> callTargetsToLocalize;
+
// We now have a mapping of all call sites for each function, and can look
// for optimization opportunities.
for (auto& [name, calls] : allCalls) {
@@ -263,12 +278,15 @@ struct DAE : public Pass {
if (numParams == 0) {
continue;
}
- auto removedIndexes = ParamUtils::removeParameters(
+ auto [removedIndexes, outcome] = ParamUtils::removeParameters(
{func}, infoMap[name].unusedParams, calls, {}, module, getPassRunner());
if (!removedIndexes.empty()) {
// Success!
changed.insert(func);
}
+ if (outcome == ParamUtils::RemovalOutcome::Failure) {
+ callTargetsToLocalize.insert(name);
+ }
}
// We can also tell which calls have all their return values dropped. Note
// that we can't do this if we changed anything so far, as we may have
@@ -307,10 +325,15 @@ struct DAE : public Pass {
changed.insert(func.get());
}
}
+ if (!callTargetsToLocalize.empty()) {
+ ParamUtils::localizeCallsTo(
+ callTargetsToLocalize, *module, getPassRunner());
+ }
if (optimize && !changed.empty()) {
OptUtils::optimizeAfterInlining(changed, module, getPassRunner());
}
- return !changed.empty() || refinedReturnTypes;
+ return !changed.empty() || refinedReturnTypes ||
+ !callTargetsToLocalize.empty();
}
private:
diff --git a/src/passes/SignaturePruning.cpp b/src/passes/SignaturePruning.cpp
index 23295a66a..2e4be89e8 100644
--- a/src/passes/SignaturePruning.cpp
+++ b/src/passes/SignaturePruning.cpp
@@ -67,6 +67,16 @@ struct SignaturePruning : public Pass {
return;
}
+ // The first iteration may suggest additional work is possible. If so, run
+ // another cycle. (Even more cycles may help, but limit ourselves to 2 for
+ // now.)
+ if (iteration(module)) {
+ iteration(module);
+ }
+ }
+
+ // Returns true if more work is possible.
+ bool iteration(Module* module) {
// First, find all the information we need. Start by collecting inside each
// function in parallel.
@@ -101,6 +111,16 @@ struct SignaturePruning : public Pass {
// Map heap types to all functions with that type.
InsertOrderedMap<HeapType, std::vector<Function*>> sigFuncs;
+ // Heap types of call targets that we found we should localize calls to, in
+ // order to fully handle them. (See similar code in DeadArgumentElimination
+ // for individual functions; here we handle a HeapType at a time.) A slight
+ // complication is that we cannot track heap types here: heap types are
+ // rewritten using |GlobalTypeRewriter::updateSignatures| below, and even
+ // types that we do not modify end up replaced (as the entire set of types
+ // becomes one new big rec group). We therefore need something more stable
+ // to track here, which we do using either a Call or a Call Ref.
+ std::unordered_set<Expression*> callTargetsToLocalize;
+
// Combine all the information we gathered into that map, iterating in a
// deterministic order as we build up vectors where the order matters.
for (auto& f : module->functions) {
@@ -215,12 +235,23 @@ struct SignaturePruning : public Pass {
}
auto oldParams = sig.params;
- auto removedIndexes = ParamUtils::removeParameters(funcs,
- unusedParams,
- info.calls,
- info.callRefs,
- module,
- getPassRunner());
+ auto [removedIndexes, outcome] =
+ ParamUtils::removeParameters(funcs,
+ unusedParams,
+ info.calls,
+ info.callRefs,
+ module,
+ getPassRunner());
+ if (outcome == ParamUtils::RemovalOutcome::Failure) {
+ // Use either a Call or a CallRef that has this type (see explanation
+ // above on |callTargetsToLocalize|.
+ if (!info.calls.empty()) {
+ callTargetsToLocalize.insert(info.calls[0]);
+ } else {
+ assert(!info.callRefs.empty());
+ callTargetsToLocalize.insert(info.callRefs[0]);
+ }
+ }
if (removedIndexes.empty()) {
continue;
}
@@ -262,6 +293,31 @@ struct SignaturePruning : public Pass {
// Rewrite the types.
GlobalTypeRewriter::updateSignatures(newSignatures, *module);
+
+ if (callTargetsToLocalize.empty()) {
+ return false;
+ }
+
+ // Localize after updating signatures, to not interfere with that
+ // operation (localization adds locals, and the indexes of locals must be
+ // taken into account in |GlobalTypeRewriter::updateSignatures| (as var
+ // indexes change when params are pruned).
+ std::unordered_set<HeapType> callTargetTypes;
+ for (auto* call : callTargetsToLocalize) {
+ HeapType type;
+ if (auto* c = call->dynCast<Call>()) {
+ type = module->getFunction(c->target)->type;
+ } else if (auto* c = call->dynCast<CallRef>()) {
+ type = c->target->type.getHeapType();
+ } else {
+ WASM_UNREACHABLE("bad call");
+ }
+ callTargetTypes.insert(type);
+ }
+
+ ParamUtils::localizeCallsTo(callTargetTypes, *module, getPassRunner());
+
+ return true;
}
};
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
index e94ea95b1..0caccff11 100644
--- a/src/passes/param-utils.cpp
+++ b/src/passes/param-utils.cpp
@@ -14,11 +14,16 @@
* limitations under the License.
*/
+#include "passes/param-utils.h"
+#include "ir/eh-utils.h"
#include "ir/function-utils.h"
#include "ir/local-graph.h"
+#include "ir/localize.h"
#include "ir/possible-constant.h"
#include "ir/type-updating.h"
+#include "pass.h"
#include "support/sorted_vector.h"
+#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm::ParamUtils {
@@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) {
return usedParams;
}
-bool removeParameter(const std::vector<Function*>& funcs,
- Index index,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
@@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs,
// propagating that out, or by appending an unreachable after the call, but
// for simplicity just ignore such cases; if we are called again later then
// if DCE ran meanwhile then we could optimize.
- auto hasBadEffects = [&](auto* call) {
- auto& operands = call->operands;
- bool hasUnremovable =
- EffectAnalyzer(runner->options, *module, operands[index])
- .hasUnremovableSideEffects();
- bool wouldChangeType = call->type == Type::unreachable && !call->isReturn &&
- operands[index]->type == Type::unreachable;
- return hasUnremovable || wouldChangeType;
+ auto checkEffects = [&](auto* call) {
+ auto* operand = call->operands[index];
+
+ if (operand->type == Type::unreachable) {
+ return Failure;
+ }
+
+ bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand)
+ .hasUnremovableSideEffects();
+
+ return hasUnremovable ? Failure : Success;
};
- bool callParamsAreValid =
- std::none_of(calls.begin(), calls.end(), [&](Call* call) {
- return hasBadEffects(call);
- });
- if (!callParamsAreValid) {
- return false;
+
+ for (auto* call : calls) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
- bool callRefParamsAreValid =
- std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) {
- return hasBadEffects(call);
- });
- if (!callRefParamsAreValid) {
- return false;
+
+ for (auto* call : callRefs) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
// The type must be valid for us to handle as a local (since we
@@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs,
// local
bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index));
if (!typeIsValid) {
- return false;
+ return Failure;
}
// We can do it!
@@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs,
call->operands.erase(call->operands.begin() + index);
}
- return true;
+ return Success;
}
-SortedVector removeParameters(const std::vector<Function*>& funcs,
- SortedVector indexes,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+std::pair<SortedVector, RemovalOutcome>
+removeParameters(const std::vector<Function*>& funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
if (indexes.empty()) {
- return {};
+ return {{}, Success};
}
assert(funcs.size() > 0);
@@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
SortedVector removed;
while (1) {
if (indexes.has(i)) {
- if (removeParameter(funcs, i, calls, callRefs, module, runner)) {
- // Success!
+ auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner);
+ if (outcome == Success) {
removed.insert(i);
}
}
@@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
}
i--;
}
- return removed;
+ RemovalOutcome finalOutcome = Success;
+ if (removed.size() < indexes.size()) {
+ finalOutcome = Failure;
+ }
+ return {removed, finalOutcome};
}
SortedVector applyConstantValues(const std::vector<Function*>& funcs,
@@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
return optimized;
}
+void localizeCallsTo(const std::unordered_set<Name>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<Name>& callTargets;
+
+ LocalizerPass(const std::unordered_set<Name>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ if (!callTargets.count(curr->target)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ curr, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != curr) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ // Localization can add blocks, which might move pops.
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
+void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<HeapType>& callTargets;
+
+ LocalizerPass(const std::unordered_set<HeapType>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ handleCall(curr, getModule()->getFunction(curr->target)->type);
+ }
+
+ void visitCallRef(CallRef* curr) {
+ auto type = curr->target->type;
+ if (type.isRef()) {
+ handleCall(curr, type.getHeapType());
+ }
+ }
+
+ void handleCall(Expression* call, HeapType type) {
+ if (!callTargets.count(type)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ call, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != call) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
} // namespace wasm::ParamUtils
diff --git a/src/passes/param-utils.h b/src/passes/param-utils.h
index 202c8b007..4c458390a 100644
--- a/src/passes/param-utils.h
+++ b/src/passes/param-utils.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef wasm_ir_function_h
-#define wasm_ir_function_h
+#ifndef wasm_pass_param_utils_h
+#define wasm_pass_param_utils_h
#include "pass.h"
#include "support/sorted_vector.h"
@@ -44,6 +44,16 @@ namespace wasm::ParamUtils {
// }
std::unordered_set<Index> getUsedParams(Function* func);
+// The outcome of an attempt to remove a parameter(s).
+enum RemovalOutcome {
+ // We removed successfully.
+ Success = 0,
+ // We failed, but only because of fixable nested effects. The caller can move
+ // those effects out (e.g. using ChildLocalizer, or the helper localizeCallsTo
+ // below) and repeat.
+ Failure = 1,
+};
+
// Try to remove a parameter from a set of functions and replace it with a local
// instead. This may not succeed if the parameter type cannot be used in a
// local, or if we hit another limitation, in which case this returns false and
@@ -64,21 +74,26 @@ std::unordered_set<Index> getUsedParams(Function* func);
// need adjusting and it is easier to do it all in one place. Also, the caller
// can update all the types at once throughout the program after making
// multiple calls to removeParameter().
-bool removeParameter(const std::vector<Function*>& funcs,
- Index index,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner);
+RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner);
// The same as removeParameter, but gets a sorted list of indexes. It tries to
-// remove them all, and returns which we removed.
-SortedVector removeParameters(const std::vector<Function*>& funcs,
- SortedVector indexes,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner);
+// remove them all, and returns which we removed, as well as an indication as
+// to whether we might remove more if effects were not in the way (specifically,
+// we return Success if we removed any index, Failure if we removed none, and
+// FailureDueToEffects if at least one index could have been removed but for
+// effects).
+std::pair<SortedVector, RemovalOutcome>
+removeParameters(const std::vector<Function*>& funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner);
// Given a set of functions and the calls and call_refs that reach them, find
// which parameters are passed the same constant value in all the calls. For
@@ -92,6 +107,20 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
const std::vector<CallRef*>& callRefs,
Module* module);
+// Helper that localizes all calls to a set of targets, in an entire module.
+// This basically calls ChildLocalizer in each function, on the relevant calls.
+// This is useful when we get FailureDueToEffects, see above.
+//
+// The set of targets can be function names (the individual functions we want to
+// handle calls towards) or heap types (which will then include all functions
+// with those types).
+void localizeCallsTo(const std::unordered_set<Name>& callTargets,
+ Module& wasm,
+ PassRunner* runner);
+void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
+ Module& wasm,
+ PassRunner* runner);
+
} // namespace wasm::ParamUtils
-#endif // wasm_ir_function_h
+#endif // wasm_pass_param_utils_h