diff options
Diffstat (limited to 'src/wasm-traversal.h')
-rw-r--r-- | src/wasm-traversal.h | 125 |
1 files changed, 120 insertions, 5 deletions
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index b50ca0fb2..4725237dd 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -53,6 +53,7 @@ struct Visitor { ReturnType visitUnary(Unary *curr) {} ReturnType visitBinary(Binary *curr) {} ReturnType visitSelect(Select *curr) {} + ReturnType visitDrop(Drop *curr) {} ReturnType visitReturn(Return *curr) {} ReturnType visitHost(Host *curr) {} ReturnType visitNop(Nop *curr) {} @@ -93,6 +94,7 @@ struct Visitor { case Expression::Id::UnaryId: DELEGATE(Unary); case Expression::Id::BinaryId: DELEGATE(Binary); case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); case Expression::Id::ReturnId: DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); case Expression::Id::NopId: DELEGATE(Nop); @@ -132,6 +134,7 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> { ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitDrop(Drop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } @@ -153,8 +156,8 @@ struct Walker : public VisitorType { // Note that the visit*() for the result node is not called for you (i.e., // just one visit*() method is called by the traversal; if you replace a node, // and you want to process the output, you must do that explicitly). - void replaceCurrent(Expression *expression) { - replace = expression; + Expression* replaceCurrent(Expression *expression) { + return replace = expression; } // Get the current module @@ -264,14 +267,15 @@ struct Walker : public VisitorType { static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); } static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); } static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); } - static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } - static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } + static void doVisitGetGlobal(SubType* self, Expression** currp) { self->visitGetGlobal((*currp)->cast<GetGlobal>()); } + static void doVisitSetGlobal(SubType* self, Expression** currp) { self->visitSetGlobal((*currp)->cast<SetGlobal>()); } static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); } static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); } static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); } static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); } static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); } static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); } + static void doVisitDrop(SubType* self, Expression** currp) { self->visitDrop((*currp)->cast<Drop>()); } static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); } static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); } static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); } @@ -354,10 +358,10 @@ struct PostWalker : public Walker<SubType, VisitorType> { case Expression::Id::CallIndirectId: { self->pushTask(SubType::doVisitCallIndirect, currp); auto& list = curr->cast<CallIndirect>()->operands; + self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target); for (int i = int(list.size()) - 1; i >= 0; i--) { self->pushTask(SubType::scan, &list[i]); } - self->pushTask(SubType::scan, &curr->cast<CallIndirect>()->target); break; } case Expression::Id::GetLocalId: { @@ -411,6 +415,11 @@ struct PostWalker : public Walker<SubType, VisitorType> { self->pushTask(SubType::scan, &curr->cast<Select>()->ifTrue); break; } + case Expression::Id::DropId: { + self->pushTask(SubType::doVisitDrop, currp); + self->pushTask(SubType::scan, &curr->cast<Drop>()->value); + break; + } case Expression::Id::ReturnId: { self->pushTask(SubType::doVisitReturn, currp); self->maybePushTask(SubType::scan, &curr->cast<Return>()->value); @@ -437,6 +446,112 @@ struct PostWalker : public Walker<SubType, VisitorType> { } }; +// Traversal with a control-flow stack. + +template<typename SubType, typename VisitorType> +struct ControlFlowWalker : public PostWalker<SubType, VisitorType> { + ControlFlowWalker() {} + + std::vector<Expression*> controlFlowStack; // contains blocks, loops, and ifs + + // Uses the control flow stack to find the target of a break to a name + Expression* findBreakTarget(Name name) { + assert(!controlFlowStack.empty()); + Index i = controlFlowStack.size() - 1; + while (1) { + auto* curr = controlFlowStack[i]; + if (Block* block = curr->template dynCast<Block>()) { + if (name == block->name) return curr; + } else if (Loop* loop = curr->template dynCast<Loop>()) { + if (name == loop->name) return curr; + } else { + // an if, ignorable + assert(curr->template is<If>()); + } + if (i == 0) return nullptr; + i--; + } + } + + static void doPreVisitControlFlow(SubType* self, Expression** currp) { + self->controlFlowStack.push_back(*currp); + } + + static void doPostVisitControlFlow(SubType* self, Expression** currp) { + assert(self->controlFlowStack.back() == *currp); + self->controlFlowStack.pop_back(); + } + + static void scan(SubType* self, Expression** currp) { + auto* curr = *currp; + + switch (curr->_id) { + case Expression::Id::BlockId: + case Expression::Id::IfId: + case Expression::Id::LoopId: { + self->pushTask(SubType::doPostVisitControlFlow, currp); + break; + } + default: {} + } + + PostWalker<SubType, VisitorType>::scan(self, currp); + + switch (curr->_id) { + case Expression::Id::BlockId: + case Expression::Id::IfId: + case Expression::Id::LoopId: { + self->pushTask(SubType::doPreVisitControlFlow, currp); + break; + } + default: {} + } + } +}; + +// Traversal with an expression stack. + +template<typename SubType, typename VisitorType> +struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> { + ExpressionStackWalker() {} + + std::vector<Expression*> expressionStack; + + // Uses the control flow stack to find the target of a break to a name + Expression* findBreakTarget(Name name) { + assert(!expressionStack.empty()); + Index i = expressionStack.size() - 1; + while (1) { + auto* curr = expressionStack[i]; + if (Block* block = curr->template dynCast<Block>()) { + if (name == block->name) return curr; + } else if (Loop* loop = curr->template dynCast<Loop>()) { + if (name == loop->name) return curr; + } else { + WASM_UNREACHABLE(); + } + if (i == 0) return nullptr; + i--; + } + } + + static void doPreVisit(SubType* self, Expression** currp) { + self->expressionStack.push_back(*currp); + } + + static void doPostVisit(SubType* self, Expression** currp) { + self->expressionStack.pop_back(); + } + + static void scan(SubType* self, Expression** currp) { + self->pushTask(SubType::doPostVisit, currp); + + PostWalker<SubType, VisitorType>::scan(self, currp); + + self->pushTask(SubType::doPreVisit, currp); + } +}; + // Traversal in the order of execution. This is quick and simple, but // does not provide the same comprehensive information that a full // conversion to basic blocks would. What it does give is a quick |