summaryrefslogtreecommitdiff
path: root/src/passes/param-utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/param-utils.cpp')
-rw-r--r--src/passes/param-utils.cpp183
1 files changed, 145 insertions, 38 deletions
diff --git a/src/passes/param-utils.cpp b/src/passes/param-utils.cpp
index e94ea95b1..0caccff11 100644
--- a/src/passes/param-utils.cpp
+++ b/src/passes/param-utils.cpp
@@ -14,11 +14,16 @@
* limitations under the License.
*/
+#include "passes/param-utils.h"
+#include "ir/eh-utils.h"
#include "ir/function-utils.h"
#include "ir/local-graph.h"
+#include "ir/localize.h"
#include "ir/possible-constant.h"
#include "ir/type-updating.h"
+#include "pass.h"
#include "support/sorted_vector.h"
+#include "wasm-traversal.h"
#include "wasm.h"
namespace wasm::ParamUtils {
@@ -45,12 +50,12 @@ std::unordered_set<Index> getUsedParams(Function* func) {
return usedParams;
}
-bool removeParameter(const std::vector<Function*>& funcs,
- Index index,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+RemovalOutcome removeParameter(const std::vector<Function*>& funcs,
+ Index index,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
assert(funcs.size() > 0);
auto* first = funcs[0];
#ifndef NDEBUG
@@ -74,28 +79,31 @@ bool removeParameter(const std::vector<Function*>& funcs,
// propagating that out, or by appending an unreachable after the call, but
// for simplicity just ignore such cases; if we are called again later then
// if DCE ran meanwhile then we could optimize.
- auto hasBadEffects = [&](auto* call) {
- auto& operands = call->operands;
- bool hasUnremovable =
- EffectAnalyzer(runner->options, *module, operands[index])
- .hasUnremovableSideEffects();
- bool wouldChangeType = call->type == Type::unreachable && !call->isReturn &&
- operands[index]->type == Type::unreachable;
- return hasUnremovable || wouldChangeType;
+ auto checkEffects = [&](auto* call) {
+ auto* operand = call->operands[index];
+
+ if (operand->type == Type::unreachable) {
+ return Failure;
+ }
+
+ bool hasUnremovable = EffectAnalyzer(runner->options, *module, operand)
+ .hasUnremovableSideEffects();
+
+ return hasUnremovable ? Failure : Success;
};
- bool callParamsAreValid =
- std::none_of(calls.begin(), calls.end(), [&](Call* call) {
- return hasBadEffects(call);
- });
- if (!callParamsAreValid) {
- return false;
+
+ for (auto* call : calls) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
- bool callRefParamsAreValid =
- std::none_of(callRefs.begin(), callRefs.end(), [&](CallRef* call) {
- return hasBadEffects(call);
- });
- if (!callRefParamsAreValid) {
- return false;
+
+ for (auto* call : callRefs) {
+ auto result = checkEffects(call);
+ if (result != Success) {
+ return result;
+ }
}
// The type must be valid for us to handle as a local (since we
@@ -104,7 +112,7 @@ bool removeParameter(const std::vector<Function*>& funcs,
// local
bool typeIsValid = TypeUpdating::canHandleAsLocal(first->getLocalType(index));
if (!typeIsValid) {
- return false;
+ return Failure;
}
// We can do it!
@@ -161,17 +169,18 @@ bool removeParameter(const std::vector<Function*>& funcs,
call->operands.erase(call->operands.begin() + index);
}
- return true;
+ return Success;
}
-SortedVector removeParameters(const std::vector<Function*>& funcs,
- SortedVector indexes,
- const std::vector<Call*>& calls,
- const std::vector<CallRef*>& callRefs,
- Module* module,
- PassRunner* runner) {
+std::pair<SortedVector, RemovalOutcome>
+removeParameters(const std::vector<Function*>& funcs,
+ SortedVector indexes,
+ const std::vector<Call*>& calls,
+ const std::vector<CallRef*>& callRefs,
+ Module* module,
+ PassRunner* runner) {
if (indexes.empty()) {
- return {};
+ return {{}, Success};
}
assert(funcs.size() > 0);
@@ -188,8 +197,8 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
SortedVector removed;
while (1) {
if (indexes.has(i)) {
- if (removeParameter(funcs, i, calls, callRefs, module, runner)) {
- // Success!
+ auto outcome = removeParameter(funcs, i, calls, callRefs, module, runner);
+ if (outcome == Success) {
removed.insert(i);
}
}
@@ -198,7 +207,11 @@ SortedVector removeParameters(const std::vector<Function*>& funcs,
}
i--;
}
- return removed;
+ RemovalOutcome finalOutcome = Success;
+ if (removed.size() < indexes.size()) {
+ finalOutcome = Failure;
+ }
+ return {removed, finalOutcome};
}
SortedVector applyConstantValues(const std::vector<Function*>& funcs,
@@ -246,4 +259,98 @@ SortedVector applyConstantValues(const std::vector<Function*>& funcs,
return optimized;
}
+void localizeCallsTo(const std::unordered_set<Name>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<Name>& callTargets;
+
+ LocalizerPass(const std::unordered_set<Name>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ if (!callTargets.count(curr->target)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ curr, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != curr) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ // Localization can add blocks, which might move pops.
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
+void localizeCallsTo(const std::unordered_set<HeapType>& callTargets,
+ Module& wasm,
+ PassRunner* runner) {
+ struct LocalizerPass : public WalkerPass<PostWalker<LocalizerPass>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<LocalizerPass>(callTargets);
+ }
+
+ const std::unordered_set<HeapType>& callTargets;
+
+ LocalizerPass(const std::unordered_set<HeapType>& callTargets)
+ : callTargets(callTargets) {}
+
+ void visitCall(Call* curr) {
+ handleCall(curr, getModule()->getFunction(curr->target)->type);
+ }
+
+ void visitCallRef(CallRef* curr) {
+ auto type = curr->target->type;
+ if (type.isRef()) {
+ handleCall(curr, type.getHeapType());
+ }
+ }
+
+ void handleCall(Expression* call, HeapType type) {
+ if (!callTargets.count(type)) {
+ return;
+ }
+
+ ChildLocalizer localizer(
+ call, getFunction(), *getModule(), getPassOptions());
+ auto* replacement = localizer.getReplacement();
+ if (replacement != call) {
+ replaceCurrent(replacement);
+ optimized = true;
+ }
+ }
+
+ bool optimized = false;
+
+ void visitFunction(Function* curr) {
+ if (optimized) {
+ EHUtils::handleBlockNestedPops(curr, *getModule());
+ }
+ }
+ };
+
+ LocalizerPass(callTargets).run(runner, &wasm);
+}
+
} // namespace wasm::ParamUtils