summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2017-05-10 14:38:20 -0700
committerGitHub <noreply@github.com>2017-05-10 14:38:20 -0700
commit6efa4a97b2c643bf35756bb6e989b90d6b3e44bc (patch)
tree64593c763ca43a2c641c652f5e39ac13db0618ef /src
parentdbe3384034e12909694c554149a1fb536fbece21 (diff)
downloadbinaryen-6efa4a97b2c643bf35756bb6e989b90d6b3e44bc.tar.gz
binaryen-6efa4a97b2c643bf35756bb6e989b90d6b3e44bc.tar.bz2
binaryen-6efa4a97b2c643bf35756bb6e989b90d6b3e44bc.zip
Flatten control flow pass (#999)
This pass flattens out control flow in order to achieve 2 properties: * Control flow structures (block, loop, if) and control flow operations (br, br_if, br_table, return, unreachable) may only be block children, a loop body, or an if-true or if-false. (I.e., they cannot be nested inside an i32.add, a drop, a call, an if-condition, etc.) * Disallow block, loop, and if return values, i.e., do not use control flow to pass around values. As a result, expressions cannot contain control flow, and overall control flow is simpler, more structured, and more "flat". This should make things like re-relooping wasm code much easier, as they can run after the cfg is flattened
Diffstat (limited to 'src')
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/FlattenControlFlow.cpp471
-rw-r--r--src/passes/pass.cpp1
-rw-r--r--src/passes/passes.h1
-rw-r--r--src/wasm-builder.h5
5 files changed, 479 insertions, 0 deletions
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 637f40d94..d7f17a3ca 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -5,6 +5,7 @@ SET(passes_SOURCES
DeadCodeElimination.cpp
DuplicateFunctionElimination.cpp
ExtractFunction.cpp
+ FlattenControlFlow.cpp
Inlining.cpp
LegalizeJSInterface.cpp
LocalCSE.cpp
diff --git a/src/passes/FlattenControlFlow.cpp b/src/passes/FlattenControlFlow.cpp
new file mode 100644
index 000000000..3da5809c3
--- /dev/null
+++ b/src/passes/FlattenControlFlow.cpp
@@ -0,0 +1,471 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Flattens control flow, e.g.
+//
+// (i32.add
+// (if (..condition..)
+// (..if true..)
+// (..if false..)
+// )
+// (i32.const 1)
+// )
+// =>
+// (if (..condition..)
+// (set_local $temp
+// (..if true..)
+// )
+// (set_local $temp
+// (..if false..)
+// )
+// )
+// (i32.add
+// (get_local $temp)
+// (i32.const 1)
+// )
+//
+// Formally, this pass flattens control flow in the precise sense of
+// making the AST have these properties:
+//
+// 1. Control flow structures (block, loop, if) and control flow
+// operations (br, br_if, br_table, return, unreachable) may
+// only be block children, a loop body, or an if-true or if-false.
+// (I.e., they cannot be nested inside an i32.add, a drop, a
+// call, an if-condition, etc.)
+// 2. Disallow block, loop, and if return values, i.e., do not use
+// control flow to pass around values.
+//
+// Note that we do still allow normal arbitrary nesting of expressions
+// *without* control flow (i.e., this is not a reduction to 3-address
+// code form). We also allow nesting of control flow, but just nested
+// in other control flow, like an if in the true arm of an if, and
+// so forth. What we achieve here is that when you see an expression,
+// you know it has no control flow inside it, it will be fully
+// executed.
+//
+
+#include <wasm.h>
+#include <pass.h>
+#include <wasm-builder.h>
+
+
+namespace wasm {
+
+// Looks for control flow changes and structures, excluding blocks (as we
+// want to put all control flow on them)
+struct ControlFlowChecker : public Visitor<ControlFlowChecker> {
+ static bool is(Expression* node) {
+ ControlFlowChecker finder;
+ finder.visit(node);
+ return finder.hasControlFlow;
+ }
+
+ bool hasControlFlow = false;
+
+ void visitBreak(Break *curr) { hasControlFlow = true; }
+ void visitSwitch(Switch *curr) { hasControlFlow = true; }
+ void visitBlock(Block *curr) { hasControlFlow = true; }
+ void visitLoop(Loop* curr) { hasControlFlow = true; }
+ void visitIf(If* curr) { hasControlFlow = true; }
+ void visitReturn(Return *curr) { hasControlFlow = true; }
+ void visitUnreachable(Unreachable *curr) { hasControlFlow = true; }
+};
+
+struct FlattenControlFlow : public WalkerPass<PostWalker<FlattenControlFlow>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new FlattenControlFlow; }
+
+ std::unique_ptr<Builder> builder;
+ // we get rid of block/if/loop values. this map tells us for
+ // each break target what local index to use.
+ // if this is a flowing value, there might not be a name assigned
+ // (block ending, block with no name; or if value), so we use
+ // the expr (and there will be exactly one set and get of it,
+ // so we don't need a name)
+ std::map<Name, Index> breakNameIndexes;
+ std::map<Expression*, Index> breakExprIndexes;
+
+ void doWalkFunction(Function* func) {
+ builder = make_unique<Builder>(*getModule());
+ walk(func->body);
+ if (func->result != none) {
+ // if the body had a fallthrough, receive it and return it
+ auto iter = breakExprIndexes.find(func->body);
+ if (iter != breakExprIndexes.end()) {
+ func->body = builder->makeSequence(
+ func->body,
+ builder->makeReturn(
+ builder->makeGetLocal(iter->second, func->result)
+ )
+ );
+ }
+ }
+ }
+
+ // returns the index to assign values to for a break target. allocates
+ // the local if this is the first time we see it.
+ // expr is used if this is a flowing value.
+ Index getBreakTargetIndex(Name name, WasmType type, Expression* expr = nullptr, Index index = -1) {
+ assert(isConcreteWasmType(type)); // we shouldn't get here if the value ins't actually set
+ if (name.is()) {
+ auto iter = breakNameIndexes.find(name);
+ if (iter == breakNameIndexes.end()) {
+ if (index == Index(-1)) {
+ index = builder->addVar(getFunction(), type);
+ }
+ breakNameIndexes[name] = index;
+ if (expr) {
+ breakExprIndexes[expr] = index;
+ }
+ return index;
+ }
+ if (expr) {
+ breakExprIndexes[expr] = iter->second;
+ }
+ return iter->second;
+ } else {
+ assert(expr);
+ auto iter = breakExprIndexes.find(expr);
+ if (iter == breakExprIndexes.end()) {
+ if (index == Index(-1)) {
+ index = builder->addVar(getFunction(), type);
+ }
+ return breakExprIndexes[expr] = index;
+ }
+ return iter->second;
+ }
+ }
+
+ // When we reach a fallthrough value, it has already been flattened, and its value
+ // assigned to the proper local. Or, it may not have needed to be flattened,
+ // and we can just assign to a local. This method simply returns the fallthrough
+ // replacement code.
+ Expression* getFallthroughReplacement(Expression* child, Index myIndex) {
+ auto iter = breakExprIndexes.find(child);
+ if (iter != breakExprIndexes.end()) {
+ // it was flattened and saved to a local
+ return builder->makeSequence(
+ child, // which no longer flows a value, now it sets the child index
+ builder->makeSetLocal(
+ myIndex,
+ builder->makeGetLocal(iter->second, getFunction()->getLocalType(iter->second))
+ )
+ );
+ }
+ // a simple expression
+ if (child->type == unreachable) {
+ // no need to even set the local
+ return child;
+ } else {
+ assert(!ControlFlowChecker::is(child));
+ return builder->makeSetLocal(
+ myIndex,
+ child
+ );
+ }
+ }
+
+ // flattening fallthroughs makes them have type none. this gets their true type
+ WasmType getFallthroughType(Expression* child) {
+ auto iter = breakExprIndexes.find(child);
+ if (iter != breakExprIndexes.end()) {
+ // it was flattened and saved to a local
+ return getFunction()->getLocalType(iter->second);
+ }
+ assert(child->type != none);
+ return child->type;
+ }
+
+ // Splitter helper
+ struct Splitter {
+ Splitter(FlattenControlFlow& parent, Expression* node) : parent(parent), node(node) {}
+
+ ~Splitter() {
+ finish();
+ }
+
+ FlattenControlFlow& parent;
+ Expression* node;
+
+ std::vector<Expression**> children; // TODO: reuse in parent, avoiding mallocing on each node
+
+ void note(Expression*& child) {
+ // we accept nullptr inputs, for a non-existing child
+ if (!child) return;
+ children.push_back(&child);
+ }
+
+ Expression* replacement; // the final replacement for the current node
+ bool stop = false; // if a child is unreachable, we can stop
+
+ void finish() {
+ if (children.empty()) return;
+ // first, scan the list
+ bool hasControlFlowChild = false;
+ bool hasUnreachableChild = false;
+ for (auto** childp : children) {
+ // it's enough to look at the child, ignoring the contents, as the contents
+ // have already been processed before we got here, so they must have been
+ // flattened if necessary.
+ auto* child = *childp;
+ if (ControlFlowChecker::is(child)) {
+ hasControlFlowChild = true;
+ }
+ if (child->type == unreachable) {
+ hasUnreachableChild = true;
+ }
+ }
+ if (!hasControlFlowChild) {
+ // nothing to do here.
+ assert(!hasUnreachableChild); // all of them should be executed
+ return;
+ }
+ // we have at least one child we need to split out, so to preserve the order of operations,
+ // split them all out
+ Builder* builder = parent.builder.get();
+ std::vector<Index> tempIndexes;
+ for (auto** childp : children) {
+ auto* child = *childp;
+ if (isConcreteWasmType(child->type)) {
+ tempIndexes.push_back(builder->addVar(parent.getFunction(), child->type));
+ } else {
+ tempIndexes.push_back(-1);
+ }
+ }
+ // create a new replacement block
+ auto* block = builder->makeBlock();
+ for (Index i = 0; i < children.size(); i++) {
+ auto* child = *children[i];
+ auto type = child->type;
+ if (isConcreteWasmType(type)) {
+ // set the child to a local, and use it later
+ block->list.push_back(builder->makeSetLocal(tempIndexes[i], child));
+ *children[i] = builder->makeGetLocal(tempIndexes[i], type);
+ } else if (type == none) {
+ // a nested none can not happen normally, here it occurs after we flattened a nested
+ // we can use the local it already assigned to. TODO: don't even allocate one here
+ block->list.push_back(child);
+ assert(parent.breakExprIndexes.count(child) > 0);
+ auto index = parent.breakExprIndexes[child];
+ *children[i] = builder->makeGetLocal(
+ index,
+ parent.getFunction()->getLocalType(index)
+ );
+ } else if (type == unreachable) {
+ block->list.push_back(child);
+ break; // no need to push any more
+ } else {
+ WASM_UNREACHABLE();
+ }
+ }
+ if (!hasUnreachableChild) {
+ // we reached the end, so we need to emit the expression itself
+ // (which has been modified to replace children usages with get_locals)
+ block->list.push_back(node);
+ }
+ block->finalize();
+ // finally, we just created a new block, ending in node. If node is e.g.
+ // i32.add, then our block would return a value. so we must convert
+ // this new block to return a value through a local
+ parent.visitBlock(block);
+ // the block is now done
+ parent.replaceCurrent(block);
+ // if the node was potentially a flowthrough value, then it has an entry
+ // in breakExprIndexes, and since we are replacing it with this block,
+ // we must note it's index as the same, so it is found by the parent.
+ if (parent.breakExprIndexes.find(node) != parent.breakExprIndexes.end()) {
+ parent.breakExprIndexes[block] = parent.breakExprIndexes[node];
+ }
+ }
+ };
+
+ void visitBlock(Block* curr) {
+ if (isConcreteWasmType(curr->type)) {
+ curr->list.back() = getFallthroughReplacement(curr->list.back(), getBreakTargetIndex(curr->name, curr->type, curr));
+ curr->finalize();
+ }
+ }
+ void visitLoop(Loop* curr) {
+ if (isConcreteWasmType(curr->type)) {
+ curr->body = getFallthroughReplacement(curr->body, getBreakTargetIndex(Name(), curr->type, curr));
+ curr->finalize();
+ }
+ }
+ void visitIf(If* curr) {
+ if (isConcreteWasmType(curr->type)) {
+ auto targetIndex = getBreakTargetIndex(Name(), curr->type, curr);
+ curr->ifTrue = getFallthroughReplacement(curr->ifTrue, targetIndex);
+ curr->ifFalse = getFallthroughReplacement(curr->ifFalse, targetIndex);
+ curr->finalize();
+ }
+ Splitter splitter(*this, curr);
+ splitter.note(curr->condition);
+ }
+ void visitBreak(Break* curr) {
+ Expression* processed = curr;
+ // first of all, get rid of the value if there is one
+ if (curr->value) {
+ if (curr->value->type != unreachable) {
+ auto type = getFallthroughType(curr->value);
+ auto index = getBreakTargetIndex(curr->name, type);
+ auto* value = getFallthroughReplacement(curr->value, index);
+ curr->value = nullptr;
+ curr->finalize();
+ processed = builder->makeSequence(
+ value,
+ curr
+ );
+ replaceCurrent(processed);
+ if (curr->condition) {
+ // we already called getBreakTargetIndex for the value we send to our
+ // break target if we break. as this is a br_if with a value, it also
+ // flows out that value, so our parent needs to know how to receive it.
+ // we note the already-existing index we prepared before, for that value.
+ getBreakTargetIndex(Name(), type, processed, index);
+ }
+ } else {
+ // we have a value, but it has unreachable type. we can just replace
+ // ourselves with it, we won't reach a condition (if there is one) or the br
+ // itself
+ replaceCurrent(curr->value);
+ return;
+ }
+ }
+ Splitter splitter(*this, processed);
+ splitter.note(curr->condition);
+ }
+ void visitSwitch(Switch* curr) {
+ Expression* processed = curr;
+
+ // first of all, get rid of the value if there is one
+ if (curr->value) {
+ if (curr->value->type != unreachable) {
+ auto type = getFallthroughType(curr->value);
+ // we must assign the value to *all* the targets
+ auto temp = builder->addVar(getFunction(), type);
+ auto* value = getFallthroughReplacement(curr->value, temp);
+ curr->value = nullptr;
+ auto* block = builder->makeBlock();
+ block->list.push_back(value);
+ std::set<Name> names;
+ for (auto target : curr->targets) {
+ if (names.insert(target).second) {
+ block->list.push_back(
+ builder->makeSetLocal(
+ getBreakTargetIndex(target, type),
+ builder->makeGetLocal(temp, type)
+ )
+ );
+ }
+ }
+ if (names.insert(curr->default_).second) {
+ block->list.push_back(
+ builder->makeSetLocal(
+ getBreakTargetIndex(curr->default_, type),
+ builder->makeGetLocal(temp, type)
+ )
+ );
+ }
+ block->list.push_back(curr);
+ block->finalize();
+ replaceCurrent(block);
+ } else {
+ // we have a value, but it has unreachable type. we can just replace
+ // ourselves with it, we won't reach a condition (if there is one) or the br
+ // itself
+ replaceCurrent(curr->value);
+ return;
+ }
+ }
+ Splitter splitter(*this, processed);
+ splitter.note(curr->value);
+ splitter.note(curr->condition);
+ }
+ void visitCall(Call* curr) {
+ Splitter splitter(*this, curr);
+ for (auto*& operand : curr->operands) {
+ splitter.note(operand);
+ }
+ }
+ void visitCallImport(CallImport* curr) {
+ Splitter splitter(*this, curr);
+ for (auto*& operand : curr->operands) {
+ splitter.note(operand);
+ }
+ }
+ void visitCallIndirect(CallIndirect* curr) {
+ Splitter splitter(*this, curr);
+ for (auto*& operand : curr->operands) {
+ splitter.note(operand);
+ }
+ splitter.note(curr->target);
+ }
+ void visitSetLocal(SetLocal* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->value);
+ }
+ void visitSetGlobal(SetGlobal* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->value);
+ }
+ void visitLoad(Load* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->ptr);
+ }
+ void visitStore(Store* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->ptr);
+ splitter.note(curr->value);
+ }
+ void visitUnary(Unary* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->value);
+ }
+ void visitBinary(Binary* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->left);
+ splitter.note(curr->right);
+ }
+ void visitSelect(Select* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->ifTrue);
+ splitter.note(curr->ifFalse);
+ splitter.note(curr->condition);
+ }
+ void visitDrop(Drop* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->value);
+ }
+ void visitReturn(Return* curr) {
+ Splitter splitter(*this, curr);
+ splitter.note(curr->value);
+ }
+ void visitHost(Host* curr) {
+ Splitter splitter(*this, curr);
+ for (auto*& operand : curr->operands) {
+ splitter.note(operand);
+ }
+ }
+};
+
+Pass *createFlattenControlFlowPass() {
+ return new FlattenControlFlow();
+}
+
+} // namespace wasm
+
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 51e76cdf7..50c8da46b 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -70,6 +70,7 @@ void PassRegistry::registerPasses() {
registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass);
registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass);
registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass);
+ registerPass("flatten-control-flow", "flattens out control flow to be only on blocks, not nested as expressions", createFlattenControlFlowPass);
registerPass("inlining", "inlines functions (currently only ones with a single use)", createInliningPass);
registerPass("legalize-js-interface", "legalizes i64 types on the import/export boundary", createLegalizeJSInterfacePass);
registerPass("local-cse", "common subexpression elimination inside basic blocks", createLocalCSEPass);
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 709a903f1..302d82b9f 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -28,6 +28,7 @@ Pass *createCodePushingPass();
Pass *createDeadCodeEliminationPass();
Pass *createDuplicateFunctionEliminationPass();
Pass *createExtractFunctionPass();
+Pass *createFlattenControlFlowPass();
Pass *createFullPrinterPass();
Pass *createInliningPass();
Pass *createLegalizeJSInterfacePass();
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 61db3c9e8..04e0aa4de 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -76,6 +76,11 @@ public:
}
return ret;
}
+ Block* makeBlock(Name name, Expression* first = nullptr) {
+ auto* ret = makeBlock(first);
+ ret->name = name;
+ return ret;
+ }
If* makeIf(Expression* condition, Expression* ifTrue, Expression* ifFalse = nullptr) {
auto* ret = allocator.alloc<If>();
ret->condition = condition; ret->ifTrue = ifTrue; ret->ifFalse = ifFalse;