diff options
Diffstat (limited to 'src/wasm-traversal.h')
-rw-r--r-- | src/wasm-traversal.h | 89 |
1 files changed, 58 insertions, 31 deletions
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 21f9b3256..555243621 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -100,6 +100,37 @@ struct Visitor { } }; +// Visit with a single unified visitor, called on every node, instead of +// separate visit* per node + +template<typename SubType, typename ReturnType = void> +struct UnifiedExpressionVisitor : public Visitor<SubType> { + // called on each node + ReturnType visitExpression(Expression* curr) {} + + // redirects + ReturnType visitBlock(Block *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitIf(If *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitLoop(Loop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitBreak(Break *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitSwitch(Switch *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitCall(Call *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitCallImport(CallImport *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitCallIndirect(CallIndirect *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitGetLocal(GetLocal *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitSetLocal(SetLocal *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitLoad(Load *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitStore(Store *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitConst(Const *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitUnreachable(Unreachable *curr) { return static_cast<SubType*>(this)->visitExpression(curr); } +}; + // // Base class for all WasmWalkers, which can traverse an AST // and provide the option to replace nodes while doing so. @@ -107,12 +138,8 @@ struct Visitor { // Subclass and implement the visit*() // calls to run code on different node types. // -template<typename SubType> -struct Walker : public Visitor<SubType> { - // Extra generic visitor, called before each node's specific visitor. Useful for - // passes that need to do the same thing for every node type. - void visitExpression(Expression* curr) {} - +template<typename SubType, typename VisitorType> +struct Walker : public VisitorType { // Function parallelism. By default, walks are not run in parallel, but you // can override this method to say that functions are parallelizable. This // should always be safe *unless* you do something in the pass that makes it @@ -245,26 +272,26 @@ struct Walker : public Visitor<SubType> { // task hooks to call visitors - static void doVisitBlock(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBlock((*currp)->cast<Block>()); } - static void doVisitIf(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitIf((*currp)->cast<If>()); } - static void doVisitLoop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoop((*currp)->cast<Loop>()); } - static void doVisitBreak(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBreak((*currp)->cast<Break>()); } - static void doVisitSwitch(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSwitch((*currp)->cast<Switch>()); } - static void doVisitCall(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCall((*currp)->cast<Call>()); } - static void doVisitCallImport(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallImport((*currp)->cast<CallImport>()); } - static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallIndirect((*currp)->cast<CallIndirect>()); } - static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitGetLocal((*currp)->cast<GetLocal>()); } - static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSetLocal((*currp)->cast<SetLocal>()); } - static void doVisitLoad(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoad((*currp)->cast<Load>()); } - static void doVisitStore(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitStore((*currp)->cast<Store>()); } - static void doVisitConst(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitConst((*currp)->cast<Const>()); } - static void doVisitUnary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnary((*currp)->cast<Unary>()); } - static void doVisitBinary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBinary((*currp)->cast<Binary>()); } - static void doVisitSelect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSelect((*currp)->cast<Select>()); } - static void doVisitReturn(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitReturn((*currp)->cast<Return>()); } - static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); } - static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); } - static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); } + static void doVisitBlock(SubType* self, Expression** currp) { self->visitBlock((*currp)->cast<Block>()); } + static void doVisitIf(SubType* self, Expression** currp) { self->visitIf((*currp)->cast<If>()); } + static void doVisitLoop(SubType* self, Expression** currp) { self->visitLoop((*currp)->cast<Loop>()); } + static void doVisitBreak(SubType* self, Expression** currp) { self->visitBreak((*currp)->cast<Break>()); } + static void doVisitSwitch(SubType* self, Expression** currp) { self->visitSwitch((*currp)->cast<Switch>()); } + static void doVisitCall(SubType* self, Expression** currp) { self->visitCall((*currp)->cast<Call>()); } + static void doVisitCallImport(SubType* self, Expression** currp) { self->visitCallImport((*currp)->cast<CallImport>()); } + static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); } + static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); } + static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); } + static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); } + static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); } + static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); } + static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); } + static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); } + static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); } + static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); } + static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); } + static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); } + static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitUnreachable((*currp)->cast<Unreachable>()); } void setFunction(Function *func) { currFunction = func; @@ -279,8 +306,8 @@ private: // Walks in post-order, i.e., children first. When there isn't an obvious // order to operands, we follow them in order of execution. -template<typename SubType> -struct PostWalker : public Walker<SubType> { +template<typename SubType, typename VisitorType> +struct PostWalker : public Walker<SubType, VisitorType> { static void scan(SubType* self, Expression** currp) { @@ -422,8 +449,8 @@ struct PostWalker : public Walker<SubType> { // When execution is no longer linear, this notifies via a call // to noteNonLinear(). -template<typename SubType> -struct LinearExecutionWalker : public PostWalker<SubType> { +template<typename SubType, typename VisitorType> +struct LinearExecutionWalker : public PostWalker<SubType, VisitorType> { LinearExecutionWalker() {} // subclasses should implement this @@ -486,7 +513,7 @@ struct LinearExecutionWalker : public PostWalker<SubType> { } default: { // other node types do not have control flow, use regular post-order - PostWalker<SubType>::scan(self, currp); + PostWalker<SubType, VisitorType>::scan(self, currp); } } } |