diff options
-rwxr-xr-x | build.sh | 1 | ||||
-rw-r--r-- | src/asm2wasm.cpp | 5 | ||||
-rw-r--r-- | src/wasm-interpreter.cpp | 3 | ||||
-rw-r--r-- | src/wasm.h | 175 |
4 files changed, 100 insertions, 84 deletions
@@ -1,2 +1,3 @@ +#g++ -std=c++11 src/wasm-interpreter.cpp -g -o bin/wasm g++ -std=c++11 src/asm2wasm.cpp src/parser.cpp src/simple_ast.cpp -g -o bin/asm2wasm diff --git a/src/asm2wasm.cpp b/src/asm2wasm.cpp index bdbef2d78..68717e67e 100644 --- a/src/asm2wasm.cpp +++ b/src/asm2wasm.cpp @@ -1,5 +1,4 @@ -#include "simple_ast.h" #include "wasm.h" #include "optimizer.h" @@ -1028,7 +1027,7 @@ void Asm2WasmBuilder::optimize() { struct BlockRemover : public WasmWalker { BlockRemover() : WasmWalker(nullptr) {} - Expression* walkBlock(Block *curr) override { + Expression* visitBlock(Block *curr) override { if (curr->list.size() != 1) return curr; // just one element; maybe we can return just the element if (curr->name.isNull()) return curr->list[0]; @@ -1042,7 +1041,7 @@ void Asm2WasmBuilder::optimize() { BreakSeeker(IString target) : target(target), found(false) {} - Expression* walkBreak(Break *curr) override { + Expression* visitBreak(Break *curr) override { if (curr->name == target) found++; } }; diff --git a/src/wasm-interpreter.cpp b/src/wasm-interpreter.cpp index 8e1294624..ae7b2e793 100644 --- a/src/wasm-interpreter.cpp +++ b/src/wasm-interpreter.cpp @@ -4,6 +4,9 @@ #include "wasm.h" +using namespace cashew; +using namespace wasm; + namespace wasm { // An instance of a WebAssembly module diff --git a/src/wasm.h b/src/wasm.h index 290fc5d47..25d514368 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -11,6 +11,7 @@ #include <map> #include <vector> +#include "simple_ast.h" #include "colors.h" namespace wasm { @@ -800,36 +801,85 @@ public: }; // -// Simple WebAssembly AST walker +// Simple WebAssembly AST visiting and children-first walking // -struct WasmWalker { +template<class ReturnType> +struct WasmVisitor { + virtual ReturnType visitBlock(Block *curr) = 0; + virtual ReturnType visitIf(If *curr) = 0; + virtual ReturnType visitLoop(Loop *curr) = 0; + virtual ReturnType visitLabel(Label *curr) = 0; + virtual ReturnType visitBreak(Break *curr) = 0; + virtual ReturnType visitSwitch(Switch *curr) = 0; + virtual ReturnType visitCall(Call *curr) = 0; + virtual ReturnType visitCallImport(CallImport *curr) = 0; + virtual ReturnType visitCallIndirect(CallIndirect *curr) = 0; + virtual ReturnType visitGetLocal(GetLocal *curr) = 0; + virtual ReturnType visitSetLocal(SetLocal *curr) = 0; + virtual ReturnType visitLoad(Load *curr) = 0; + virtual ReturnType visitStore(Store *curr) = 0; + virtual ReturnType visitConst(Const *curr) = 0; + virtual ReturnType visitUnary(Unary *curr) = 0; + virtual ReturnType visitBinary(Binary *curr) = 0; + virtual ReturnType visitCompare(Compare *curr) = 0; + virtual ReturnType visitConvert(Convert *curr) = 0; + virtual ReturnType visitHost(Host *curr) = 0; + virtual ReturnType visitNop(Nop *curr) = 0; + + ReturnType visit(Expression *curr) { + assert(curr); + if (Block *cast = dynamic_cast<Block*>(curr)) return visitBlock(cast); + if (If *cast = dynamic_cast<If*>(curr)) return visitIf(cast); + if (Loop *cast = dynamic_cast<Loop*>(curr)) return visitLoop(cast); + if (Label *cast = dynamic_cast<Label*>(curr)) return visitLabel(cast); + if (Break *cast = dynamic_cast<Break*>(curr)) return visitBreak(cast); + if (Switch *cast = dynamic_cast<Switch*>(curr)) return visitSwitch(cast); + if (Call *cast = dynamic_cast<Call*>(curr)) return visitCall(cast); + if (CallImport *cast = dynamic_cast<CallImport*>(curr)) return visitCallImport(cast); + if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) return visitCallIndirect(cast); + if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) return visitGetLocal(cast); + if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) return visitSetLocal(cast); + if (Load *cast = dynamic_cast<Load*>(curr)) return visitLoad(cast); + if (Store *cast = dynamic_cast<Store*>(curr)) return visitStore(cast); + if (Const *cast = dynamic_cast<Const*>(curr)) return visitConst(cast); + if (Unary *cast = dynamic_cast<Unary*>(curr)) return visitUnary(cast); + if (Binary *cast = dynamic_cast<Binary*>(curr)) return visitBinary(cast); + if (Compare *cast = dynamic_cast<Compare*>(curr)) return visitCompare(cast); + if (Convert *cast = dynamic_cast<Convert*>(curr)) return visitConvert(cast); + if (Host *cast = dynamic_cast<Host*>(curr)) return visitHost(cast); + if (Nop *cast = dynamic_cast<Nop*>(curr)) return visitNop(cast); + abort(); + } +}; + +struct WasmWalker : public WasmVisitor<Expression*> { wasm::Arena* allocator; // use an existing allocator, or null if no allocations WasmWalker() : allocator(nullptr) {} WasmWalker(wasm::Arena* allocator) : allocator(allocator) {} // Each method receives an AST pointer, and it is replaced with what is returned. - virtual Expression* walkBlock(Block *curr) { return curr; }; - virtual Expression* walkIf(If *curr) { return curr; }; - virtual Expression* walkLoop(Loop *curr) { return curr; }; - virtual Expression* walkLabel(Label *curr) { return curr; }; - virtual Expression* walkBreak(Break *curr) { return curr; }; - virtual Expression* walkSwitch(Switch *curr) { return curr; }; - virtual Expression* walkCall(Call *curr) { return curr; }; - virtual Expression* walkCallImport(CallImport *curr) { return curr; }; - virtual Expression* walkCallIndirect(CallIndirect *curr) { return curr; }; - virtual Expression* walkGetLocal(GetLocal *curr) { return curr; }; - virtual Expression* walkSetLocal(SetLocal *curr) { return curr; }; - virtual Expression* walkLoad(Load *curr) { return curr; }; - virtual Expression* walkStore(Store *curr) { return curr; }; - virtual Expression* walkConst(Const *curr) { return curr; }; - virtual Expression* walkUnary(Unary *curr) { return curr; }; - virtual Expression* walkBinary(Binary *curr) { return curr; }; - virtual Expression* walkCompare(Compare *curr) { return curr; }; - virtual Expression* walkConvert(Convert *curr) { return curr; }; - virtual Expression* walkHost(Host *curr) { return curr; }; - virtual Expression* walkNop(Nop *curr) { return curr; }; + virtual Expression* visitBlock(Block *curr) { return curr; }; + virtual Expression* visitIf(If *curr) { return curr; }; + virtual Expression* visitLoop(Loop *curr) { return curr; }; + virtual Expression* visitLabel(Label *curr) { return curr; }; + virtual Expression* visitBreak(Break *curr) { return curr; }; + virtual Expression* visitSwitch(Switch *curr) { return curr; }; + virtual Expression* visitCall(Call *curr) { return curr; }; + virtual Expression* visitCallImport(CallImport *curr) { return curr; }; + virtual Expression* visitCallIndirect(CallIndirect *curr) { return curr; }; + virtual Expression* visitGetLocal(GetLocal *curr) { return curr; }; + virtual Expression* visitSetLocal(SetLocal *curr) { return curr; }; + virtual Expression* visitLoad(Load *curr) { return curr; }; + virtual Expression* visitStore(Store *curr) { return curr; }; + virtual Expression* visitConst(Const *curr) { return curr; }; + virtual Expression* visitUnary(Unary *curr) { return curr; }; + virtual Expression* visitBinary(Binary *curr) { return curr; }; + virtual Expression* visitCompare(Compare *curr) { return curr; }; + virtual Expression* visitConvert(Convert *curr) { return curr; }; + virtual Expression* visitHost(Host *curr) { return curr; }; + virtual Expression* visitNop(Nop *curr) { return curr; }; // children-first Expression *walk(Expression *curr) { @@ -840,104 +890,67 @@ struct WasmWalker { for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkBlock(cast); - } - if (If *cast = dynamic_cast<If*>(curr)) { + } else if (If *cast = dynamic_cast<If*>(curr)) { cast->condition = walk(cast->condition); cast->ifTrue = walk(cast->ifTrue); cast->ifFalse = walk(cast->ifFalse); - return walkIf(cast); - } - if (Loop *cast = dynamic_cast<Loop*>(curr)) { + } else if (Loop *cast = dynamic_cast<Loop*>(curr)) { cast->body = walk(cast->body); - return walkLoop(cast); - } - if (Label *cast = dynamic_cast<Label*>(curr)) { - return walkLabel(cast); - } - if (Break *cast = dynamic_cast<Break*>(curr)) { + } else if (Label *cast = dynamic_cast<Label*>(curr)) { + } else if (Break *cast = dynamic_cast<Break*>(curr)) { cast->condition = walk(cast->condition); cast->value = walk(cast->value); - return walkBreak(cast); - } - if (Switch *cast = dynamic_cast<Switch*>(curr)) { + } else if (Switch *cast = dynamic_cast<Switch*>(curr)) { cast->value = walk(cast->value); for (auto& curr : cast->cases) { curr.body = walk(curr.body); } cast->default_ = walk(cast->default_); - return walkSwitch(cast); - } - if (Call *cast = dynamic_cast<Call*>(curr)) { + } else if (Call *cast = dynamic_cast<Call*>(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCall(cast); - } - if (CallImport *cast = dynamic_cast<CallImport*>(curr)) { + } else if (CallImport *cast = dynamic_cast<CallImport*>(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCallImport(cast); - } - if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) { + } else if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) { cast->target = walk(cast->target); ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkCallIndirect(cast); - } - if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) { - return walkGetLocal(cast); - } - if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) { + } else if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) { + } else if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) { cast->value = walk(cast->value); - return walkSetLocal(cast); - } - if (Load *cast = dynamic_cast<Load*>(curr)) { + } else if (Load *cast = dynamic_cast<Load*>(curr)) { cast->ptr = walk(cast->ptr); - return walkLoad(cast); - } - if (Store *cast = dynamic_cast<Store*>(curr)) { + } else if (Store *cast = dynamic_cast<Store*>(curr)) { cast->ptr = walk(cast->ptr); cast->value = walk(cast->value); - return walkStore(cast); - } - if (Const *cast = dynamic_cast<Const*>(curr)) { - return walkConst(cast); - } - if (Unary *cast = dynamic_cast<Unary*>(curr)) { + } else if (Const *cast = dynamic_cast<Const*>(curr)) { + } else if (Unary *cast = dynamic_cast<Unary*>(curr)) { cast->value = walk(cast->value); - return walkUnary(cast); - } - if (Binary *cast = dynamic_cast<Binary*>(curr)) { + } else if (Binary *cast = dynamic_cast<Binary*>(curr)) { cast->left = walk(cast->left); cast->right = walk(cast->right); - return walkBinary(cast); - } - if (Compare *cast = dynamic_cast<Compare*>(curr)) { + } else if (Compare *cast = dynamic_cast<Compare*>(curr)) { cast->left = walk(cast->left); cast->right = walk(cast->right); - return walkCompare(cast); - } - if (Convert *cast = dynamic_cast<Convert*>(curr)) { + } else if (Convert *cast = dynamic_cast<Convert*>(curr)) { cast->value = walk(cast->value); - return walkConvert(cast); - } - if (Host *cast = dynamic_cast<Host*>(curr)) { + } else if (Host *cast = dynamic_cast<Host*>(curr)) { ExpressionList& list = cast->operands; for (size_t z = 0; z < list.size(); z++) { list[z] = walk(list[z]); } - return walkHost(cast); - } - if (Nop *cast = dynamic_cast<Nop*>(curr)) { - return walkNop(cast); + } else if (Nop *cast = dynamic_cast<Nop*>(curr)) { + } else { + abort(); } - abort(); + return visit(curr); } void startWalk(Function *func) { |