summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2020-09-03 16:11:05 -0700
committerGitHub <noreply@github.com>2020-09-03 16:11:05 -0700
commit44df23efd69fd2dd4c260755c82ddede226c40ff (patch)
treed828947439ac4dd47fc023b176b86d4949890b81
parent132c72bb5e93591de34a9bfc267e4a2007908626 (diff)
downloadbinaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.tar.gz
binaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.tar.bz2
binaryen-44df23efd69fd2dd4c260755c82ddede226c40ff.zip
Optimize MergeBlocks by caching branch results (#3102)
BranchSeekerCache caches the set of branches in a node + its children, and helps compute new results by looking in the cache and using data for the children. This avoids quadratic time in the common case of a post-walk on a tower of nested blocks which is common in a switch. Fixes #3090 . On the testcase there this pass goes from over a minute to less than a second.
-rw-r--r--src/ir/branch-utils.h102
-rw-r--r--src/passes/MergeBlocks.cpp28
-rw-r--r--src/wasm/wasm-validator.cpp1
3 files changed, 115 insertions, 16 deletions
diff --git a/src/ir/branch-utils.h b/src/ir/branch-utils.h
index 363a9c9e2..9b439be89 100644
--- a/src/ir/branch-utils.h
+++ b/src/ir/branch-utils.h
@@ -17,6 +17,7 @@
#ifndef wasm_ir_branch_h
#define wasm_ir_branch_h
+#include "ir/iteration.h"
#include "wasm-traversal.h"
#include "wasm.h"
@@ -52,10 +53,12 @@ inline bool isBranchReachable(Expression* expr) {
WASM_UNREACHABLE("unexpected expression type");
}
-inline std::set<Name> getUniqueTargets(Break* br) { return {br->name}; }
+using NameSet = std::set<Name>;
-inline std::set<Name> getUniqueTargets(Switch* sw) {
- std::set<Name> ret;
+inline NameSet getUniqueTargets(Break* br) { return {br->name}; }
+
+inline NameSet getUniqueTargets(Switch* sw) {
+ NameSet ret;
for (auto target : sw->targets) {
ret.insert(target);
}
@@ -63,8 +66,20 @@ inline std::set<Name> getUniqueTargets(Switch* sw) {
return ret;
}
-inline std::set<Name> getUniqueTargets(BrOnExn* br) { return {br->name}; }
+inline NameSet getUniqueTargets(BrOnExn* br) { return {br->name}; }
+inline NameSet getUniqueTargets(Expression* expr) {
+ if (auto* br = expr->dynCast<Break>()) {
+ return getUniqueTargets(br);
+ }
+ if (auto* br = expr->dynCast<Switch>()) {
+ return getUniqueTargets(br);
+ }
+ if (auto* br = expr->dynCast<BrOnExn>()) {
+ return getUniqueTargets(br);
+ }
+ return {};
+}
// If we branch to 'from', change that to 'to' instead.
inline bool replacePossibleTarget(Expression* branch, Name from, Name to) {
bool worked = false;
@@ -97,9 +112,9 @@ inline bool replacePossibleTarget(Expression* branch, Name from, Name to) {
// returns the set of targets to which we branch that are
// outside of a node
-inline std::set<Name> getExitingBranches(Expression* ast) {
+inline NameSet getExitingBranches(Expression* ast) {
struct Scanner : public PostWalker<Scanner> {
- std::set<Name> targets;
+ NameSet targets;
void visitBreak(Break* curr) { targets.insert(curr->name); }
void visitSwitch(Switch* curr) {
@@ -128,9 +143,9 @@ inline std::set<Name> getExitingBranches(Expression* ast) {
// returns the list of all branch targets in a node
-inline std::set<Name> getBranchTargets(Expression* ast) {
+inline NameSet getBranchTargets(Expression* ast) {
struct Scanner : public PostWalker<Scanner> {
- std::set<Name> targets;
+ NameSet targets;
void visitBlock(Block* curr) {
if (curr->name.is()) {
@@ -218,6 +233,77 @@ struct BranchSeeker : public PostWalker<BranchSeeker> {
}
};
+// Accumulates all the branches in an entire tree.
+struct BranchAccumulator
+ : public PostWalker<BranchAccumulator,
+ UnifiedExpressionVisitor<BranchAccumulator>> {
+ NameSet branches;
+
+ void visitExpression(Expression* curr) {
+ auto selfBranches = getUniqueTargets(curr);
+ branches.insert(selfBranches.begin(), selfBranches.end());
+ }
+};
+
+// A helper structure for the common case of post-walking some IR while querying
+// whether a branch is present. We can cache results for children in order to
+// avoid quadratic time searches.
+// We assume that a node will be scanned *once* here. That means that if we
+// scan a node, we can discard all information for its children. This avoids
+// linearly increasing memory usage over time.
+class BranchSeekerCache {
+ // Maps all the branches present in an expression and all its nested children.
+ std::unordered_map<Expression*, NameSet> branches;
+
+public:
+ const NameSet& getBranches(Expression* curr) {
+ auto iter = branches.find(curr);
+ if (iter != branches.end()) {
+ return iter->second;
+ }
+ NameSet currBranches;
+ auto add = [&](NameSet& moreBranches) {
+ // Make sure to do a fast swap for the first set of branches to arrive.
+ // This helps the case of the first child being a block with a very large
+ // set of names.
+ if (currBranches.empty()) {
+ currBranches.swap(moreBranches);
+ } else {
+ currBranches.insert(moreBranches.begin(), moreBranches.end());
+ }
+ };
+ // Add from the children, which are hopefully cached.
+ for (auto child : ChildIterator(curr)) {
+ auto iter = branches.find(child);
+ if (iter != branches.end()) {
+ add(iter->second);
+ // We are scanning the parent, which means we assume the child will
+ // never be visited again.
+ branches.erase(iter);
+ } else {
+ // The child was not cached. Scan it manually.
+ BranchAccumulator childBranches;
+ childBranches.walk(child);
+ add(childBranches.branches);
+ // Don't bother caching anything - we are scanning the parent, so the
+ // child will presumably not be scanned again.
+ }
+ }
+ // Finish with the parent's own branches.
+ auto selfBranches = getUniqueTargets(curr);
+ add(selfBranches);
+ return branches[curr] = std::move(currBranches);
+ }
+
+ bool hasBranch(Expression* curr, Name target) {
+ bool result = getBranches(curr).count(target);
+#ifdef BRANCH_UTILS_DEBUG
+ assert(bresult == BranchSeeker::has(curr, target));
+#endif
+ return result;
+ }
+};
+
} // namespace BranchUtils
} // namespace wasm
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 1401cd2f4..4ecec6669 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -145,8 +145,11 @@ struct ProblemFinder : public ControlFlowWalker<ProblemFinder> {
struct BreakValueDropper : public ControlFlowWalker<BreakValueDropper> {
Name origin;
PassOptions& passOptions;
+ BranchUtils::BranchSeekerCache& branchInfo;
- BreakValueDropper(PassOptions& passOptions) : passOptions(passOptions) {}
+ BreakValueDropper(PassOptions& passOptions,
+ BranchUtils::BranchSeekerCache& branchInfo)
+ : passOptions(passOptions), branchInfo(branchInfo) {}
void visitBlock(Block* curr);
@@ -198,8 +201,10 @@ static bool hasDeadCode(Block* block) {
}
// core block optimizer routine
-static void
-optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) {
+static void optimizeBlock(Block* curr,
+ Module* module,
+ PassOptions& passOptions,
+ BranchUtils::BranchSeekerCache& branchInfo) {
auto& list = curr->list;
// Main merging loop.
bool more = true;
@@ -237,7 +242,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) {
childBlock = nullptr;
} else {
// fix up breaks
- BreakValueDropper fixer(passOptions);
+ BreakValueDropper fixer(passOptions, branchInfo);
fixer.origin = childBlock->name;
fixer.setModule(module);
fixer.walk(expression);
@@ -294,7 +299,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) {
auto childName = childBlock->name;
for (size_t j = 0; j < childSize; j++) {
auto* item = childList[j];
- if (BranchUtils::BranchSeeker::has(item, childName)) {
+ if (branchInfo.hasBranch(item, childName)) {
// We can't remove this from the child.
keepStart = j;
keepEnd = childSize;
@@ -360,6 +365,13 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) {
if (loop) {
loop->finalize();
}
+ // Note that we modify the child block here, which invalidates info
+ // in branchInfo. However, as we have scanned the parent, we have
+ // already forgotten the child's info, so there is nothing to do here
+ // for the child.
+ // (We also don't need to do anything for the parent - we move code
+ // from a child into the parent, but that doesn't change the total
+ // branches in the parent.)
}
// Add the rest of the parent block after the child.
for (size_t j = i + 1; j < list.size(); j++) {
@@ -387,7 +399,7 @@ optimizeBlock(Block* curr, Module* module, PassOptions& passOptions) {
}
void BreakValueDropper::visitBlock(Block* curr) {
- optimizeBlock(curr, getModule(), passOptions);
+ optimizeBlock(curr, getModule(), passOptions, branchInfo);
}
struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
@@ -395,8 +407,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
Pass* create() override { return new MergeBlocks; }
+ BranchUtils::BranchSeekerCache branchInfo;
+
void visitBlock(Block* curr) {
- optimizeBlock(curr, getModule(), getPassOptions());
+ optimizeBlock(curr, getModule(), getPassOptions(), branchInfo);
}
// given
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 8410cc0b2..3e1a057ce 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -19,7 +19,6 @@
#include <sstream>
#include <unordered_set>
-#include "ir/branch-utils.h"
#include "ir/features.h"
#include "ir/global-utils.h"
#include "ir/module-utils.h"