diff options
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 410 |
1 files changed, 215 insertions, 195 deletions
diff --git a/src/wasm.h b/src/wasm.h index 9edd7645e..ffca28f9e 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1133,100 +1133,106 @@ class AllocatingModule : public Module { }; // -// Simple WebAssembly AST visiting. Useful for anything that wants to do -// something different for each AST node type, like printing, interpreting, -// etc. +// WebAssembly AST visitor. Useful for anything that wants to do something +// different for each AST node type, like printing, interpreting, etc. +// +// This class is specifically designed as a template to avoid virtual function +// call overhead. To write a visitor, derive from this class as follows: +// +// struct MyVisitor : public WasmVisitor<MyVisitor> { .. } // -template<class ReturnType> +template<typename SubType, typename ReturnType> struct WasmVisitor { virtual ~WasmVisitor() {} // should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048 // Expression visitors - virtual ReturnType visitBlock(Block *curr) { abort(); } - virtual ReturnType visitIf(If *curr) { abort(); } - virtual ReturnType visitLoop(Loop *curr) { abort(); } - virtual ReturnType visitBreak(Break *curr) { abort(); } - virtual ReturnType visitSwitch(Switch *curr) { abort(); } - virtual ReturnType visitCall(Call *curr) { abort(); } - virtual ReturnType visitCallImport(CallImport *curr) { abort(); } - virtual ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } - virtual ReturnType visitGetLocal(GetLocal *curr) { abort(); } - virtual ReturnType visitSetLocal(SetLocal *curr) { abort(); } - virtual ReturnType visitLoad(Load *curr) { abort(); } - virtual ReturnType visitStore(Store *curr) { abort(); } - virtual ReturnType visitConst(Const *curr) { abort(); } - virtual ReturnType visitUnary(Unary *curr) { abort(); } - virtual ReturnType visitBinary(Binary *curr) { abort(); } - virtual ReturnType visitSelect(Select *curr) { abort(); } - virtual ReturnType visitHost(Host *curr) { abort(); } - virtual ReturnType visitNop(Nop *curr) { abort(); } - virtual ReturnType visitUnreachable(Unreachable *curr) { abort(); } + ReturnType visitBlock(Block *curr) { abort(); } + ReturnType visitIf(If *curr) { abort(); } + ReturnType visitLoop(Loop *curr) { abort(); } + ReturnType visitBreak(Break *curr) { abort(); } + ReturnType visitSwitch(Switch *curr) { abort(); } + ReturnType visitCall(Call *curr) { abort(); } + ReturnType visitCallImport(CallImport *curr) { abort(); } + ReturnType visitCallIndirect(CallIndirect *curr) { abort(); } + ReturnType visitGetLocal(GetLocal *curr) { abort(); } + ReturnType visitSetLocal(SetLocal *curr) { abort(); } + ReturnType visitLoad(Load *curr) { abort(); } + ReturnType visitStore(Store *curr) { abort(); } + ReturnType visitConst(Const *curr) { abort(); } + ReturnType visitUnary(Unary *curr) { abort(); } + ReturnType visitBinary(Binary *curr) { abort(); } + ReturnType visitSelect(Select *curr) { abort(); } + ReturnType visitHost(Host *curr) { abort(); } + ReturnType visitNop(Nop *curr) { abort(); } + ReturnType visitUnreachable(Unreachable *curr) { abort(); } // Module-level visitors - virtual ReturnType visitFunctionType(FunctionType *curr) { abort(); } - virtual ReturnType visitImport(Import *curr) { abort(); } - virtual ReturnType visitExport(Export *curr) { abort(); } - virtual ReturnType visitFunction(Function *curr) { abort(); } - virtual ReturnType visitTable(Table *curr) { abort(); } - virtual ReturnType visitMemory(Memory *curr) { abort(); } - virtual ReturnType visitModule(Module *curr) { abort(); } + ReturnType visitFunctionType(FunctionType *curr) { abort(); } + ReturnType visitImport(Import *curr) { abort(); } + ReturnType visitExport(Export *curr) { abort(); } + ReturnType visitFunction(Function *curr) { abort(); } + ReturnType visitTable(Table *curr) { abort(); } + ReturnType visitMemory(Memory *curr) { abort(); } + ReturnType visitModule(Module *curr) { abort(); } + +#define DELEGATE(CLASS_TO_VISIT) \ + return static_cast<SubType*>(this)-> \ + visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) ReturnType visit(Expression *curr) { assert(curr); switch (curr->_id) { - case Expression::Id::BlockId: return visitBlock((Block*)curr); - case Expression::Id::IfId: return visitIf((If*)curr); - case Expression::Id::LoopId: return visitLoop((Loop*)curr); - case Expression::Id::BreakId: return visitBreak((Break*)curr); - case Expression::Id::SwitchId: return visitSwitch((Switch*)curr); - case Expression::Id::CallId: return visitCall((Call*)curr); - case Expression::Id::CallImportId: return visitCallImport((CallImport*)curr); - case Expression::Id::CallIndirectId: return visitCallIndirect((CallIndirect*)curr); - case Expression::Id::GetLocalId: return visitGetLocal((GetLocal*)curr); - case Expression::Id::SetLocalId: return visitSetLocal((SetLocal*)curr); - case Expression::Id::LoadId: return visitLoad((Load*)curr); - case Expression::Id::StoreId: return visitStore((Store*)curr); - case Expression::Id::ConstId: return visitConst((Const*)curr); - case Expression::Id::UnaryId: return visitUnary((Unary*)curr); - case Expression::Id::BinaryId: return visitBinary((Binary*)curr); - case Expression::Id::SelectId: return visitSelect((Select*)curr); - case Expression::Id::HostId: return visitHost((Host*)curr); - case Expression::Id::NopId: return visitNop((Nop*)curr); - case Expression::Id::UnreachableId: return visitUnreachable((Unreachable*)curr); - default: { - std::cerr << "visiting unknown expression " << curr->_id << '\n'; - abort(); - } + case Expression::Id::InvalidId: abort(); + case Expression::Id::BlockId: DELEGATE(Block); + case Expression::Id::IfId: DELEGATE(If); + case Expression::Id::LoopId: DELEGATE(Loop); + case Expression::Id::BreakId: DELEGATE(Break); + case Expression::Id::SwitchId: DELEGATE(Switch); + case Expression::Id::CallId: DELEGATE(Call); + case Expression::Id::CallImportId: DELEGATE(CallImport); + case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); + case Expression::Id::GetLocalId: DELEGATE(GetLocal); + case Expression::Id::SetLocalId: DELEGATE(SetLocal); + case Expression::Id::LoadId: DELEGATE(Load); + case Expression::Id::StoreId: DELEGATE(Store); + case Expression::Id::ConstId: DELEGATE(Const); + case Expression::Id::UnaryId: DELEGATE(Unary); + case Expression::Id::BinaryId: DELEGATE(Binary); + case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::NopId: DELEGATE(Nop); + case Expression::Id::UnreachableId: DELEGATE(Unreachable); + default: WASM_UNREACHABLE(); } } }; std::ostream& Expression::print(std::ostream &o, unsigned indent) { - struct ExpressionPrinter : public WasmVisitor<void> { + struct ExpressionPrinter : public WasmVisitor<ExpressionPrinter, void> { std::ostream &o; unsigned indent; ExpressionPrinter(std::ostream &o, unsigned indent) : o(o), indent(indent) {} - void visitBlock(Block *curr) override { curr->doPrint(o, indent); } - void visitIf(If *curr) override { curr->doPrint(o, indent); } - void visitLoop(Loop *curr) override { curr->doPrint(o, indent); } - void visitBreak(Break *curr) override { curr->doPrint(o, indent); } - void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); } - void visitCall(Call *curr) override { curr->doPrint(o, indent); } - void visitCallImport(CallImport *curr) override { curr->doPrint(o, indent); } - void visitCallIndirect(CallIndirect *curr) override { curr->doPrint(o, indent); } - void visitGetLocal(GetLocal *curr) override { curr->doPrint(o, indent); } - void visitSetLocal(SetLocal *curr) override { curr->doPrint(o, indent); } - void visitLoad(Load *curr) override { curr->doPrint(o, indent); } - void visitStore(Store *curr) override { curr->doPrint(o, indent); } - void visitConst(Const *curr) override { curr->doPrint(o, indent); } - void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } - void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } - void visitSelect(Select *curr) override { curr->doPrint(o, indent); } - void visitHost(Host *curr) override { curr->doPrint(o, indent); } - void visitNop(Nop *curr) override { curr->doPrint(o, indent); } - void visitUnreachable(Unreachable *curr) override { curr->doPrint(o, indent); } + void visitBlock(Block *curr) { curr->doPrint(o, indent); } + void visitIf(If *curr) { curr->doPrint(o, indent); } + void visitLoop(Loop *curr) { curr->doPrint(o, indent); } + void visitBreak(Break *curr) { curr->doPrint(o, indent); } + void visitSwitch(Switch *curr) { curr->doPrint(o, indent); } + void visitCall(Call *curr) { curr->doPrint(o, indent); } + void visitCallImport(CallImport *curr) { curr->doPrint(o, indent); } + void visitCallIndirect(CallIndirect *curr) { curr->doPrint(o, indent); } + void visitGetLocal(GetLocal *curr) { curr->doPrint(o, indent); } + void visitSetLocal(SetLocal *curr) { curr->doPrint(o, indent); } + void visitLoad(Load *curr) { curr->doPrint(o, indent); } + void visitStore(Store *curr) { curr->doPrint(o, indent); } + void visitConst(Const *curr) { curr->doPrint(o, indent); } + void visitUnary(Unary *curr) { curr->doPrint(o, indent); } + void visitBinary(Binary *curr) { curr->doPrint(o, indent); } + void visitSelect(Select *curr) { curr->doPrint(o, indent); } + void visitHost(Host *curr) { curr->doPrint(o, indent); } + void visitNop(Nop *curr) { curr->doPrint(o, indent); } + void visitUnreachable(Unreachable *curr) { curr->doPrint(o, indent); } }; ExpressionPrinter(o, indent).visit(this); @@ -1235,12 +1241,106 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { } // +// Base class for all WasmWalkers +// +template<typename SubType, typename ReturnType = void> +struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { + virtual void walk(Expression*& curr) { abort(); } + virtual void startWalk(Function *func) { abort(); } + virtual void startWalk(Module *module) { abort(); } +}; + +template<typename ParentType> +struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> { + ParentType& parent; + + ChildWalker(ParentType& parent) : parent(parent) {} + + void visitBlock(Block *curr) { + ExpressionList& list = curr->list; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitIf(If *curr) { + parent.walk(curr->condition); + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); + } + void visitLoop(Loop *curr) { + parent.walk(curr->body); + } + void visitBreak(Break *curr) { + parent.walk(curr->condition); + parent.walk(curr->value); + } + void visitSwitch(Switch *curr) { + parent.walk(curr->value); + for (auto& case_ : curr->cases) { + parent.walk(case_.body); + } + } + void visitCall(Call *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitCallImport(CallImport *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitCallIndirect(CallIndirect *curr) { + parent.walk(curr->target); + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitGetLocal(GetLocal *curr) {} + void visitSetLocal(SetLocal *curr) { + parent.walk(curr->value); + } + void visitLoad(Load *curr) { + parent.walk(curr->ptr); + } + void visitStore(Store *curr) { + parent.walk(curr->ptr); + parent.walk(curr->value); + } + void visitConst(Const *curr) {} + void visitUnary(Unary *curr) { + parent.walk(curr->value); + } + void visitBinary(Binary *curr) { + parent.walk(curr->left); + parent.walk(curr->right); + } + void visitSelect(Select *curr) { + parent.walk(curr->condition); + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); + } + void visitHost(Host *curr) { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + parent.walk(list[z]); + } + } + void visitNop(Nop *curr) {} + void visitUnreachable(Unreachable *curr) {} +}; + +// // Simple WebAssembly children-first walking (i.e., post-order, if you look // at the children as subtrees of the current node), with the ability to replace // the current expression node. Useful for writing optimization passes. // -struct WasmWalker : public WasmVisitor<void> { +template<typename SubType, typename ReturnType = void> +struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { Expression* replace; WasmWalker() : replace(nullptr) {} @@ -1251,123 +1351,41 @@ struct WasmWalker : public WasmVisitor<void> { } // By default, do nothing - void visitBlock(Block *curr) override {} - void visitIf(If *curr) override {} - void visitLoop(Loop *curr) override {} - void visitBreak(Break *curr) override {} - void visitSwitch(Switch *curr) override {} - void visitCall(Call *curr) override {} - void visitCallImport(CallImport *curr) override {} - void visitCallIndirect(CallIndirect *curr) override {} - void visitGetLocal(GetLocal *curr) override {} - void visitSetLocal(SetLocal *curr) override {} - void visitLoad(Load *curr) override {} - void visitStore(Store *curr) override {} - void visitConst(Const *curr) override {} - void visitUnary(Unary *curr) override {} - void visitBinary(Binary *curr) override {} - void visitSelect(Select *curr) override {} - void visitHost(Host *curr) override {} - void visitNop(Nop *curr) override {} - void visitUnreachable(Unreachable *curr) override {} - - void visitFunctionType(FunctionType *curr) override {} - void visitImport(Import *curr) override {} - void visitExport(Export *curr) override {} - void visitFunction(Function *curr) override {} - void visitTable(Table *curr) override {} - void visitMemory(Memory *curr) override {} - void visitModule(Module *curr) override {} + ReturnType visitBlock(Block *curr) {} + ReturnType visitIf(If *curr) {} + ReturnType visitLoop(Loop *curr) {} + ReturnType visitBreak(Break *curr) {} + ReturnType visitSwitch(Switch *curr) {} + ReturnType visitCall(Call *curr) {} + ReturnType visitCallImport(CallImport *curr) {} + ReturnType visitCallIndirect(CallIndirect *curr) {} + ReturnType visitGetLocal(GetLocal *curr) {} + ReturnType visitSetLocal(SetLocal *curr) {} + ReturnType visitLoad(Load *curr) {} + ReturnType visitStore(Store *curr) {} + ReturnType visitConst(Const *curr) {} + ReturnType visitUnary(Unary *curr) {} + ReturnType visitBinary(Binary *curr) {} + ReturnType visitSelect(Select *curr) {} + ReturnType visitHost(Host *curr) {} + ReturnType visitNop(Nop *curr) {} + ReturnType visitUnreachable(Unreachable *curr) {} + + ReturnType visitFunctionType(FunctionType *curr) {} + ReturnType visitImport(Import *curr) {} + ReturnType visitExport(Export *curr) {} + ReturnType visitFunction(Function *curr) {} + ReturnType visitTable(Table *curr) {} + ReturnType visitMemory(Memory *curr) {} + ReturnType visitModule(Module *curr) {} // children-first - void walk(Expression*& curr) { + void walk(Expression*& curr) override { if (!curr) return; - struct ChildWalker : public WasmVisitor { - WasmWalker& parent; - - ChildWalker(WasmWalker& parent) : parent(parent) {} - - void visitBlock(Block *curr) override { - ExpressionList& list = curr->list; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitIf(If *curr) override { - parent.walk(curr->condition); - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - } - void visitLoop(Loop *curr) override { - parent.walk(curr->body); - } - void visitBreak(Break *curr) override { - parent.walk(curr->condition); - parent.walk(curr->value); - } - void visitSwitch(Switch *curr) override { - parent.walk(curr->value); - for (auto& case_ : curr->cases) { - parent.walk(case_.body); - } - } - void visitCall(Call *curr) override { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitCallImport(CallImport *curr) override { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitCallIndirect(CallIndirect *curr) override { - parent.walk(curr->target); - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitGetLocal(GetLocal *curr) override {} - void visitSetLocal(SetLocal *curr) override { - parent.walk(curr->value); - } - void visitLoad(Load *curr) override { - parent.walk(curr->ptr); - } - void visitStore(Store *curr) override { - parent.walk(curr->ptr); - parent.walk(curr->value); - } - void visitConst(Const *curr) override {} - void visitUnary(Unary *curr) override { - parent.walk(curr->value); - } - void visitBinary(Binary *curr) override { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitSelect(Select *curr) override { - parent.walk(curr->condition); - parent.walk(curr->ifTrue); - parent.walk(curr->ifFalse); - } - void visitHost(Host *curr) override { - ExpressionList& list = curr->operands; - for (size_t z = 0; z < list.size(); z++) { - parent.walk(list[z]); - } - } - void visitNop(Nop *curr) override {} - void visitUnreachable(Unreachable *curr) override {} - }; - - ChildWalker(*this).visit(curr); + ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr); - visit(curr); + this->visit(curr); if (replace) { curr = replace; @@ -1375,33 +1393,35 @@ struct WasmWalker : public WasmVisitor<void> { } } - void startWalk(Function *func) { + void startWalk(Function *func) override { walk(func->body); } - void startWalk(Module *module) { + void startWalk(Module *module) override { + // Dispatch statically through the SubType. + SubType* self = static_cast<SubType*>(this); for (auto curr : module->functionTypes) { - visitFunctionType(curr); + self->visitFunctionType(curr); assert(!replace); } for (auto curr : module->imports) { - visitImport(curr); + self->visitImport(curr); assert(!replace); } for (auto curr : module->exports) { - visitExport(curr); + self->visitExport(curr); assert(!replace); } for (auto curr : module->functions) { startWalk(curr); - visitFunction(curr); + self->visitFunction(curr); assert(!replace); } - visitTable(&module->table); + self->visitTable(&module->table); assert(!replace); - visitMemory(&module->memory); + self->visitMemory(&module->memory); assert(!replace); - visitModule(module); + self->visitModule(module); assert(!replace); } }; |