diff options
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 70 |
1 files changed, 39 insertions, 31 deletions
diff --git a/src/wasm.h b/src/wasm.h index a5ad3fe77..72ebe98ee 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1132,9 +1132,13 @@ class AllocatingModule : public Module { }; // -// Simple WebAssembly AST visiting. Useful for anything that wants to do -// something different for each AST node type, like printing, interpreting, -// etc. +// WebAssembly AST visitor. Useful for anything that wants to do something +// different for each AST node type, like printing, interpreting, etc. +// +// This class is specifically designed as a template to avoid virtual function +// call overhead. To write a visitor, derive from this class as follows: +// +// struct MyVisitor : public WasmVisitor<MyVisitor> { .. } // template<typename SubType, typename ReturnType> @@ -1170,33 +1174,33 @@ struct WasmVisitor { ReturnType visitMemory(Memory *curr) { abort(); } ReturnType visitModule(Module *curr) { abort(); } +#define DELEGATE(CLASS_TO_VISIT) \ + return static_cast<SubType*>(this)-> \ + visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) + ReturnType visit(Expression *curr) { assert(curr); - SubType* self = static_cast<SubType*>(this); switch (curr->_id) { - case Expression::Id::BlockId: return self->visitBlock((Block*)curr); - case Expression::Id::IfId: return self->visitIf((If*)curr); - case Expression::Id::LoopId: return self->visitLoop((Loop*)curr); - case Expression::Id::BreakId: return self->visitBreak((Break*)curr); - case Expression::Id::SwitchId: return self->visitSwitch((Switch*)curr); - case Expression::Id::CallId: return self->visitCall((Call*)curr); - case Expression::Id::CallImportId: return self->visitCallImport((CallImport*)curr); - case Expression::Id::CallIndirectId: return self->visitCallIndirect((CallIndirect*)curr); - case Expression::Id::GetLocalId: return self->visitGetLocal((GetLocal*)curr); - case Expression::Id::SetLocalId: return self->visitSetLocal((SetLocal*)curr); - case Expression::Id::LoadId: return self->visitLoad((Load*)curr); - case Expression::Id::StoreId: return self->visitStore((Store*)curr); - case Expression::Id::ConstId: return self->visitConst((Const*)curr); - case Expression::Id::UnaryId: return self->visitUnary((Unary*)curr); - case Expression::Id::BinaryId: return self->visitBinary((Binary*)curr); - case Expression::Id::SelectId: return self->visitSelect((Select*)curr); - case Expression::Id::HostId: return self->visitHost((Host*)curr); - case Expression::Id::NopId: return self->visitNop((Nop*)curr); - case Expression::Id::UnreachableId: return self->visitUnreachable((Unreachable*)curr); - default: { - std::cerr << "visiting unknown expression " << curr->_id << '\n'; - abort(); - } + case Expression::Id::InvalidId: abort(); + case Expression::Id::BlockId: DELEGATE(Block); + case Expression::Id::IfId: DELEGATE(If); + case Expression::Id::LoopId: DELEGATE(Loop); + case Expression::Id::BreakId: DELEGATE(Break); + case Expression::Id::SwitchId: DELEGATE(Switch); + case Expression::Id::CallId: DELEGATE(Call); + case Expression::Id::CallImportId: DELEGATE(CallImport); + case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); + case Expression::Id::GetLocalId: DELEGATE(GetLocal); + case Expression::Id::SetLocalId: DELEGATE(SetLocal); + case Expression::Id::LoadId: DELEGATE(Load); + case Expression::Id::StoreId: DELEGATE(Store); + case Expression::Id::ConstId: DELEGATE(Const); + case Expression::Id::UnaryId: DELEGATE(Unary); + case Expression::Id::BinaryId: DELEGATE(Binary); + case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::NopId: DELEGATE(Nop); + case Expression::Id::UnreachableId: DELEGATE(Unreachable); } } }; @@ -1234,7 +1238,10 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { return o; } -template<typename SubType, typename ReturnType> +// +// Base class for all WasmWalkers +// +template<typename SubType, typename ReturnType = void> struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { virtual void walk(Expression*& curr) { abort(); } virtual void startWalk(Function *func) { abort(); } @@ -1242,7 +1249,7 @@ struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> { }; template<typename ParentType> -struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>, void> { +struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> { ParentType& parent; ChildWalker(ParentType& parent) : parent(parent) {} @@ -1330,7 +1337,7 @@ struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>, void> { // the current expression node. Useful for writing optimization passes. // -template<typename SubType, typename ReturnType> +template<typename SubType, typename ReturnType = void> struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { Expression* replace; @@ -1374,7 +1381,7 @@ struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { void walk(Expression*& curr) override { if (!curr) return; - ChildWalker<WasmWalker<SubType, ReturnType> >(*this).visit(curr); + ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr); this->visit(curr); @@ -1389,6 +1396,7 @@ struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> { } void startWalk(Module *module) override { + // Dispatch statically through the SubType. SubType* self = static_cast<SubType*>(this); for (auto curr : module->functionTypes) { self->visitFunctionType(curr); |