summaryrefslogtreecommitdiff
path: root/src/cfg/cfg-traversal.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cfg/cfg-traversal.h')
-rw-r--r--src/cfg/cfg-traversal.h313
1 files changed, 313 insertions, 0 deletions
diff --git a/src/cfg/cfg-traversal.h b/src/cfg/cfg-traversal.h
new file mode 100644
index 000000000..f6c055542
--- /dev/null
+++ b/src/cfg/cfg-traversal.h
@@ -0,0 +1,313 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Convert the AST to a CFG, while traversing it.
+//
+// Note that this is not the same as the relooper CFG. The relooper is
+// designed for compilation to an AST, this is for processing. There is
+// no built-in support for transforming this CFG into the AST back
+// again, it is just metadata on the side for computation purposes.
+//
+// Usage: As the traversal proceeds, you can note information and add it to
+// the current basic block using currBasicBlock, on the contents
+// property, whose type is user-defined.
+//
+
+#ifndef cfg_traversal_h
+#define cfg_traversal_h
+
+#include "wasm.h"
+#include "wasm-traversal.h"
+
+namespace wasm {
+
+template<typename SubType, typename VisitorType, typename Contents>
+struct CFGWalker : public PostWalker<SubType, VisitorType> {
+
+ // public interface
+
+ struct BasicBlock {
+ Contents contents; // custom contents
+ std::vector<BasicBlock*> out, in;
+ };
+
+ BasicBlock* entry; // the entry block
+
+ BasicBlock* makeBasicBlock() { // override this with code to create a BasicBlock if necessary
+ return new BasicBlock();
+ }
+
+ // internal details
+
+ std::vector<std::unique_ptr<BasicBlock>> basicBlocks; // all the blocks
+
+ // traversal state
+ BasicBlock* currBasicBlock; // the current block in play during traversal
+ std::map<Name, std::vector<BasicBlock*>> branches;
+ std::vector<BasicBlock*> ifStack;
+ std::vector<BasicBlock*> loopStack;
+
+ static void doStartBasicBlock(SubType* self, Expression** currp) {
+ self->currBasicBlock = self->makeBasicBlock();
+ self->basicBlocks.push_back(std::unique_ptr<BasicBlock>(self->currBasicBlock));
+ }
+
+ void link(BasicBlock* from, BasicBlock* to) {
+ from->out.push_back(to);
+ to->in.push_back(from);
+ }
+
+ static void doEndBlock(SubType* self, Expression** currp) {
+ // TODO: if no name, no need for a new block
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ self->link(last, self->currBasicBlock); // fallthrough
+ // branches to the new one
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) {
+ auto& origins = self->branches[curr->name];
+ for (auto* origin : origins) {
+ self->link(origin, self->currBasicBlock);
+ }
+ self->branches.erase(curr->name);
+ }
+ }
+
+ static void doStartIfTrue(SubType* self, Expression** currp) {
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ self->link(last, self->currBasicBlock); // ifTrue
+ self->ifStack.push_back(last); // the block before the ifTrue
+ }
+
+ static void doStartIfFalse(SubType* self, Expression** currp) {
+ self->ifStack.push_back(self->currBasicBlock); // the ifTrue fallthrough
+ doStartBasicBlock(self, currp);
+ self->link(self->ifStack[self->ifStack.size() - 2], self->currBasicBlock); // before if -> ifFalse
+ }
+
+ static void doEndIf(SubType* self, Expression** currp) {
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ self->link(last, self->currBasicBlock); // last one is ifFalse's fallthrough if there was one, otherwise it's the ifTrue fallthrough
+ if ((*currp)->cast<If>()->ifFalse) {
+ // we just linked ifFalse, need to link ifTrue to the end
+ self->link(self->ifStack.back(), self->currBasicBlock);
+ self->ifStack.pop_back();
+ } else {
+ // no ifFalse, so add a fallthrough for if the if is not taken
+ self->link(self->ifStack.back(), self->currBasicBlock);
+ }
+ self->ifStack.pop_back();
+ }
+
+ static void doStartLoop(SubType* self, Expression** currp) {
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ self->link(last, self->currBasicBlock);
+ self->loopStack.push_back(self->currBasicBlock);
+ }
+
+ static void doEndLoop(SubType* self, Expression** currp) {
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ self->link(last, self->currBasicBlock); // fallthrough
+ // branches to the new one
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->out.is()) {
+ auto& origins = self->branches[curr->out];
+ for (auto* origin : origins) {
+ self->link(origin, self->currBasicBlock);
+ }
+ self->branches.erase(curr->out);
+ }
+ // branches to the top of the loop
+ if (curr->in.is()) {
+ auto* loopStart = self->loopStack.back();
+ auto& origins = self->branches[curr->in];
+ for (auto* origin : origins) {
+ self->link(origin, loopStart);
+ }
+ self->branches.erase(curr->in);
+ }
+ self->loopStack.pop_back();
+ }
+
+ static void doEndBreak(SubType* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Break>();
+ self->branches[curr->name].push_back(self->currBasicBlock); // branch to the target
+ auto* last = self->currBasicBlock;
+ doStartBasicBlock(self, currp);
+ if (curr->condition) {
+ self->link(last, self->currBasicBlock); // we might fall through
+ }
+ }
+
+ static void doEndSwitch(SubType* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Switch>();
+ std::set<Name> seen; // we might see the same label more than once; do not spam branches
+ for (Name target : curr->targets) {
+ if (!seen.count(target)) {
+ self->branches[target].push_back(self->currBasicBlock); // branch to the target
+ seen.insert(target);
+ }
+ }
+ if (!seen.count(curr->default_)) {
+ self->branches[curr->default_].push_back(self->currBasicBlock); // branch to the target
+ }
+ doStartBasicBlock(self, currp);
+ }
+
+ static void scan(SubType* self, Expression** currp) {
+ Expression* curr = *currp;
+
+ switch (curr->_id) {
+ case Expression::Id::BlockId: {
+ self->pushTask(SubType::doEndBlock, currp);
+ self->pushTask(SubType::doVisitBlock, currp);
+ auto& list = curr->cast<Block>()->list;
+ for (int i = int(list.size()) - 1; i >= 0; i--) {
+ self->pushTask(SubType::scan, &list[i]);
+ }
+ break;
+ }
+ case Expression::Id::IfId: {
+ self->pushTask(SubType::doEndIf, currp);
+ self->pushTask(SubType::doVisitIf, currp);
+ auto* ifFalse = curr->cast<If>()->ifFalse;
+ if (ifFalse) {
+ self->pushTask(SubType::scan, &curr->cast<If>()->ifFalse);
+ self->pushTask(SubType::doStartIfFalse, currp);
+ }
+ self->pushTask(SubType::scan, &curr->cast<If>()->ifTrue);
+ self->pushTask(SubType::doStartIfTrue, currp);
+ self->pushTask(SubType::scan, &curr->cast<If>()->condition);
+ break;
+ }
+ case Expression::Id::LoopId: {
+ self->pushTask(SubType::doEndLoop, currp);
+ self->pushTask(SubType::doVisitLoop, currp);
+ self->pushTask(SubType::scan, &curr->cast<Loop>()->body);
+ self->pushTask(SubType::doStartLoop, currp);
+ break;
+ }
+ case Expression::Id::BreakId: {
+ self->pushTask(SubType::doEndBreak, currp);
+ self->pushTask(SubType::doVisitBreak, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->condition);
+ self->maybePushTask(SubType::scan, &curr->cast<Break>()->value);
+ break;
+ }
+ case Expression::Id::SwitchId: {
+ self->pushTask(SubType::doEndSwitch, currp);
+ self->pushTask(SubType::doVisitSwitch, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
+ self->pushTask(SubType::scan, &curr->cast<Switch>()->condition);
+ break;
+ }
+ case Expression::Id::ReturnId: {
+ self->pushTask(SubType::doStartBasicBlock, currp);
+ self->pushTask(SubType::doVisitReturn, currp);
+ self->maybePushTask(SubType::scan, &curr->cast<Return>()->value);
+ break;
+ }
+ case Expression::Id::UnreachableId: {
+ self->pushTask(SubType::doStartBasicBlock, currp);
+ self->pushTask(SubType::doVisitUnreachable, currp);
+ break;
+ }
+ default: {
+ // other node types do not have control flow, use regular post-order
+ PostWalker<SubType, VisitorType>::scan(self, currp);
+ }
+ }
+ }
+
+ void walk(Expression*& root) {
+ basicBlocks.clear();
+
+ doStartBasicBlock(static_cast<SubType*>(this), nullptr);
+ entry = currBasicBlock;
+ PostWalker<SubType, VisitorType>::walk(root);
+
+ assert(branches.size() == 0);
+ assert(ifStack.size() == 0);
+ assert(loopStack.size() == 0);
+ }
+
+ std::unordered_set<BasicBlock*> findLiveBlocks() {
+ std::unordered_set<BasicBlock*> alive;
+ std::unordered_set<BasicBlock*> queue;
+ queue.insert(entry);
+ while (queue.size() > 0) {
+ auto iter = queue.begin();
+ auto* curr = *iter;
+ queue.erase(iter);
+ alive.insert(curr);
+ for (auto* out : curr->out) {
+ if (!alive.count(out)) queue.insert(out);
+ }
+ }
+ return alive;
+ }
+
+ void unlinkDeadBlocks(std::unordered_set<BasicBlock*> alive) {
+ for (auto& block : basicBlocks) {
+ if (!alive.count(block.get())) {
+ block->in.clear();
+ block->out.clear();
+ continue;
+ }
+ block->in.erase(std::remove_if(block->in.begin(), block->in.end(), [&alive](BasicBlock* other) {
+ return !alive.count(other);
+ }), block->in.end());
+ block->out.erase(std::remove_if(block->out.begin(), block->out.end(), [&alive](BasicBlock* other) {
+ return !alive.count(other);
+ }), block->out.end());
+ }
+ }
+
+ // TODO: utility method for optimizing cfg, removing empty blocks depending on their .content
+
+ std::map<BasicBlock*, size_t> debugIds;
+
+ void dumpCFG(std::string message, Function* func) {
+ std::cout << "<==\nCFG [" << message << "]:\n";
+ for (auto& block : basicBlocks) {
+ debugIds[block.get()] = debugIds.size();
+ }
+ for (auto& block : basicBlocks) {
+ assert(debugIds.count(block.get()) > 0);
+ std::cout << " block " << debugIds[block.get()] << ":\n";
+ block->contents.dump(func);
+ for (auto& in : block->in) {
+ assert(debugIds.count(in) > 0);
+ assert(std::find(in->out.begin(), in->out.end(), block.get()) != in->out.end()); // must be a parallel link back
+ }
+ for (auto& out : block->out) {
+ assert(debugIds.count(out) > 0);
+ std::cout << " out: " << debugIds[out] << "\n";
+ assert(std::find(out->in.begin(), out->in.end(), block.get()) != out->in.end()); // must be a parallel link back
+ }
+ }
+ std::cout << "==>\n";
+ }
+};
+
+} // namespace wasm
+
+#endif // cfg_traversal_h