diff options
author | Alon Zakai <azakai@google.com> | 2020-09-03 16:11:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-03 16:11:05 -0700 |
commit | 44df23efd69fd2dd4c260755c82ddede226c40ff (patch) | |
tree | d828947439ac4dd47fc023b176b86d4949890b81 | |
parent | 132c72bb5e93591de34a9bfc267e4a2007908626 (diff) | |
download | binaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.tar.gz binaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.tar.bz2 binaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.zip |
Optimize MergeBlocks by caching branch results (#3102)
BranchSeekerCache caches the set of branches in a node +
its children, and helps compute new results by looking in the cache
and using data for the children. This avoids quadratic time in the
common case of a post-walk on a tower of nested blocks which is
common in a switch.
Fixes #3090 . On the testcase there this pass goes from
over a minute to less than a second.
-rw-r--r-- | src/ir/branch-utils.h | 102 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 28 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 1 |
3 files changed, 115 insertions, 16 deletions
diff --git a/src/ir/branch-utils.h b/src/ir/branch-utils.h index 363a9c9e2..9b439be89 100644 --- a/src/ir/branch-utils.h +++ b/src/ir/branch-utils.h @@ -17,6 +17,7 @@ #ifndef wasm_ir_branch_h #define wasm_ir_branch_h +#include "ir/iteration.h" #include "wasm-traversal.h" #include "wasm.h" @@ -52,10 +53,12 @@ inline bool isBranchReachable(Expression* expr) { WASM_UNREACHABLE("unexpected expression type"); } -inline std::set<Name> getUniqueTargets(Break* br) { return {br->name}; } +using NameSet = std::set<Name>; -inline std::set<Name> getUniqueTargets(Switch* sw) { - std::set<Name> ret; +inline NameSet getUniqueTargets(Break* br) { return {br->name}; } + +inline NameSet getUniqueTargets(Switch* sw) { + NameSet ret; for (auto target : sw->targets) { ret.insert(target); } @@ -63,8 +66,20 @@ inline std::set<Name> getUniqueTargets(Switch* sw) { return ret; } -inline std::set<Name> getUniqueTargets(BrOnExn* br) { return {br->name}; } +inline NameSet getUniqueTargets(BrOnExn* br) { return {br->name}; } +inline NameSet getUniqueTargets(Expression* expr) { + if (auto* br = expr->dynCast<Break>()) { + return getUniqueTargets(br); + } + if (auto* br = expr->dynCast<Switch>()) { + return getUniqueTargets(br); + } + if (auto* br = expr->dynCast<BrOnExn>()) { + return getUniqueTargets(br); + } + return {}; +} // If we branch to 'from', change that to 'to' instead. inline bool replacePossibleTarget(Expression* branch, Name from, Name to) { bool worked = false; @@ -97,9 +112,9 @@ inline bool replacePossibleTarget(Expression* branch, Name from, Name to) { // returns the set of targets to which we branch that are // outside of a node -inline std::set<Name> getExitingBranches(Expression* ast) { +inline NameSet getExitingBranches(Expression* ast) { struct Scanner : public PostWalker<Scanner> { - std::set<Name> targets; + NameSet targets; void visitBreak(Break* curr) { targets.insert(curr->name); } void visitSwitch(Switch* curr) { @@ -128,9 +143,9 @@ inline std::set<Name> getExitingBranches(Expression* ast) { // returns the list of all branch targets in a node -inline std::set<Name> getBranchTargets(Expression* ast) { +inline NameSet getBranchTargets(Expression* ast) { struct Scanner : public PostWalker<Scanner> { - std::set<Name> targets; + NameSet targets; void visitBlock(Block* curr) { if (curr->name.is()) { @@ -218,6 +233,77 @@ struct BranchSeeker : public PostWalker<BranchSeeker> { } }; +// Accumulates all the branches in an entire tree. +struct BranchAccumulator + : public PostWalker<BranchAccumulator, + UnifiedExpressionVisitor<BranchAccumulator>> { + NameSet branches; + + void visitExpression(Expression* curr) { + auto selfBranches = getUniqueTargets(curr); + branches.insert(selfBranches.begin(), selfBranches.end()); + } +}; + +// A helper structure for the common case of post-walking some IR while querying +// whether a branch is present. We can cache results for children in order to +// avoid quadratic time searches. +// We assume that a node will be scanned *once* here. That means that if we +// scan a node, we can discard all information for its children. This avoids +// linearly increasing memory usage over time. +class BranchSeekerCache { + // Maps all the branches present in an expression and all its nested children. + std::unordered_map<Expression*, NameSet> branches; + +public: + const NameSet& getBranches(Expression* curr) { + auto iter = branches.find(curr); + if (iter != branches.end()) { + return iter->second; + } + NameSet currBranches; + auto add = [&](NameSet& moreBranches) { + // Make sure to do a fast swap for the first set of branches to arrive. + // This helps the case of the first child being a block with a very large + // set of names. + if (currBranches.empty()) { + currBranches.swap(moreBranches); + } else { + currBranches.insert(moreBranches.begin(), moreBranches.end()); + } + }; + // Add from the children, which are hopefully cached. + for (auto child : ChildIterator(curr)) { + auto iter = branches.find(child); + if (iter != branches.end()) { + add(iter->second); + // We are scanning the parent, which means we assume the child will + // never be visited again. + branches.erase(iter); + } else { + // The child was not cached. Scan it manually. + BranchAccumulator childBranches; + childBranches.walk(child); + add(childBranches.branches); + // Don't bother caching anything - we are scanning the parent, so the + // child will presumably not be scanned again. + } + } + // Finish with the parent's own branches. + auto selfBranches = getUniqueTargets(curr); + add(selfBranches); + return branches[curr] = std::move(currBranches); + } + + bool hasBranch(Expression* curr, Name target) { + bool result = getBranches(curr).count(target); +#ifdef BRANCH_UTILS_DEBUG + assert(bresult == BranchSeeker::has(curr, target)); +#endif + return result; + } +}; + } // namespace BranchUtils } // namespace wasm diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 1401cd2f4..4ecec6669 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -145,8 +145,11 @@ struct ProblemFinder : public ControlFlowWalker<ProblemFinder> { struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper> { Name origin; PassOptions& passOptions; + BranchUtils::BranchSeekerCache& branchInfo; - BreakValueDropper(PassOptions& passOptions) : passOptions(passOptions) {} + BreakValueDropper(PassOptions& passOptions, + BranchUtils::BranchSeekerCache& branchInfo) + : passOptions(passOptions), branchInfo(branchInfo) {} void visitBlock(Block* curr); @@ -198,8 +201,10 @@ static bool hasDeadCode(Block* block) { } // core block optimizer routine -static void -optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { +static void optimizeBlock(Block* curr, + Module* module, + PassOptions& passOptions, + BranchUtils::BranchSeekerCache& branchInfo) { auto& list = curr->list; // Main merging loop. bool more = true; @@ -237,7 +242,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { childBlock = nullptr; } else { // fix up breaks - BreakValueDropper fixer(passOptions); + BreakValueDropper fixer(passOptions, branchInfo); fixer.origin = childBlock->name; fixer.setModule(module); fixer.walk(expression); @@ -294,7 +299,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { auto childName = childBlock->name; for (size_t j = 0; j < childSize; j++) { auto* item = childList[j]; - if (BranchUtils::BranchSeeker::has(item, childName)) { + if (branchInfo.hasBranch(item, childName)) { // We can't remove this from the child. keepStart = j; keepEnd = childSize; @@ -360,6 +365,13 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { if (loop) { loop->finalize(); } + // Note that we modify the child block here, which invalidates info + // in branchInfo. However, as we have scanned the parent, we have + // already forgotten the child's info, so there is nothing to do here + // for the child. + // (We also don't need to do anything for the parent - we move code + // from a child into the parent, but that doesn't change the total + // branches in the parent.) } // Add the rest of the parent block after the child. for (size_t j = i + 1; j < list.size(); j++) { @@ -387,7 +399,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) { } void BreakValueDropper::visitBlock(Block* curr) { - optimizeBlock(curr, getModule(), passOptions); + optimizeBlock(curr, getModule(), passOptions, branchInfo); } struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { @@ -395,8 +407,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { Pass* create() override { return new MergeBlocks; } + BranchUtils::BranchSeekerCache branchInfo; + void visitBlock(Block* curr) { - optimizeBlock(curr, getModule(), getPassOptions()); + optimizeBlock(curr, getModule(), getPassOptions(), branchInfo); } // given diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 8410cc0b2..3e1a057ce 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -19,7 +19,6 @@ #include <sstream> #include <unordered_set> -#include "ir/branch-utils.h" #include "ir/features.h" #include "ir/global-utils.h" #include "ir/module-utils.h" |