From 2d6a5ed822740a6ac53d1807fe2111bee5098bc6 Mon Sep 17 00:00:00 2001 From: Alon Zakai Date: Fri, 30 Oct 2015 21:01:12 -0700 Subject: refactoring --- src/wasm.h | 175 +++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 94 insertions(+), 81 deletions(-) (limited to 'src/wasm.h') diff --git a/src/wasm.h b/src/wasm.h index 290fc5d47..25d514368 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -11,6 +11,7 @@ #include #include +#include "simple_ast.h" #include "colors.h" namespace wasm { @@ -800,36 +801,85 @@ public: }; // -// Simple WebAssembly AST walker +// Simple WebAssembly AST visiting and children-first walking // -struct WasmWalker { +template +struct WasmVisitor { + virtual ReturnType visitBlock(Block *curr) = 0; + virtual ReturnType visitIf(If *curr) = 0; + virtual ReturnType visitLoop(Loop *curr) = 0; + virtual ReturnType visitLabel(Label *curr) = 0; + virtual ReturnType visitBreak(Break *curr) = 0; + virtual ReturnType visitSwitch(Switch *curr) = 0; + virtual ReturnType visitCall(Call *curr) = 0; + virtual ReturnType visitCallImport(CallImport *curr) = 0; + virtual ReturnType visitCallIndirect(CallIndirect *curr) = 0; + virtual ReturnType visitGetLocal(GetLocal *curr) = 0; + virtual ReturnType visitSetLocal(SetLocal *curr) = 0; + virtual ReturnType visitLoad(Load *curr) = 0; + virtual ReturnType visitStore(Store *curr) = 0; + virtual ReturnType visitConst(Const *curr) = 0; + virtual ReturnType visitUnary(Unary *curr) = 0; + virtual ReturnType visitBinary(Binary *curr) = 0; + virtual ReturnType visitCompare(Compare *curr) = 0; + virtual ReturnType visitConvert(Convert *curr) = 0; + virtual ReturnType visitHost(Host *curr) = 0; + virtual ReturnType visitNop(Nop *curr) = 0; + + ReturnType visit(Expression *curr) { + assert(curr); + if (Block *cast = dynamic_cast(curr)) return visitBlock(cast); + if (If *cast = dynamic_cast(curr)) return visitIf(cast); + if (Loop *cast = dynamic_cast(curr)) return visitLoop(cast); + if (Label *cast = dynamic_cast(curr)) return visitLabel(cast); + if (Break *cast = dynamic_cast(curr)) return visitBreak(cast); + if (Switch *cast = dynamic_cast(curr)) return visitSwitch(cast); + if (Call *cast = dynamic_cast(curr)) return visitCall(cast); + if (CallImport *cast = dynamic_cast(curr)) return visitCallImport(cast); + if (CallIndirect *cast = dynamic_cast(curr)) return visitCallIndirect(cast); + if (GetLocal *cast = dynamic_cast(curr)) return visitGetLocal(cast); + if (SetLocal *cast = dynamic_cast(curr)) return visitSetLocal(cast); + if (Load *cast = dynamic_cast(curr)) return visitLoad(cast); + if (Store *cast = dynamic_cast(curr)) return visitStore(cast); + if (Const *cast = dynamic_cast(curr)) return visitConst(cast); + if (Unary *cast = dynamic_cast(curr)) return visitUnary(cast); + if (Binary *cast = dynamic_cast(curr)) return visitBinary(cast); + if (Compare *cast = dynamic_cast(curr)) return visitCompare(cast); + if (Convert *cast = dynamic_cast(curr)) return visitConvert(cast); + if (Host *cast = dynamic_cast(curr)) return visitHost(cast); + if (Nop *cast = dynamic_cast(curr)) return visitNop(cast); + abort(); + } +}; + +struct WasmWalker : public WasmVisitor { wasm::Arena* allocator; // use an existing allocator, or null if no allocations WasmWalker() : allocator(nullptr) {} WasmWalker(wasm::Arena* allocator) : allocator(allocator) {} // Each method receives an AST pointer, and it is replaced with what is returned. - virtual Expression* walkBlock(Block *curr) { return curr; }; - virtual Expression* walkIf(If *curr) { return curr; }; - virtual Expression* walkLoop(Loop *curr) { return curr; }; - virtual Expression* walkLabel(Label *curr) { return curr; }; - virtual Expression* walkBreak(Break *curr) { return curr; }; - virtual Expression* walkSwitch(Switch *curr) { return curr; }; - virtual Expression* walkCall(Call *curr) { return curr; }; - virtual Expression* walkCallImport(CallImport *curr) { return curr; }; - virtual Expression* walkCallIndirect(CallIndirect *curr) { return curr; }; - virtual Expression* walkGetLocal(GetLocal *curr) { return curr; }; - virtual Expression* walkSetLocal(SetLocal *curr) { return curr; }; - virtual Expression* walkLoad(Load *curr) { return curr; }; - virtual Expression* walkStore(Store *curr) { return curr; }; - virtual Expression* walkConst(Const *curr) { return curr; }; - virtual Expression* walkUnary(Unary *curr) { return curr; }; - virtual Expression* walkBinary(Binary *curr) { return curr; }; - virtual Expression* walkCompare(Compare *curr) { return curr; }; - virtual Expression* walkConvert(Convert *curr) { return curr; }; - virtual Expression* walkHost(Host *curr) { return curr; }; - virtual Expression* walkNop(Nop *curr) { return curr; }; + virtual Expression* visitBlock(Block *curr) { return curr; }; + virtual Expression* visitIf(If *curr) { return curr; }; + virtual Expression* visitLoop(Loop *curr) { return curr; }; + virtual Expression* visitLabel(Label *curr) { return curr; }; + virtual Expression* visitBreak(Break *curr) { return curr; }; + virtual Expression* visitSwitch(Switch *curr) { return curr; }; + virtual Expression* visitCall(Call *curr) { return curr; }; + virtual Expression* visitCallImport(CallImport *curr) { return curr; }; + virtual Expression* visitCallIndirect(CallIndirect *curr) { return curr; }; + virtual Expression* visitGetLocal(GetLocal *curr) { return curr; }; + virtual Expression* visitSetLocal(SetLocal *curr) { return curr; }; + virtual Expression* visitLoad(Load *curr) { return curr; }; + virtual Expression* visitStore(Store *curr) { return curr; }; + virtual Expression* visitConst(Const *curr) { return curr; }; + virtual Expression* visitUnary(Unary *curr) { return curr; }; + virtual Expression* visitBinary(Binary *curr) { return curr; }; + virtual Expression* visitCompare(Compare *curr) { return curr; }; + virtual Expression* visitConvert(Convert *curr) { return curr; }; + virtual Expression* visitHost(Host *curr) { return curr; }; + virtual Expression* visitNop(Nop *curr) { return curr; }; // children-first Expression *walk(Expression *curr) { @@ -840,104 +890,67 @@ struct WasmWalker { for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkBlock(cast); - } - if (If *cast = dynamic_cast(curr)) { + } else if (If *cast = dynamic_cast(curr)) { cast->condition = walk(cast->condition); cast->ifTrue = walk(cast->ifTrue); cast->ifFalse = walk(cast->ifFalse); - return walkIf(cast); - } - if (Loop *cast = dynamic_cast(curr)) { + } else if (Loop *cast = dynamic_cast(curr)) { cast->body = walk(cast->body); - return walkLoop(cast); - } - if (Label *cast = dynamic_cast(curr)) { - return walkLabel(cast); - } - if (Break *cast = dynamic_cast(curr)) { + } else if (Label *cast = dynamic_cast(curr)) { + } else if (Break *cast = dynamic_cast(curr)) { cast->condition = walk(cast->condition); cast->value = walk(cast->value); - return walkBreak(cast); - } - if (Switch *cast = dynamic_cast(curr)) { + } else if (Switch *cast = dynamic_cast(curr)) { cast->value = walk(cast->value); for (auto& curr : cast->cases) { curr.body = walk(curr.body); } cast->default_ = walk(cast->default_); - return walkSwitch(cast); - } - if (Call *cast = dynamic_cast(curr)) { + } else if (Call *cast = dynamic_cast(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCall(cast); - } - if (CallImport *cast = dynamic_cast(curr)) { + } else if (CallImport *cast = dynamic_cast(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCallImport(cast); - } - if (CallIndirect *cast = dynamic_cast(curr)) { + } else if (CallIndirect *cast = dynamic_cast(curr)) { cast->target = walk(cast->target); ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCallIndirect(cast); - } - if (GetLocal *cast = dynamic_cast(curr)) { - return walkGetLocal(cast); - } - if (SetLocal *cast = dynamic_cast(curr)) { + } else if (GetLocal *cast = dynamic_cast(curr)) { + } else if (SetLocal *cast = dynamic_cast(curr)) { cast->value = walk(cast->value); - return walkSetLocal(cast); - } - if (Load *cast = dynamic_cast(curr)) { + } else if (Load *cast = dynamic_cast(curr)) { cast->ptr = walk(cast->ptr); - return walkLoad(cast); - } - if (Store *cast = dynamic_cast(curr)) { + } else if (Store *cast = dynamic_cast(curr)) { cast->ptr = walk(cast->ptr); cast->value = walk(cast->value); - return walkStore(cast); - } - if (Const *cast = dynamic_cast(curr)) { - return walkConst(cast); - } - if (Unary *cast = dynamic_cast(curr)) { + } else if (Const *cast = dynamic_cast(curr)) { + } else if (Unary *cast = dynamic_cast(curr)) { cast->value = walk(cast->value); - return walkUnary(cast); - } - if (Binary *cast = dynamic_cast(curr)) { + } else if (Binary *cast = dynamic_cast(curr)) { cast->left = walk(cast->left); cast->right = walk(cast->right); - return walkBinary(cast); - } - if (Compare *cast = dynamic_cast(curr)) { + } else if (Compare *cast = dynamic_cast(curr)) { cast->left = walk(cast->left); cast->right = walk(cast->right); - return walkCompare(cast); - } - if (Convert *cast = dynamic_cast(curr)) { + } else if (Convert *cast = dynamic_cast(curr)) { cast->value = walk(cast->value); - return walkConvert(cast); - } - if (Host *cast = dynamic_cast(curr)) { + } else if (Host *cast = dynamic_cast(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkHost(cast); - } - if (Nop *cast = dynamic_cast(curr)) { - return walkNop(cast); + } else if (Nop *cast = dynamic_cast(curr)) { + } else { + abort(); } - abort(); + return visit(curr); } void startWalk(Function *func) { -- cgit v1.2.3