diff options
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 119 |
1 files changed, 65 insertions, 54 deletions
diff --git a/src/wasm.h b/src/wasm.h index 9ef39e42f..95b44c344 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -985,39 +985,45 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { // Simple WebAssembly children-first walking // -struct WasmWalker : public WasmVisitor<Expression*> { +struct WasmWalker : public WasmVisitor<void> { wasm::Arena* allocator; // use an existing allocator, or null if no allocations + Expression* replace; - WasmWalker() : allocator(nullptr) {} - WasmWalker(wasm::Arena* allocator) : allocator(allocator) {} - - // Each method receives an AST pointer, and it is replaced with what is returned. - Expression* visitBlock(Block *curr) override { return curr; }; - Expression* visitIf(If *curr) override { return curr; }; - Expression* visitLoop(Loop *curr) override { return curr; }; - Expression* visitLabel(Label *curr) override { return curr; }; - Expression* visitBreak(Break *curr) override { return curr; }; - Expression* visitSwitch(Switch *curr) override { return curr; }; - Expression* visitCall(Call *curr) override { return curr; }; - Expression* visitCallImport(CallImport *curr) override { return curr; }; - Expression* visitCallIndirect(CallIndirect *curr) override { return curr; }; - Expression* visitGetLocal(GetLocal *curr) override { return curr; }; - Expression* visitSetLocal(SetLocal *curr) override { return curr; }; - Expression* visitLoad(Load *curr) override { return curr; }; - Expression* visitStore(Store *curr) override { return curr; }; - Expression* visitConst(Const *curr) override { return curr; }; - Expression* visitUnary(Unary *curr) override { return curr; }; - Expression* visitBinary(Binary *curr) override { return curr; }; - Expression* visitCompare(Compare *curr) override { return curr; }; - Expression* visitConvert(Convert *curr) override { return curr; }; - Expression* visitHost(Host *curr) override { return curr; }; - Expression* visitNop(Nop *curr) override { return curr; }; + WasmWalker() : allocator(nullptr), replace(nullptr) {} + WasmWalker(wasm::Arena* allocator) : allocator(allocator), replace(nullptr) {} + + // the visit* methods can call this to replace the current node + void replaceCurrent(Expression *expression) { + replace = expression; + } + + // By default, do nothing + void visitBlock(Block *curr) override {}; + void visitIf(If *curr) override {}; + void visitLoop(Loop *curr) override {}; + void visitLabel(Label *curr) override {}; + void visitBreak(Break *curr) override {}; + void visitSwitch(Switch *curr) override {}; + void visitCall(Call *curr) override {}; + void visitCallImport(CallImport *curr) override {}; + void visitCallIndirect(CallIndirect *curr) override {}; + void visitGetLocal(GetLocal *curr) override {}; + void visitSetLocal(SetLocal *curr) override {}; + void visitLoad(Load *curr) override {}; + void visitStore(Store *curr) override {}; + void visitConst(Const *curr) override {}; + void visitUnary(Unary *curr) override {}; + void visitBinary(Binary *curr) override {}; + void visitCompare(Compare *curr) override {}; + void visitConvert(Convert *curr) override {}; + void visitHost(Host *curr) override {}; + void visitNop(Nop *curr) override {}; // children-first - Expression *walk(Expression *curr) { - if (!curr) return curr; + void walk(Expression*& curr) { + if (!curr) return; - struct ChildWalker : public WasmVisitor<void> { + struct ChildWalker : public WasmVisitor { WasmWalker& parent; ChildWalker(WasmWalker& parent) : parent(parent) {} @@ -1025,77 +1031,77 @@ struct WasmWalker : public WasmVisitor<Expression*> { void visitBlock(Block *curr) override { ExpressionList& list = curr->list; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitIf(If *curr) override { - curr->condition = parent.walk(curr->condition); - curr->ifTrue = parent.walk(curr->ifTrue); - curr->ifFalse = parent.walk(curr->ifFalse); + parent.walk(curr->condition); + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); } void visitLoop(Loop *curr) override { - curr->body = parent.walk(curr->body); + parent.walk(curr->body); } void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitSwitch(Switch *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); for (auto& case_ : curr->cases) { - case_.body = parent.walk(case_.body); + parent.walk(case_.body); } - curr->default_ = parent.walk(curr->default_); + parent.walk(curr->default_); } void visitCall(Call *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitCallImport(CallImport *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitCallIndirect(CallIndirect *curr) override { - curr->target = parent.walk(curr->target); + parent.walk(curr->target); ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitGetLocal(GetLocal *curr) override {} void visitSetLocal(SetLocal *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitLoad(Load *curr) override { - curr->ptr = parent.walk(curr->ptr); + parent.walk(curr->ptr); } void visitStore(Store *curr) override { - curr->ptr = parent.walk(curr->ptr); - curr->value = parent.walk(curr->value); + parent.walk(curr->ptr); + parent.walk(curr->value); } void visitConst(Const *curr) override {} void visitUnary(Unary *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitBinary(Binary *curr) override { - curr->left = parent.walk(curr->left); - curr->right = parent.walk(curr->right); + parent.walk(curr->left); + parent.walk(curr->right); } void visitCompare(Compare *curr) override { - curr->left = parent.walk(curr->left); - curr->right = parent.walk(curr->right); + parent.walk(curr->left); + parent.walk(curr->right); } void visitConvert(Convert *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitHost(Host *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitNop(Nop *curr) override {} @@ -1103,11 +1109,16 @@ struct WasmWalker : public WasmVisitor<Expression*> { ChildWalker(*this).visit(curr); - return visit(curr); + visit(curr); + + if (replace) { + curr = replace; + replace = nullptr; + } } void startWalk(Function *func) { - func->body = walk(func->body); + walk(func->body); } }; |