summaryrefslogtreecommitdiff
path: root/src/wasm-traversal.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-traversal.h')
-rw-r--r--src/wasm-traversal.h89
1 files changed, 58 insertions, 31 deletions
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 21f9b3256..555243621 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -100,6 +100,37 @@ struct Visitor {
}
};
+// Visit with a single unified visitor, called on every node, instead of
+// separate visit* per node
+
+template<typename SubType, typename ReturnType = void>
+struct UnifiedExpressionVisitor : public Visitor<SubType> {
+ // called on each node
+ ReturnType visitExpression(Expression* curr) {}
+
+ // redirects
+ ReturnType visitBlock(Block *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitIf(If *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitLoop(Loop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitBreak(Break *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitSwitch(Switch *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitCall(Call *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitCallImport(CallImport *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitCallIndirect(CallIndirect *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitGetLocal(GetLocal *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitSetLocal(SetLocal *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitLoad(Load *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitStore(Store *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitConst(Const *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitUnary(Unary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitBinary(Binary *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitSelect(Select *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitReturn(Return *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitHost(Host *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitNop(Nop *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+ ReturnType visitUnreachable(Unreachable *curr) { return static_cast<SubType*>(this)->visitExpression(curr); }
+};
+
//
// Base class for all WasmWalkers, which can traverse an AST
// and provide the option to replace nodes while doing so.
@@ -107,12 +138,8 @@ struct Visitor {
// Subclass and implement the visit*()
// calls to run code on different node types.
//
-template<typename SubType>
-struct Walker : public Visitor<SubType> {
- // Extra generic visitor, called before each node's specific visitor. Useful for
- // passes that need to do the same thing for every node type.
- void visitExpression(Expression* curr) {}
-
+template<typename SubType, typename VisitorType>
+struct Walker : public VisitorType {
// Function parallelism. By default, walks are not run in parallel, but you
// can override this method to say that functions are parallelizable. This
// should always be safe *unless* you do something in the pass that makes it
@@ -245,26 +272,26 @@ struct Walker : public Visitor<SubType> {
// task hooks to call visitors
- static void doVisitBlock(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBlock((*currp)->cast<Block>()); }
- static void doVisitIf(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitIf((*currp)->cast<If>()); }
- static void doVisitLoop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoop((*currp)->cast<Loop>()); }
- static void doVisitBreak(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBreak((*currp)->cast<Break>()); }
- static void doVisitSwitch(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSwitch((*currp)->cast<Switch>()); }
- static void doVisitCall(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCall((*currp)->cast<Call>()); }
- static void doVisitCallImport(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallImport((*currp)->cast<CallImport>()); }
- static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitCallIndirect((*currp)->cast<CallIndirect>()); }
- static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitGetLocal((*currp)->cast<GetLocal>()); }
- static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSetLocal((*currp)->cast<SetLocal>()); }
- static void doVisitLoad(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitLoad((*currp)->cast<Load>()); }
- static void doVisitStore(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitStore((*currp)->cast<Store>()); }
- static void doVisitConst(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitConst((*currp)->cast<Const>()); }
- static void doVisitUnary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnary((*currp)->cast<Unary>()); }
- static void doVisitBinary(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitBinary((*currp)->cast<Binary>()); }
- static void doVisitSelect(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitSelect((*currp)->cast<Select>()); }
- static void doVisitReturn(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitReturn((*currp)->cast<Return>()); }
- static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); }
- static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); }
- static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); }
+ static void doVisitBlock(SubType* self, Expression** currp) { self->visitBlock((*currp)->cast<Block>()); }
+ static void doVisitIf(SubType* self, Expression** currp) { self->visitIf((*currp)->cast<If>()); }
+ static void doVisitLoop(SubType* self, Expression** currp) { self->visitLoop((*currp)->cast<Loop>()); }
+ static void doVisitBreak(SubType* self, Expression** currp) { self->visitBreak((*currp)->cast<Break>()); }
+ static void doVisitSwitch(SubType* self, Expression** currp) { self->visitSwitch((*currp)->cast<Switch>()); }
+ static void doVisitCall(SubType* self, Expression** currp) { self->visitCall((*currp)->cast<Call>()); }
+ static void doVisitCallImport(SubType* self, Expression** currp) { self->visitCallImport((*currp)->cast<CallImport>()); }
+ static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); }
+ static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); }
+ static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); }
+ static void doVisitLoad(SubType* self, Expression** currp) { self->visitLoad((*currp)->cast<Load>()); }
+ static void doVisitStore(SubType* self, Expression** currp) { self->visitStore((*currp)->cast<Store>()); }
+ static void doVisitConst(SubType* self, Expression** currp) { self->visitConst((*currp)->cast<Const>()); }
+ static void doVisitUnary(SubType* self, Expression** currp) { self->visitUnary((*currp)->cast<Unary>()); }
+ static void doVisitBinary(SubType* self, Expression** currp) { self->visitBinary((*currp)->cast<Binary>()); }
+ static void doVisitSelect(SubType* self, Expression** currp) { self->visitSelect((*currp)->cast<Select>()); }
+ static void doVisitReturn(SubType* self, Expression** currp) { self->visitReturn((*currp)->cast<Return>()); }
+ static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); }
+ static void doVisitNop(SubType* self, Expression** currp) { self->visitNop((*currp)->cast<Nop>()); }
+ static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitUnreachable((*currp)->cast<Unreachable>()); }
void setFunction(Function *func) {
currFunction = func;
@@ -279,8 +306,8 @@ private:
// Walks in post-order, i.e., children first. When there isn't an obvious
// order to operands, we follow them in order of execution.
-template<typename SubType>
-struct PostWalker : public Walker<SubType> {
+template<typename SubType, typename VisitorType>
+struct PostWalker : public Walker<SubType, VisitorType> {
static void scan(SubType* self, Expression** currp) {
@@ -422,8 +449,8 @@ struct PostWalker : public Walker<SubType> {
// When execution is no longer linear, this notifies via a call
// to noteNonLinear().
-template<typename SubType>
-struct LinearExecutionWalker : public PostWalker<SubType> {
+template<typename SubType, typename VisitorType>
+struct LinearExecutionWalker : public PostWalker<SubType, VisitorType> {
LinearExecutionWalker() {}
// subclasses should implement this
@@ -486,7 +513,7 @@ struct LinearExecutionWalker : public PostWalker<SubType> {
}
default: {
// other node types do not have control flow, use regular post-order
- PostWalker<SubType>::scan(self, currp);
+ PostWalker<SubType, VisitorType>::scan(self, currp);
}
}
}