diff options
author | Michael Bebenita <mbebenita@gmail.com> | 2016-01-13 20:13:04 -0800 |
---|---|---|
committer | Michael Bebenita <mbebenita@gmail.com> | 2016-01-13 20:13:04 -0800 |
commit | e165020f87f807179d27203195843c88fb8afe55 (patch) | |
tree | 6b8f19534b32e8a8de32e6d380f5a537a6965368 /src/wasm.h | |
parent | 7e3bdd00f9b390c36461291fa5b884ace55e82d6 (diff) | |
download | binaryen-e165020f87f807179d27203195843c88fb8afe55.tar.gz binaryen-e165020f87f807179d27203195843c88fb8afe55.tar.bz2 binaryen-e165020f87f807179d27203195843c88fb8afe55.zip |
Use LLVM style static polymorphism for WasmVisitors.
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 257 |
1 files changed, 131 insertions, 126 deletions
diff --git a/src/wasm.h b/src/wasm.h index c6942acfe..a5ad3fe77 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1137,61 +1137,62 @@ class AllocatingModule : public Module { // etc. // -template<class ReturnType> +template<typename SubType, typename ReturnType> struct WasmVisitor { virtual ~WasmVisitor() {} // should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048 // Expression visitors - virtual ReturnType visitBlock(Block *curr) { abort(); } - virtual ReturnType visitIf(If *curr) { abort(); } - virtual ReturnType visitLoop(Loop *curr) { abort(); } - virtual ReturnType visitBreak(Break *curr) { abort(); } - virtual ReturnType visitSwitch(Switch *curr) { abort(); } - virtual ReturnType visitCall(Call *curr) { abort(); } - virtual ReturnType visitCallImport(CallImport *curr) { abort(); } - virtual ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } - virtual ReturnType visitGetLocal(GetLocal *curr) { abort(); } - virtual ReturnType visitSetLocal(SetLocal *curr) { abort(); } - virtual ReturnType visitLoad(Load *curr) { abort(); } - virtual ReturnType visitStore(Store *curr) { abort(); } - virtual ReturnType visitConst(Const *curr) { abort(); } - virtual ReturnType visitUnary(Unary *curr) { abort(); } - virtual ReturnType visitBinary(Binary *curr) { abort(); } - virtual ReturnType visitSelect(Select *curr) { abort(); } - virtual ReturnType visitHost(Host *curr) { abort(); } - virtual ReturnType visitNop(Nop *curr) { abort(); } - virtual ReturnType visitUnreachable(Unreachable *curr) { abort(); } + ReturnType visitBlock(Block *curr) { abort(); } + ReturnType visitIf(If *curr) { abort(); } + ReturnType visitLoop(Loop *curr) { abort(); } + ReturnType visitBreak(Break *curr) { abort(); } + ReturnType visitSwitch(Switch *curr) { abort(); } + ReturnType visitCall(Call *curr) { abort(); } + ReturnType visitCallImport(CallImport *curr) { abort(); } + ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } + ReturnType visitGetLocal(GetLocal *curr) { abort(); } + ReturnType visitSetLocal(SetLocal *curr) { abort(); } + ReturnType visitLoad(Load *curr) { abort(); } + ReturnType visitStore(Store *curr) { abort(); } + ReturnType visitConst(Const *curr) { abort(); } + ReturnType visitUnary(Unary *curr) { abort(); } + ReturnType visitBinary(Binary *curr) { abort(); } + ReturnType visitSelect(Select *curr) { abort(); } + ReturnType visitHost(Host *curr) { abort(); } + ReturnType visitNop(Nop *curr) { abort(); } + ReturnType visitUnreachable(Unreachable *curr) { abort(); } // Module-level visitors - virtual ReturnType visitFunctionType(FunctionType *curr) { abort(); } - virtual ReturnType visitImport(Import *curr) { abort(); } - virtual ReturnType visitExport(Export *curr) { abort(); } - virtual ReturnType visitFunction(Function *curr) { abort(); } - virtual ReturnType visitTable(Table *curr) { abort(); } - virtual ReturnType visitMemory(Memory *curr) { abort(); } - virtual ReturnType visitModule(Module *curr) { abort(); } + ReturnType visitFunctionType(FunctionType *curr) { abort(); } + ReturnType visitImport(Import *curr) { abort(); } + ReturnType visitExport(Export *curr) { abort(); } + ReturnType visitFunction(Function *curr) { abort(); } + ReturnType visitTable(Table *curr) { abort(); } + ReturnType visitMemory(Memory *curr) { abort(); } + ReturnType visitModule(Module *curr) { abort(); } ReturnType visit(Expression *curr) { assert(curr); + SubType* self = static_cast<SubType*>(this); switch (curr->_id) { - case Expression::Id::BlockId: return visitBlock((Block*)curr); - case Expression::Id::IfId: return visitIf((If*)curr); - case Expression::Id::LoopId: return visitLoop((Loop*)curr); - case Expression::Id::BreakId: return visitBreak((Break*)curr); - case Expression::Id::SwitchId: return visitSwitch((Switch*)curr); - case Expression::Id::CallId: return visitCall((Call*)curr); - case Expression::Id::CallImportId: return visitCallImport((CallImport*)curr); - case Expression::Id::CallIndirectId: return visitCallIndirect((CallIndirect*)curr); - case Expression::Id::GetLocalId: return visitGetLocal((GetLocal*)curr); - case Expression::Id::SetLocalId: return visitSetLocal((SetLocal*)curr); - case Expression::Id::LoadId: return visitLoad((Load*)curr); - case Expression::Id::StoreId: return visitStore((Store*)curr); - case Expression::Id::ConstId: return visitConst((Const*)curr); - case Expression::Id::UnaryId: return visitUnary((Unary*)curr); - case Expression::Id::BinaryId: return visitBinary((Binary*)curr); - case Expression::Id::SelectId: return visitSelect((Select*)curr); - case Expression::Id::HostId: return visitHost((Host*)curr); - case Expression::Id::NopId: return visitNop((Nop*)curr); - case Expression::Id::UnreachableId: return visitUnreachable((Unreachable*)curr); + case Expression::Id::BlockId: return self->visitBlock((Block*)curr); + case Expression::Id::IfId: return self->visitIf((If*)curr); + case Expression::Id::LoopId: return self->visitLoop((Loop*)curr); + case Expression::Id::BreakId: return self->visitBreak((Break*)curr); + case Expression::Id::SwitchId: return self->visitSwitch((Switch*)curr); + case Expression::Id::CallId: return self->visitCall((Call*)curr); + case Expression::Id::CallImportId: return self->visitCallImport((CallImport*)curr); + case Expression::Id::CallIndirectId: return self->visitCallIndirect((CallIndirect*)curr); + case Expression::Id::GetLocalId: return self->visitGetLocal((GetLocal*)curr); + case Expression::Id::SetLocalId: return self->visitSetLocal((SetLocal*)curr); + case Expression::Id::LoadId: return self->visitLoad((Load*)curr); + case Expression::Id::StoreId: return self->visitStore((Store*)curr); + case Expression::Id::ConstId: return self->visitConst((Const*)curr); + case Expression::Id::UnaryId: return self->visitUnary((Unary*)curr); + case Expression::Id::BinaryId: return self->visitBinary((Binary*)curr); + case Expression::Id::SelectId: return self->visitSelect((Select*)curr); + case Expression::Id::HostId: return self->visitHost((Host*)curr); + case Expression::Id::NopId: return self->visitNop((Nop*)curr); + case Expression::Id::UnreachableId: return self->visitUnreachable((Unreachable*)curr); default: { std::cerr << "visiting unknown expression " << curr->_id << '\n'; abort(); @@ -1201,31 +1202,31 @@ struct WasmVisitor { }; std::ostream& Expression::print(std::ostream &o, unsigned indent) { - struct ExpressionPrinter : public WasmVisitor<void> { + struct ExpressionPrinter : public WasmVisitor<ExpressionPrinter, void> { std::ostream &o; unsigned indent; ExpressionPrinter(std::ostream &o, unsigned indent) : o(o), indent(indent) {} - void visitBlock(Block *curr) override { curr->doPrint(o, indent); } - void visitIf(If *curr) override { curr->doPrint(o, indent); } - void visitLoop(Loop *curr) override { curr->doPrint(o, indent); } - void visitBreak(Break *curr) override { curr->doPrint(o, indent); } - void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); } - void visitCall(Call *curr) override { curr->doPrint(o, indent); } - void visitCallImport(CallImport *curr) override { curr->doPrint(o, indent); } - void visitCallIndirect(CallIndirect *curr) override { curr->doPrint(o, indent); } - void visitGetLocal(GetLocal *curr) override { curr->doPrint(o, indent); } - void visitSetLocal(SetLocal *curr) override { curr->doPrint(o, indent); } - void visitLoad(Load *curr) override { curr->doPrint(o, indent); } - void visitStore(Store *curr) override { curr->doPrint(o, indent); } - void visitConst(Const *curr) override { curr->doPrint(o, indent); } - void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } - void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } - void visitSelect(Select *curr) override { curr->doPrint(o, indent); } - void visitHost(Host *curr) override { curr->doPrint(o, indent); } - void visitNop(Nop *curr) override { curr->doPrint(o, indent); } - void visitUnreachable(Unreachable *curr) override { curr->doPrint(o, indent); } + void visitBlock(Block *curr) { curr->doPrint(o, indent); } + void visitIf(If *curr) { curr->doPrint(o, indent); } + void visitLoop(Loop *curr) { curr->doPrint(o, indent); } + void visitBreak(Break *curr) { curr->doPrint(o, indent); } + void visitSwitch(Switch *curr) { curr->doPrint(o, indent); } + void visitCall(Call *curr) { curr->doPrint(o, indent); } + void visitCallImport(CallImport *curr) { curr->doPrint(o, indent); } + void visitCallIndirect(CallIndirect *curr) { curr->doPrint(o, indent); } + void visitGetLocal(GetLocal *curr) { curr->doPrint(o, indent); } + void visitSetLocal(SetLocal *curr) { curr->doPrint(o, indent); } + void visitLoad(Load *curr) { curr->doPrint(o, indent); } + void visitStore(Store *curr) { curr->doPrint(o, indent); } + void visitConst(Const *curr) { curr->doPrint(o, indent); } + void visitUnary(Unary *curr) { curr->doPrint(o, indent); } + void visitBinary(Binary *curr) { curr->doPrint(o, indent); } + void visitSelect(Select *curr) { curr->doPrint(o, indent); } + void visitHost(Host *curr) { curr->doPrint(o, indent); } + void visitNop(Nop *curr) { curr->doPrint(o, indent); } + void visitUnreachable(Unreachable *curr) { curr->doPrint(o, indent); } }; ExpressionPrinter(o, indent).visit(this); @@ -1233,92 +1234,94 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { return o; } -struct WasmWalkerBase : public WasmVisitor<void> { +template<typename SubType, typename ReturnType> +struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { virtual void walk(Expression*& curr) { abort(); } virtual void startWalk(Function *func) { abort(); } virtual void startWalk(Module *module) { abort(); } }; -struct ChildWalker : public WasmWalkerBase { - WasmWalkerBase& parent; +template<typename ParentType> +struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>, void> { + ParentType& parent; - ChildWalker(WasmWalkerBase& parent) : parent(parent) {} + ChildWalker(ParentType& parent) : parent(parent) {} - void visitBlock(Block *curr) override { + void visitBlock(Block *curr) { ExpressionList& list = curr->list; for (size_t z = 0; z < list.size(); z++) { parent.walk(list[z]); } } - void visitIf(If *curr) override { + void visitIf(If *curr) { parent.walk(curr->condition); parent.walk(curr->ifTrue); parent.walk(curr->ifFalse); } - void visitLoop(Loop *curr) override { + void visitLoop(Loop *curr) { parent.walk(curr->body); } - void visitBreak(Break *curr) override { + void visitBreak(Break *curr) { parent.walk(curr->condition); parent.walk(curr->value); } - void visitSwitch(Switch *curr) override { + void visitSwitch(Switch *curr) { parent.walk(curr->value); for (auto& case_ : curr->cases) { parent.walk(case_.body); } } - void visitCall(Call *curr) override { + void visitCall(Call *curr) { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { parent.walk(list[z]); } } - void visitCallImport(CallImport *curr) override { + void visitCallImport(CallImport *curr) { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { parent.walk(list[z]); } } - void visitCallIndirect(CallIndirect *curr) override { + void visitCallIndirect(CallIndirect *curr) { parent.walk(curr->target); ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { parent.walk(list[z]); } } - void visitGetLocal(GetLocal *curr) override {} - void visitSetLocal(SetLocal *curr) override { + void visitGetLocal(GetLocal *curr) {} + void visitSetLocal(SetLocal *curr) { parent.walk(curr->value); } - void visitLoad(Load *curr) override { + void visitLoad(Load *curr) { parent.walk(curr->ptr); } - void visitStore(Store *curr) override { + void visitStore(Store *curr) { parent.walk(curr->ptr); parent.walk(curr->value); } - void visitConst(Const *curr) override {} - void visitUnary(Unary *curr) override { + void visitConst(Const *curr) {} + void visitUnary(Unary *curr) { parent.walk(curr->value); } - void visitBinary(Binary *curr) override { + void visitBinary(Binary *curr) { parent.walk(curr->left); parent.walk(curr->right); } - void visitSelect(Select *curr) override { + void visitSelect(Select *curr) { parent.walk(curr->condition); parent.walk(curr->ifTrue); parent.walk(curr->ifFalse); } - void visitHost(Host *curr) override { + void visitHost(Host *curr) { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { parent.walk(list[z]); } } - void visitNop(Nop *curr) override {} - void visitUnreachable(Unreachable *curr) override {} + void visitNop(Nop *curr) {} + void visitUnreachable(Unreachable *curr) {} }; // @@ -1327,7 +1330,8 @@ struct ChildWalker : public WasmWalkerBase { // the current expression node. Useful for writing optimization passes. // -struct WasmWalker : public WasmWalkerBase { +template<typename SubType, typename ReturnType> +struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { Expression* replace; WasmWalker() : replace(nullptr) {} @@ -1338,41 +1342,41 @@ struct WasmWalker : public WasmWalkerBase { } // By default, do nothing - void visitBlock(Block *curr) override {} - void visitIf(If *curr) override {} - void visitLoop(Loop *curr) override {} - void visitBreak(Break *curr) override {} - void visitSwitch(Switch *curr) override {} - void visitCall(Call *curr) override {} - void visitCallImport(CallImport *curr) override {} - void visitCallIndirect(CallIndirect *curr) override {} - void visitGetLocal(GetLocal *curr) override {} - void visitSetLocal(SetLocal *curr) override {} - void visitLoad(Load *curr) override {} - void visitStore(Store *curr) override {} - void visitConst(Const *curr) override {} - void visitUnary(Unary *curr) override {} - void visitBinary(Binary *curr) override {} - void visitSelect(Select *curr) override {} - void visitHost(Host *curr) override {} - void visitNop(Nop *curr) override {} - void visitUnreachable(Unreachable *curr) override {} - - void visitFunctionType(FunctionType *curr) override {} - void visitImport(Import *curr) override {} - void visitExport(Export *curr) override {} - void visitFunction(Function *curr) override {} - void visitTable(Table *curr) override {} - void visitMemory(Memory *curr) override {} - void visitModule(Module *curr) override {} + ReturnType visitBlock(Block *curr) {} + ReturnType visitIf(If *curr) {} + ReturnType visitLoop(Loop *curr) {} + ReturnType visitBreak(Break *curr) {} + ReturnType visitSwitch(Switch *curr) {} + ReturnType visitCall(Call *curr) {} + ReturnType visitCallImport(CallImport *curr) {} + ReturnType visitCallIndirect(CallIndirect *curr) {} + ReturnType visitGetLocal(GetLocal *curr) {} + ReturnType visitSetLocal(SetLocal *curr) {} + ReturnType visitLoad(Load *curr) {} + ReturnType visitStore(Store *curr) {} + ReturnType visitConst(Const *curr) {} + ReturnType visitUnary(Unary *curr) {} + ReturnType visitBinary(Binary *curr) {} + ReturnType visitSelect(Select *curr) {} + ReturnType visitHost(Host *curr) {} + ReturnType visitNop(Nop *curr) {} + ReturnType visitUnreachable(Unreachable *curr) {} + + ReturnType visitFunctionType(FunctionType *curr) {} + ReturnType visitImport(Import *curr) {} + ReturnType visitExport(Export *curr) {} + ReturnType visitFunction(Function *curr) {} + ReturnType visitTable(Table *curr) {} + ReturnType visitMemory(Memory *curr) {} + ReturnType visitModule(Module *curr) {} // children-first void walk(Expression*& curr) override { if (!curr) return; - ChildWalker(*this).visit(curr); + ChildWalker<WasmWalker<SubType, ReturnType> >(*this).visit(curr); - visit(curr); + this->visit(curr); if (replace) { curr = replace; @@ -1385,28 +1389,29 @@ struct WasmWalker : public WasmWalkerBase { } void startWalk(Module *module) override { + SubType* self = static_cast<SubType*>(this); for (auto curr : module->functionTypes) { - visitFunctionType(curr); + self->visitFunctionType(curr); assert(!replace); } for (auto curr : module->imports) { - visitImport(curr); + self->visitImport(curr); assert(!replace); } for (auto curr : module->exports) { - visitExport(curr); + self->visitExport(curr); assert(!replace); } for (auto curr : module->functions) { startWalk(curr); - visitFunction(curr); + self->visitFunction(curr); assert(!replace); } - visitTable(&module->table); + self->visitTable(&module->table); assert(!replace); - visitMemory(&module->memory); + self->visitMemory(&module->memory); assert(!replace); - visitModule(module); + self->visitModule(module); assert(!replace); } }; |