summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h175
1 files changed, 94 insertions, 81 deletions
diff --git a/src/wasm.h b/src/wasm.h
index 290fc5d47..25d514368 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -11,6 +11,7 @@
#include <map>
#include <vector>
+#include "simple_ast.h"
#include "colors.h"
namespace wasm {
@@ -800,36 +801,85 @@ public:
};
//
-// Simple WebAssembly AST walker
+// Simple WebAssembly AST visiting and children-first walking
//
-struct WasmWalker {
+template<class ReturnType>
+struct WasmVisitor {
+ virtual ReturnType visitBlock(Block *curr) = 0;
+ virtual ReturnType visitIf(If *curr) = 0;
+ virtual ReturnType visitLoop(Loop *curr) = 0;
+ virtual ReturnType visitLabel(Label *curr) = 0;
+ virtual ReturnType visitBreak(Break *curr) = 0;
+ virtual ReturnType visitSwitch(Switch *curr) = 0;
+ virtual ReturnType visitCall(Call *curr) = 0;
+ virtual ReturnType visitCallImport(CallImport *curr) = 0;
+ virtual ReturnType visitCallIndirect(CallIndirect *curr) = 0;
+ virtual ReturnType visitGetLocal(GetLocal *curr) = 0;
+ virtual ReturnType visitSetLocal(SetLocal *curr) = 0;
+ virtual ReturnType visitLoad(Load *curr) = 0;
+ virtual ReturnType visitStore(Store *curr) = 0;
+ virtual ReturnType visitConst(Const *curr) = 0;
+ virtual ReturnType visitUnary(Unary *curr) = 0;
+ virtual ReturnType visitBinary(Binary *curr) = 0;
+ virtual ReturnType visitCompare(Compare *curr) = 0;
+ virtual ReturnType visitConvert(Convert *curr) = 0;
+ virtual ReturnType visitHost(Host *curr) = 0;
+ virtual ReturnType visitNop(Nop *curr) = 0;
+
+ ReturnType visit(Expression *curr) {
+ assert(curr);
+ if (Block *cast = dynamic_cast<Block*>(curr)) return visitBlock(cast);
+ if (If *cast = dynamic_cast<If*>(curr)) return visitIf(cast);
+ if (Loop *cast = dynamic_cast<Loop*>(curr)) return visitLoop(cast);
+ if (Label *cast = dynamic_cast<Label*>(curr)) return visitLabel(cast);
+ if (Break *cast = dynamic_cast<Break*>(curr)) return visitBreak(cast);
+ if (Switch *cast = dynamic_cast<Switch*>(curr)) return visitSwitch(cast);
+ if (Call *cast = dynamic_cast<Call*>(curr)) return visitCall(cast);
+ if (CallImport *cast = dynamic_cast<CallImport*>(curr)) return visitCallImport(cast);
+ if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) return visitCallIndirect(cast);
+ if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) return visitGetLocal(cast);
+ if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) return visitSetLocal(cast);
+ if (Load *cast = dynamic_cast<Load*>(curr)) return visitLoad(cast);
+ if (Store *cast = dynamic_cast<Store*>(curr)) return visitStore(cast);
+ if (Const *cast = dynamic_cast<Const*>(curr)) return visitConst(cast);
+ if (Unary *cast = dynamic_cast<Unary*>(curr)) return visitUnary(cast);
+ if (Binary *cast = dynamic_cast<Binary*>(curr)) return visitBinary(cast);
+ if (Compare *cast = dynamic_cast<Compare*>(curr)) return visitCompare(cast);
+ if (Convert *cast = dynamic_cast<Convert*>(curr)) return visitConvert(cast);
+ if (Host *cast = dynamic_cast<Host*>(curr)) return visitHost(cast);
+ if (Nop *cast = dynamic_cast<Nop*>(curr)) return visitNop(cast);
+ abort();
+ }
+};
+
+struct WasmWalker : public WasmVisitor<Expression*> {
wasm::Arena* allocator; // use an existing allocator, or null if no allocations
WasmWalker() : allocator(nullptr) {}
WasmWalker(wasm::Arena* allocator) : allocator(allocator) {}
// Each method receives an AST pointer, and it is replaced with what is returned.
- virtual Expression* walkBlock(Block *curr) { return curr; };
- virtual Expression* walkIf(If *curr) { return curr; };
- virtual Expression* walkLoop(Loop *curr) { return curr; };
- virtual Expression* walkLabel(Label *curr) { return curr; };
- virtual Expression* walkBreak(Break *curr) { return curr; };
- virtual Expression* walkSwitch(Switch *curr) { return curr; };
- virtual Expression* walkCall(Call *curr) { return curr; };
- virtual Expression* walkCallImport(CallImport *curr) { return curr; };
- virtual Expression* walkCallIndirect(CallIndirect *curr) { return curr; };
- virtual Expression* walkGetLocal(GetLocal *curr) { return curr; };
- virtual Expression* walkSetLocal(SetLocal *curr) { return curr; };
- virtual Expression* walkLoad(Load *curr) { return curr; };
- virtual Expression* walkStore(Store *curr) { return curr; };
- virtual Expression* walkConst(Const *curr) { return curr; };
- virtual Expression* walkUnary(Unary *curr) { return curr; };
- virtual Expression* walkBinary(Binary *curr) { return curr; };
- virtual Expression* walkCompare(Compare *curr) { return curr; };
- virtual Expression* walkConvert(Convert *curr) { return curr; };
- virtual Expression* walkHost(Host *curr) { return curr; };
- virtual Expression* walkNop(Nop *curr) { return curr; };
+ virtual Expression* visitBlock(Block *curr) { return curr; };
+ virtual Expression* visitIf(If *curr) { return curr; };
+ virtual Expression* visitLoop(Loop *curr) { return curr; };
+ virtual Expression* visitLabel(Label *curr) { return curr; };
+ virtual Expression* visitBreak(Break *curr) { return curr; };
+ virtual Expression* visitSwitch(Switch *curr) { return curr; };
+ virtual Expression* visitCall(Call *curr) { return curr; };
+ virtual Expression* visitCallImport(CallImport *curr) { return curr; };
+ virtual Expression* visitCallIndirect(CallIndirect *curr) { return curr; };
+ virtual Expression* visitGetLocal(GetLocal *curr) { return curr; };
+ virtual Expression* visitSetLocal(SetLocal *curr) { return curr; };
+ virtual Expression* visitLoad(Load *curr) { return curr; };
+ virtual Expression* visitStore(Store *curr) { return curr; };
+ virtual Expression* visitConst(Const *curr) { return curr; };
+ virtual Expression* visitUnary(Unary *curr) { return curr; };
+ virtual Expression* visitBinary(Binary *curr) { return curr; };
+ virtual Expression* visitCompare(Compare *curr) { return curr; };
+ virtual Expression* visitConvert(Convert *curr) { return curr; };
+ virtual Expression* visitHost(Host *curr) { return curr; };
+ virtual Expression* visitNop(Nop *curr) { return curr; };
// children-first
Expression *walk(Expression *curr) {
@@ -840,104 +890,67 @@ struct WasmWalker {
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
- return walkBlock(cast);
- }
- if (If *cast = dynamic_cast<If*>(curr)) {
+ } else if (If *cast = dynamic_cast<If*>(curr)) {
cast->condition = walk(cast->condition);
cast->ifTrue = walk(cast->ifTrue);
cast->ifFalse = walk(cast->ifFalse);
- return walkIf(cast);
- }
- if (Loop *cast = dynamic_cast<Loop*>(curr)) {
+ } else if (Loop *cast = dynamic_cast<Loop*>(curr)) {
cast->body = walk(cast->body);
- return walkLoop(cast);
- }
- if (Label *cast = dynamic_cast<Label*>(curr)) {
- return walkLabel(cast);
- }
- if (Break *cast = dynamic_cast<Break*>(curr)) {
+ } else if (Label *cast = dynamic_cast<Label*>(curr)) {
+ } else if (Break *cast = dynamic_cast<Break*>(curr)) {
cast->condition = walk(cast->condition);
cast->value = walk(cast->value);
- return walkBreak(cast);
- }
- if (Switch *cast = dynamic_cast<Switch*>(curr)) {
+ } else if (Switch *cast = dynamic_cast<Switch*>(curr)) {
cast->value = walk(cast->value);
for (auto& curr : cast->cases) {
curr.body = walk(curr.body);
}
cast->default_ = walk(cast->default_);
- return walkSwitch(cast);
- }
- if (Call *cast = dynamic_cast<Call*>(curr)) {
+ } else if (Call *cast = dynamic_cast<Call*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
- return walkCall(cast);
- }
- if (CallImport *cast = dynamic_cast<CallImport*>(curr)) {
+ } else if (CallImport *cast = dynamic_cast<CallImport*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
- return walkCallImport(cast);
- }
- if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) {
+ } else if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) {
cast->target = walk(cast->target);
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
- return walkCallIndirect(cast);
- }
- if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) {
- return walkGetLocal(cast);
- }
- if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) {
+ } else if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) {
+ } else if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) {
cast->value = walk(cast->value);
- return walkSetLocal(cast);
- }
- if (Load *cast = dynamic_cast<Load*>(curr)) {
+ } else if (Load *cast = dynamic_cast<Load*>(curr)) {
cast->ptr = walk(cast->ptr);
- return walkLoad(cast);
- }
- if (Store *cast = dynamic_cast<Store*>(curr)) {
+ } else if (Store *cast = dynamic_cast<Store*>(curr)) {
cast->ptr = walk(cast->ptr);
cast->value = walk(cast->value);
- return walkStore(cast);
- }
- if (Const *cast = dynamic_cast<Const*>(curr)) {
- return walkConst(cast);
- }
- if (Unary *cast = dynamic_cast<Unary*>(curr)) {
+ } else if (Const *cast = dynamic_cast<Const*>(curr)) {
+ } else if (Unary *cast = dynamic_cast<Unary*>(curr)) {
cast->value = walk(cast->value);
- return walkUnary(cast);
- }
- if (Binary *cast = dynamic_cast<Binary*>(curr)) {
+ } else if (Binary *cast = dynamic_cast<Binary*>(curr)) {
cast->left = walk(cast->left);
cast->right = walk(cast->right);
- return walkBinary(cast);
- }
- if (Compare *cast = dynamic_cast<Compare*>(curr)) {
+ } else if (Compare *cast = dynamic_cast<Compare*>(curr)) {
cast->left = walk(cast->left);
cast->right = walk(cast->right);
- return walkCompare(cast);
- }
- if (Convert *cast = dynamic_cast<Convert*>(curr)) {
+ } else if (Convert *cast = dynamic_cast<Convert*>(curr)) {
cast->value = walk(cast->value);
- return walkConvert(cast);
- }
- if (Host *cast = dynamic_cast<Host*>(curr)) {
+ } else if (Host *cast = dynamic_cast<Host*>(curr)) {
ExpressionList& list = cast->operands;
for (size_t z = 0; z < list.size(); z++) {
list[z] = walk(list[z]);
}
- return walkHost(cast);
- }
- if (Nop *cast = dynamic_cast<Nop*>(curr)) {
- return walkNop(cast);
+ } else if (Nop *cast = dynamic_cast<Nop*>(curr)) {
+ } else {
+ abort();
}
- abort();
+ return visit(curr);
}
void startWalk(Function *func) {