diff options
author | Jérôme Vouillon <jerome.vouillon@gmail.com> | 2024-04-10 17:06:56 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-10 14:06:56 -0700 |
commit | f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383 (patch) | |
tree | 19b439aeee22049c679e073ea64a5ea01e7e9539 /src | |
parent | 738e8fca4bea0c37e859e5bf0d37866ff432714e (diff) | |
download | binaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.tar.gz binaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.tar.bz2 binaryen-f7ea0c4cff5e3d80e3feb2bce15037bd2c6b9383.zip |
Improve inlining of `return_call*` (#6477)
Use the previous implementation when no return_call is in a try block. This
avoids moving code around (as a sibling of the caller body or the inlined body),
so that should allow more local optimizations after inlining.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/Inlining.cpp | 88 | ||||
-rw-r--r-- | src/wasm-traversal.h | 47 |
2 files changed, 107 insertions, 28 deletions
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 0643389c2..04bc6d78a 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -242,9 +242,10 @@ private: struct InliningAction { Expression** callSite; Function* contents; + bool insideATry; - InliningAction(Expression** callSite, Function* contents) - : callSite(callSite), contents(contents) {} + InliningAction(Expression** callSite, Function* contents, bool insideATry) + : callSite(callSite), contents(contents), insideATry(insideATry) {} }; struct InliningState { @@ -254,7 +255,7 @@ struct InliningState { std::unordered_map<Name, std::vector<InliningAction>> actionsForFunction; }; -struct Planner : public WalkerPass<PostWalker<Planner>> { +struct Planner : public WalkerPass<TryDepthWalker<Planner>> { bool isFunctionParallel() override { return true; } Planner(InliningState* state) : state(state) {} @@ -287,7 +288,7 @@ struct Planner : public WalkerPass<PostWalker<Planner>> { // can't add a new element in parallel assert(state->actionsForFunction.count(getFunction()->name) > 0); state->actionsForFunction[getFunction()->name].emplace_back( - &block->list[0], getModule()->getFunction(curr->target)); + &block->list[0], getModule()->getFunction(curr->target), tryDepth > 0); } } @@ -295,7 +296,7 @@ private: InliningState* state; }; -struct Updater : public PostWalker<Updater> { +struct Updater : public TryDepthWalker<Updater> { Module* module; std::map<Index, Index> localMapping; Name returnName; @@ -335,19 +336,38 @@ struct Updater : public PostWalker<Updater> { return; } - // Set the children to locals as necessary, then add a branch out of the - // inlined body. The branch label will be set later when we create branch - // targets for the calls. - Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options) - .getChildrenReplacement(); - Break* branch = builder->makeBreak(Name()); - childBlock->list.push_back(branch); - childBlock->type = Type::unreachable; - replaceCurrent(childBlock); - - curr->isReturn = false; - curr->type = sig.results; - returnCallInfos.push_back({curr, branch}); + if (tryDepth == 0) { + // Return calls in inlined functions should only break out of + // the scope of the inlined code, not the entire function they + // are being inlined into. To achieve this, make the call a + // non-return call and add a break. This does not cause + // unbounded stack growth because inlining and return calling + // both avoid creating a new stack frame. + curr->isReturn = false; + curr->type = sig.results; + // There might still be unreachable children causing this to be + // unreachable. + curr->finalize(); + if (sig.results.isConcrete()) { + replaceCurrent(builder->makeBreak(returnName, curr)); + } else { + replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName))); + } + } else { + // Set the children to locals as necessary, then add a branch out of the + // inlined body. The branch label will be set later when we create branch + // targets for the calls. + Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options) + .getChildrenReplacement(); + Break* branch = builder->makeBreak(Name()); + childBlock->list.push_back(branch); + childBlock->type = Type::unreachable; + replaceCurrent(childBlock); + + curr->isReturn = false; + curr->type = sig.results; + returnCallInfos.push_back({curr, branch}); + } } void visitCall(Call* curr) { @@ -464,14 +484,15 @@ static Expression* doInlining(Module* module, // // (In this case we could use a second block and define the named block $X // after the call's parameters, but that adds work for an extremely rare - // situation.) The latter case does not apply if the call is a return_call, - // because in that case the call's children do not appear inside the same - // block as the inlined body. + // situation.) The latter case does not apply if the call is a + // return_call inside a try, because in that case the call's + // children do not appear inside the same block as the inlined body. + bool hoistCall = call->isReturn && action.insideATry; if (BranchUtils::hasBranchTarget(from->body, block->name) || - (!call->isReturn && BranchUtils::BranchSeeker::has(call, block->name))) { + (!hoistCall && BranchUtils::BranchSeeker::has(call, block->name))) { auto fromNames = BranchUtils::getBranchTargets(from->body); - auto callNames = call->isReturn ? BranchUtils::NameSet{} - : BranchUtils::BranchAccumulator::get(call); + auto callNames = hoistCall ? BranchUtils::NameSet{} + : BranchUtils::BranchAccumulator::get(call); block->name = Names::getValidName(block->name, [&](Name test) { return !fromNames.count(test) && !callNames.count(test); }); @@ -490,7 +511,7 @@ static Expression* doInlining(Module* module, updater.localMapping[i] = builder.addVar(into, from->getLocalType(i)); } - if (call->isReturn) { + if (hoistCall) { // Wrap the existing function body in a block we can branch out of before // entering the inlined function body. This block must have a name that is // different from any other block name above the branch. @@ -544,7 +565,16 @@ static Expression* doInlining(Module* module, builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i], LiteralUtils::makeZero(type, *module))); } - *action.callSite = block; + if (call->isReturn) { + assert(!action.insideATry); + if (retType.isConcrete()) { + *action.callSite = builder.makeReturn(block); + } else { + *action.callSite = builder.makeSequence(block, builder.makeReturn()); + } + } else { + *action.callSite = block; + } } // Generate and update the inlined contents @@ -1396,8 +1426,10 @@ struct InlineMainPass : public Pass { // No call at all. return; } - doInlining( - module, main, InliningAction(callSite, originalMain), getPassOptions()); + doInlining(module, + main, + InliningAction(callSite, originalMain, true), + getPassOptions()); } }; diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index f5f25dd1f..c8078164b 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -537,6 +537,53 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> { } }; +// Traversal keeping track of try depth + +// This is used to keep track of whether we are in the scope of an +// exception handler. This matters since return_call is not equivalent +// to return + call within an exception handler. If another kind of +// handler scope is added, this code will need to be updated. +template<typename SubType, typename VisitorType = Visitor<SubType>> +struct TryDepthWalker : public PostWalker<SubType, VisitorType> { + TryDepthWalker() = default; + + size_t tryDepth = 0; + + static void doEnterTry(SubType* self, Expression** currp) { + self->tryDepth++; + } + + static void doLeaveTry(SubType* self, Expression** currp) { + self->tryDepth--; + } + + static void scan(SubType* self, Expression** currp) { + auto* curr = *currp; + + if (curr->is<Try>()) { + self->pushTask(SubType::doVisitTry, currp); + auto& catchBodies = curr->cast<Try>()->catchBodies; + for (int i = int(catchBodies.size()) - 1; i >= 0; i--) { + self->pushTask(SubType::scan, &catchBodies[i]); + } + self->pushTask(SubType::doLeaveTry, currp); + self->pushTask(SubType::scan, &curr->cast<Try>()->body); + self->pushTask(SubType::doEnterTry, currp); + return; + } + + if (curr->is<TryTable>()) { + self->pushTask(SubType::doLeaveTry, currp); + } + + PostWalker<SubType, VisitorType>::scan(self, currp); + + if (curr->is<TryTable>()) { + self->pushTask(SubType::doEnterTry, currp); + } + } +}; + } // namespace wasm #endif // wasm_wasm_traversal_h |