summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJérôme Vouillon <jerome.vouillon@gmail.com>2024-04-10 17:06:56 -0400
committerGitHub <noreply@github.com>2024-04-10 14:06:56 -0700
commitf7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383 (patch)
tree19b439aeee22049c679e073ea64a5ea01e7e9539 /src
parent738e8fca4bea0c37e859e5bf0d37866ff432714e (diff)
downloadbinaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.tar.gz
binaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.tar.bz2
binaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.zip
Improve inlining of `return_call*` (#6477)
Use the previous implementation when no return_call is in a try block. This avoids moving code around (as a sibling of the caller body or the inlined body), so that should allow more local optimizations after inlining.
Diffstat (limited to 'src')
-rw-r--r--src/passes/Inlining.cpp88
-rw-r--r--src/wasm-traversal.h47
2 files changed, 107 insertions, 28 deletions
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index 0643389c2..04bc6d78a 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -242,9 +242,10 @@ private:
struct InliningAction {
Expression** callSite;
Function* contents;
+ bool insideATry;
- InliningAction(Expression** callSite, Function* contents)
- : callSite(callSite), contents(contents) {}
+ InliningAction(Expression** callSite, Function* contents, bool insideATry)
+ : callSite(callSite), contents(contents), insideATry(insideATry) {}
};
struct InliningState {
@@ -254,7 +255,7 @@ struct InliningState {
std::unordered_map<Name, std::vector<InliningAction>> actionsForFunction;
};
-struct Planner : public WalkerPass<PostWalker<Planner>> {
+struct Planner : public WalkerPass<TryDepthWalker<Planner>> {
bool isFunctionParallel() override { return true; }
Planner(InliningState* state) : state(state) {}
@@ -287,7 +288,7 @@ struct Planner : public WalkerPass<PostWalker<Planner>> {
// can't add a new element in parallel
assert(state->actionsForFunction.count(getFunction()->name) > 0);
state->actionsForFunction[getFunction()->name].emplace_back(
- &block->list[0], getModule()->getFunction(curr->target));
+ &block->list[0], getModule()->getFunction(curr->target), tryDepth > 0);
}
}
@@ -295,7 +296,7 @@ private:
InliningState* state;
};
-struct Updater : public PostWalker<Updater> {
+struct Updater : public TryDepthWalker<Updater> {
Module* module;
std::map<Index, Index> localMapping;
Name returnName;
@@ -335,19 +336,38 @@ struct Updater : public PostWalker<Updater> {
return;
}
- // Set the children to locals as necessary, then add a branch out of the
- // inlined body. The branch label will be set later when we create branch
- // targets for the calls.
- Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options)
- .getChildrenReplacement();
- Break* branch = builder->makeBreak(Name());
- childBlock->list.push_back(branch);
- childBlock->type = Type::unreachable;
- replaceCurrent(childBlock);
-
- curr->isReturn = false;
- curr->type = sig.results;
- returnCallInfos.push_back({curr, branch});
+ if (tryDepth == 0) {
+ // Return calls in inlined functions should only break out of
+ // the scope of the inlined code, not the entire function they
+ // are being inlined into. To achieve this, make the call a
+ // non-return call and add a break. This does not cause
+ // unbounded stack growth because inlining and return calling
+ // both avoid creating a new stack frame.
+ curr->isReturn = false;
+ curr->type = sig.results;
+ // There might still be unreachable children causing this to be
+ // unreachable.
+ curr->finalize();
+ if (sig.results.isConcrete()) {
+ replaceCurrent(builder->makeBreak(returnName, curr));
+ } else {
+ replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName)));
+ }
+ } else {
+ // Set the children to locals as necessary, then add a branch out of the
+ // inlined body. The branch label will be set later when we create branch
+ // targets for the calls.
+ Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options)
+ .getChildrenReplacement();
+ Break* branch = builder->makeBreak(Name());
+ childBlock->list.push_back(branch);
+ childBlock->type = Type::unreachable;
+ replaceCurrent(childBlock);
+
+ curr->isReturn = false;
+ curr->type = sig.results;
+ returnCallInfos.push_back({curr, branch});
+ }
}
void visitCall(Call* curr) {
@@ -464,14 +484,15 @@ static Expression* doInlining(Module* module,
//
// (In this case we could use a second block and define the named block $X
// after the call's parameters, but that adds work for an extremely rare
- // situation.) The latter case does not apply if the call is a return_call,
- // because in that case the call's children do not appear inside the same
- // block as the inlined body.
+ // situation.) The latter case does not apply if the call is a
+ // return_call inside a try, because in that case the call's
+ // children do not appear inside the same block as the inlined body.
+ bool hoistCall = call->isReturn && action.insideATry;
if (BranchUtils::hasBranchTarget(from->body, block->name) ||
- (!call->isReturn && BranchUtils::BranchSeeker::has(call, block->name))) {
+ (!hoistCall && BranchUtils::BranchSeeker::has(call, block->name))) {
auto fromNames = BranchUtils::getBranchTargets(from->body);
- auto callNames = call->isReturn ? BranchUtils::NameSet{}
- : BranchUtils::BranchAccumulator::get(call);
+ auto callNames = hoistCall ? BranchUtils::NameSet{}
+ : BranchUtils::BranchAccumulator::get(call);
block->name = Names::getValidName(block->name, [&](Name test) {
return !fromNames.count(test) && !callNames.count(test);
});
@@ -490,7 +511,7 @@ static Expression* doInlining(Module* module,
updater.localMapping[i] = builder.addVar(into, from->getLocalType(i));
}
- if (call->isReturn) {
+ if (hoistCall) {
// Wrap the existing function body in a block we can branch out of before
// entering the inlined function body. This block must have a name that is
// different from any other block name above the branch.
@@ -544,7 +565,16 @@ static Expression* doInlining(Module* module,
builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i],
LiteralUtils::makeZero(type, *module)));
}
- *action.callSite = block;
+ if (call->isReturn) {
+ assert(!action.insideATry);
+ if (retType.isConcrete()) {
+ *action.callSite = builder.makeReturn(block);
+ } else {
+ *action.callSite = builder.makeSequence(block, builder.makeReturn());
+ }
+ } else {
+ *action.callSite = block;
+ }
}
// Generate and update the inlined contents
@@ -1396,8 +1426,10 @@ struct InlineMainPass : public Pass {
// No call at all.
return;
}
- doInlining(
- module, main, InliningAction(callSite, originalMain), getPassOptions());
+ doInlining(module,
+ main,
+ InliningAction(callSite, originalMain, true),
+ getPassOptions());
}
};
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index f5f25dd1f..c8078164b 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -537,6 +537,53 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> {
}
};
+// Traversal keeping track of try depth
+
+// This is used to keep track of whether we are in the scope of an
+// exception handler. This matters since return_call is not equivalent
+// to return + call within an exception handler. If another kind of
+// handler scope is added, this code will need to be updated.
+template<typename SubType, typename VisitorType = Visitor<SubType>>
+struct TryDepthWalker : public PostWalker<SubType, VisitorType> {
+ TryDepthWalker() = default;
+
+ size_t tryDepth = 0;
+
+ static void doEnterTry(SubType* self, Expression** currp) {
+ self->tryDepth++;
+ }
+
+ static void doLeaveTry(SubType* self, Expression** currp) {
+ self->tryDepth--;
+ }
+
+ static void scan(SubType* self, Expression** currp) {
+ auto* curr = *currp;
+
+ if (curr->is<Try>()) {
+ self->pushTask(SubType::doVisitTry, currp);
+ auto& catchBodies = curr->cast<Try>()->catchBodies;
+ for (int i = int(catchBodies.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &catchBodies[i]);
+ }
+ self->pushTask(SubType::doLeaveTry, currp);
+ self->pushTask(SubType::scan, &curr->cast<Try>()->body);
+ self->pushTask(SubType::doEnterTry, currp);
+ return;
+ }
+
+ if (curr->is<TryTable>()) {
+ self->pushTask(SubType::doLeaveTry, currp);
+ }
+
+ PostWalker<SubType, VisitorType>::scan(self, currp);
+
+ if (curr->is<TryTable>()) {
+ self->pushTask(SubType::doEnterTry, currp);
+ }
+ }
+};
+
} // namespace wasm
#endif // wasm_wasm_traversal_h