/* * Copyright 2016 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // WebAssembly AST visitor. Useful for anything that wants to do something // different for each AST node type, like printing, interpreting, etc. // // This class is specifically designed as a template to avoid virtual function // call overhead. To write a visitor, derive from this class as follows: // // struct MyVisitor : public WasmVisitor { .. } // #ifndef wasm_wasm_traversal_h #define wasm_wasm_traversal_h #include "wasm.h" #include "support/threads.h" namespace wasm { template struct Visitor { // Expression visitors ReturnType visitBlock(Block* curr) {} ReturnType visitIf(If* curr) {} ReturnType visitLoop(Loop* curr) {} ReturnType visitBreak(Break* curr) {} ReturnType visitSwitch(Switch* curr) {} ReturnType visitCall(Call* curr) {} ReturnType visitCallImport(CallImport* curr) {} ReturnType visitCallIndirect(CallIndirect* curr) {} ReturnType visitGetLocal(GetLocal* curr) {} ReturnType visitSetLocal(SetLocal* curr) {} ReturnType visitGetGlobal(GetGlobal* curr) {} ReturnType visitSetGlobal(SetGlobal* curr) {} ReturnType visitLoad(Load* curr) {} ReturnType visitStore(Store* curr) {} ReturnType visitConst(Const* curr) {} ReturnType visitUnary(Unary* curr) {} ReturnType visitBinary(Binary* curr) {} ReturnType visitSelect(Select* curr) {} ReturnType visitDrop(Drop* curr) {} ReturnType visitReturn(Return* curr) {} ReturnType visitHost(Host* curr) {} ReturnType visitNop(Nop* curr) {} ReturnType visitUnreachable(Unreachable* curr) {} // Module-level visitors ReturnType visitFunctionType(FunctionType* curr) {} ReturnType visitImport(Import* curr) {} ReturnType visitExport(Export* curr) {} ReturnType visitGlobal(Global* curr) {} ReturnType visitFunction(Function* curr) {} ReturnType visitTable(Table* curr) {} ReturnType visitMemory(Memory* curr) {} ReturnType visitModule(Module* curr) {} ReturnType visit(Expression* curr) { assert(curr); #define DELEGATE(CLASS_TO_VISIT) \ return static_cast(this)-> \ visit##CLASS_TO_VISIT(static_cast(curr)) switch (curr->_id) { case Expression::Id::BlockId: DELEGATE(Block); case Expression::Id::IfId: DELEGATE(If); case Expression::Id::LoopId: DELEGATE(Loop); case Expression::Id::BreakId: DELEGATE(Break); case Expression::Id::SwitchId: DELEGATE(Switch); case Expression::Id::CallId: DELEGATE(Call); case Expression::Id::CallImportId: DELEGATE(CallImport); case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); case Expression::Id::GetLocalId: DELEGATE(GetLocal); case Expression::Id::SetLocalId: DELEGATE(SetLocal); case Expression::Id::GetGlobalId: DELEGATE(GetGlobal); case Expression::Id::SetGlobalId: DELEGATE(SetGlobal); case Expression::Id::LoadId: DELEGATE(Load); case Expression::Id::StoreId: DELEGATE(Store); case Expression::Id::ConstId: DELEGATE(Const); case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); case Expression::Id::UnreachableId: DELEGATE(Unreachable); case Expression::Id::InvalidId: default: WASM_UNREACHABLE(); } #undef DELEGATE } }; // Visit with a single unified visitor, called on every node, instead of // separate visit* per node template struct UnifiedExpressionVisitor : public Visitor { // called on each node ReturnType visitExpression(Expression* curr) {} // redirects ReturnType visitBlock(Block* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitIf(If* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitLoop(Loop* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitBreak(Break* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitSwitch(Switch* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitCall(Call* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitCallImport(CallImport* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitCallIndirect(CallIndirect* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitGetLocal(GetLocal* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitSetLocal(SetLocal* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitGetGlobal(GetGlobal* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitSetGlobal(SetGlobal* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitLoad(Load* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitStore(Store* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitConst(Const* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitUnary(Unary* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitBinary(Binary* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitSelect(Select* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitDrop(Drop* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitReturn(Return* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitHost(Host* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitNop(Nop* curr) { return static_cast(this)->visitExpression(curr); } ReturnType visitUnreachable(Unreachable* curr) { return static_cast(this)->visitExpression(curr); } }; // // Base class for all WasmWalkers, which can traverse an AST // and provide the option to replace nodes while doing so. // // Subclass and implement the visit*() // calls to run code on different node types. // template struct Walker : public VisitorType { // Useful methods for visitor implementions // Replace the current node. You can call this in your visit*() methods. // Note that the visit*() for the result node is not called for you (i.e., // just one visit*() method is called by the traversal; if you replace a node, // and you want to process the output, you must do that explicitly). Expression* replaceCurrent(Expression* expression) { return replace = expression; } // Get the current module Module* getModule() { return currModule; } // Get the current function Function* getFunction() { return currFunction; } // Walk starting void walkGlobal(Global* global) { walk(global->init); static_cast(this)->visitGlobal(global); } void walkFunction(Function* func) { setFunction(func); static_cast(this)->doWalkFunction(func); static_cast(this)->visitFunction(func); setFunction(nullptr); } // override this to provide custom functionality void doWalkFunction(Function* func) { walk(func->body); } void walkTable(Table* table) { for (auto& segment : table->segments) { walk(segment.offset); } static_cast(this)->visitTable(table); } void walkMemory(Memory* memory) { for (auto& segment : memory->segments) { walk(segment.offset); } static_cast(this)->visitMemory(memory); } void walkModule(Module* module) { setModule(module); static_cast(this)->doWalkModule(module); static_cast(this)->visitModule(module); setModule(nullptr); } // override this to provide custom functionality void doWalkModule(Module* module) { // Dispatch statically through the SubType. SubType* self = static_cast(this); for (auto& curr : module->functionTypes) { self->visitFunctionType(curr.get()); } for (auto& curr : module->imports) { self->visitImport(curr.get()); } for (auto& curr : module->exports) { self->visitExport(curr.get()); } for (auto& curr : module->globals) { self->walkGlobal(curr.get()); } for (auto& curr : module->functions) { self->walkFunction(curr.get()); } self->walkTable(&module->table); self->walkMemory(&module->memory); } // Walk implementation. We don't use recursion as ASTs may be highly // nested. // Tasks receive the this pointer and a pointer to the pointer to operate on typedef void (*TaskFunc)(SubType*, Expression**); struct Task { TaskFunc func; Expression** currp; Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {} }; void pushTask(TaskFunc func, Expression** currp) { stack.emplace_back(func, currp); } void maybePushTask(TaskFunc func, Expression** currp) { if (*currp) { stack.emplace_back(func, currp); } } Task popTask() { auto ret = stack.back(); stack.pop_back(); return ret; } void walk(Expression*& root) { assert(stack.size() == 0); pushTask(SubType::scan, &root); while (stack.size() > 0) { auto task = popTask(); assert(*task.currp); task.func(static_cast(this), task.currp); if (replace) { *task.currp = replace; replace = nullptr; } } } // subclasses implement this to define the proper order of execution static void scan(SubType* self, Expression** currp) { abort(); } // task hooks to call visitors static void doVisitBlock(SubType* self, Expression** currp) { self->visitBlock((*currp)->cast()); } static void doVisitIf(SubType* self, Expression** currp) { self->visitIf((*currp)->cast()); } static void doVisitLoop(SubType* self, Expression** currp) { self->visitLoop((*currp)->cast()); } static void doVisitBreak(SubType* self, Expression** currp) { self->visitBreak((*currp)->cast()); } static void doVisitSwitch(SubType* self, Expression** currp) { self->visitSwitch((*currp)->cast()); } static void doVisitCall(SubType* self, Expression** currp) { self->visitCall((*currp)->cast()); } static void doVisitCallImport(SubType* self, Expression** currp) { self->visitCallImport((*currp)->cast()); } static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast()); } static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast()); } static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast()); } static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast()); } static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast()); } static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast()); } static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast()); } static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast()); } static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast()); } static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast()); } static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast()->condition); self->pushTask(SubType::scan, &curr->cast()->ifTrue); break; } case Expression::Id::DropId: { self->pushTask(SubType::doVisitDrop, currp); self->pushTask(SubType::scan, &curr->cast()->value); break; } case Expression::Id::ReturnId: { self->pushTask(SubType::doVisitReturn, currp); self->maybePushTask(SubType::scan, &curr->cast()->value); break; } case Expression::Id::HostId: { self->pushTask(SubType::doVisitHost, currp); auto& list = curr->cast()->operands; for (int i = int(list.size()) - 1; i >= 0; i--) { self->pushTask(SubType::scan, &list[i]); } break; } case Expression::Id::NopId: { self->pushTask(SubType::doVisitNop, currp); break; } case Expression::Id::UnreachableId: { self->pushTask(SubType::doVisitUnreachable, currp); break; } default: WASM_UNREACHABLE(); } } }; // Traversal with a control-flow stack. template> struct ControlFlowWalker : public PostWalker { ControlFlowWalker() {} std::vector controlFlowStack; // contains blocks, loops, and ifs // Uses the control flow stack to find the target of a break to a name Expression* findBreakTarget(Name name) { assert(!controlFlowStack.empty()); Index i = controlFlowStack.size() - 1; while (1) { auto* curr = controlFlowStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) return curr; } else if (Loop* loop = curr->template dynCast()) { if (name == loop->name) return curr; } else { // an if, ignorable assert(curr->template is()); } if (i == 0) return nullptr; i--; } } static void doPreVisitControlFlow(SubType* self, Expression** currp) { self->controlFlowStack.push_back(*currp); } static void doPostVisitControlFlow(SubType* self, Expression** currp) { assert(self->controlFlowStack.back() == *currp); self->controlFlowStack.pop_back(); } static void scan(SubType* self, Expression** currp) { auto* curr = *currp; switch (curr->_id) { case Expression::Id::BlockId: case Expression::Id::IfId: case Expression::Id::LoopId: { self->pushTask(SubType::doPostVisitControlFlow, currp); break; } default: {} } PostWalker::scan(self, currp); switch (curr->_id) { case Expression::Id::BlockId: case Expression::Id::IfId: case Expression::Id::LoopId: { self->pushTask(SubType::doPreVisitControlFlow, currp); break; } default: {} } } }; // Traversal with an expression stack. template> struct ExpressionStackWalker : public PostWalker { ExpressionStackWalker() {} std::vector expressionStack; // Uses the control flow stack to find the target of a break to a name Expression* findBreakTarget(Name name) { assert(!expressionStack.empty()); Index i = expressionStack.size() - 1; while (1) { auto* curr = expressionStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) return curr; } else if (Loop* loop = curr->template dynCast()) { if (name == loop->name) return curr; } else { WASM_UNREACHABLE(); } if (i == 0) return nullptr; i--; } } static void doPreVisit(SubType* self, Expression** currp) { self->expressionStack.push_back(*currp); } static void doPostVisit(SubType* self, Expression** currp) { self->expressionStack.pop_back(); } static void scan(SubType* self, Expression** currp) { self->pushTask(SubType::doPostVisit, currp); PostWalker::scan(self, currp); self->pushTask(SubType::doPreVisit, currp); } }; // Traversal in the order of execution. This is quick and simple, but // does not provide the same comprehensive information that a full // conversion to basic blocks would. What it does give is a quick // way to view straightline execution traces, i.e., that have no // branching. This can let optimizations get most of what they // want without the cost of creating another AST. // // When execution is no longer linear, this notifies via a call // to noteNonLinear(). template> struct LinearExecutionWalker : public PostWalker { LinearExecutionWalker() {} // subclasses should implement this void noteNonLinear(Expression* curr) { abort(); } static void doNoteNonLinear(SubType* self, Expression** currp) { self->noteNonLinear(*currp); } static void scan(SubType* self, Expression** currp) { Expression* curr = *currp; switch (curr->_id) { case Expression::Id::InvalidId: abort(); case Expression::Id::BlockId: { self->pushTask(SubType::doVisitBlock, currp); if (curr->cast()->name.is()) { self->pushTask(SubType::doNoteNonLinear, currp); } auto& list = curr->cast()->list; for (int i = int(list.size()) - 1; i >= 0; i--) { self->pushTask(SubType::scan, &list[i]); } break; } case Expression::Id::IfId: { self->pushTask(SubType::doVisitIf, currp); self->pushTask(SubType::doNoteNonLinear, currp); self->maybePushTask(SubType::scan, &curr->cast()->ifFalse); self->pushTask(SubType::doNoteNonLinear, currp); self->pushTask(SubType::scan, &curr->cast()->ifTrue); self->pushTask(SubType::doNoteNonLinear, currp); self->pushTask(SubType::scan, &curr->cast()->condition); break; } case Expression::Id::LoopId: { self->pushTask(SubType::doVisitLoop, currp); self->pushTask(SubType::scan, &curr->cast()->body); self->pushTask(SubType::doNoteNonLinear, currp); break; } case Expression::Id::BreakId: { self->pushTask(SubType::doVisitBreak, currp); self->pushTask(SubType::doNoteNonLinear, currp); self->maybePushTask(SubType::scan, &curr->cast()->condition); self->maybePushTask(SubType::scan, &curr->cast()->value); break; } case Expression::Id::SwitchId: { self->pushTask(SubType::doVisitSwitch, currp); self->pushTask(SubType::doNoteNonLinear, currp); self->maybePushTask(SubType::scan, &curr->cast()->value); self->pushTask(SubType::scan, &curr->cast()->condition); break; } case Expression::Id::ReturnId: { self->pushTask(SubType::doVisitReturn, currp); self->pushTask(SubType::doNoteNonLinear, currp); self->maybePushTask(SubType::scan, &curr->cast()->value); break; } case Expression::Id::UnreachableId: { self->pushTask(SubType::doVisitUnreachable, currp); self->pushTask(SubType::doNoteNonLinear, currp); break; } default: { // other node types do not have control flow, use regular post-order PostWalker::scan(self, currp); } } } }; } // namespace wasm #endif // wasm_wasm_traversal_h