summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h70
1 files changed, 39 insertions, 31 deletions
diff --git a/src/wasm.h b/src/wasm.h
index a5ad3fe77..72ebe98ee 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1132,9 +1132,13 @@ class AllocatingModule : public Module {
};
//
-// Simple WebAssembly AST visiting. Useful for anything that wants to do
-// something different for each AST node type, like printing, interpreting,
-// etc.
+// WebAssembly AST visitor. Useful for anything that wants to do something
+// different for each AST node type, like printing, interpreting, etc.
+//
+// This class is specifically designed as a template to avoid virtual function
+// call overhead. To write a visitor, derive from this class as follows:
+//
+// struct MyVisitor : public WasmVisitor<MyVisitor> { .. }
//
template<typename SubType, typename ReturnType>
@@ -1170,33 +1174,33 @@ struct WasmVisitor {
ReturnType visitMemory(Memory *curr) { abort(); }
ReturnType visitModule(Module *curr) { abort(); }
+#define DELEGATE(CLASS_TO_VISIT) \
+ return static_cast<SubType*>(this)-> \
+ visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
+
ReturnType visit(Expression *curr) {
assert(curr);
- SubType* self = static_cast<SubType*>(this);
switch (curr->_id) {
- case Expression::Id::BlockId: return self->visitBlock((Block*)curr);
- case Expression::Id::IfId: return self->visitIf((If*)curr);
- case Expression::Id::LoopId: return self->visitLoop((Loop*)curr);
- case Expression::Id::BreakId: return self->visitBreak((Break*)curr);
- case Expression::Id::SwitchId: return self->visitSwitch((Switch*)curr);
- case Expression::Id::CallId: return self->visitCall((Call*)curr);
- case Expression::Id::CallImportId: return self->visitCallImport((CallImport*)curr);
- case Expression::Id::CallIndirectId: return self->visitCallIndirect((CallIndirect*)curr);
- case Expression::Id::GetLocalId: return self->visitGetLocal((GetLocal*)curr);
- case Expression::Id::SetLocalId: return self->visitSetLocal((SetLocal*)curr);
- case Expression::Id::LoadId: return self->visitLoad((Load*)curr);
- case Expression::Id::StoreId: return self->visitStore((Store*)curr);
- case Expression::Id::ConstId: return self->visitConst((Const*)curr);
- case Expression::Id::UnaryId: return self->visitUnary((Unary*)curr);
- case Expression::Id::BinaryId: return self->visitBinary((Binary*)curr);
- case Expression::Id::SelectId: return self->visitSelect((Select*)curr);
- case Expression::Id::HostId: return self->visitHost((Host*)curr);
- case Expression::Id::NopId: return self->visitNop((Nop*)curr);
- case Expression::Id::UnreachableId: return self->visitUnreachable((Unreachable*)curr);
- default: {
- std::cerr << "visiting unknown expression " << curr->_id << '\n';
- abort();
- }
+ case Expression::Id::InvalidId: abort();
+ case Expression::Id::BlockId: DELEGATE(Block);
+ case Expression::Id::IfId: DELEGATE(If);
+ case Expression::Id::LoopId: DELEGATE(Loop);
+ case Expression::Id::BreakId: DELEGATE(Break);
+ case Expression::Id::SwitchId: DELEGATE(Switch);
+ case Expression::Id::CallId: DELEGATE(Call);
+ case Expression::Id::CallImportId: DELEGATE(CallImport);
+ case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
+ case Expression::Id::GetLocalId: DELEGATE(GetLocal);
+ case Expression::Id::SetLocalId: DELEGATE(SetLocal);
+ case Expression::Id::LoadId: DELEGATE(Load);
+ case Expression::Id::StoreId: DELEGATE(Store);
+ case Expression::Id::ConstId: DELEGATE(Const);
+ case Expression::Id::UnaryId: DELEGATE(Unary);
+ case Expression::Id::BinaryId: DELEGATE(Binary);
+ case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::HostId: DELEGATE(Host);
+ case Expression::Id::NopId: DELEGATE(Nop);
+ case Expression::Id::UnreachableId: DELEGATE(Unreachable);
}
}
};
@@ -1234,7 +1238,10 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
return o;
}
-template<typename SubType, typename ReturnType>
+//
+// Base class for all WasmWalkers
+//
+template<typename SubType, typename ReturnType = void>
struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
virtual void walk(Expression*& curr) { abort(); }
virtual void startWalk(Function *func) { abort(); }
@@ -1242,7 +1249,7 @@ struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
};
template<typename ParentType>
-struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>, void> {
+struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> {
ParentType& parent;
ChildWalker(ParentType& parent) : parent(parent) {}
@@ -1330,7 +1337,7 @@ struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>, void> {
// the current expression node. Useful for writing optimization passes.
//
-template<typename SubType, typename ReturnType>
+template<typename SubType, typename ReturnType = void>
struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> {
Expression* replace;
@@ -1374,7 +1381,7 @@ struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> {
void walk(Expression*& curr) override {
if (!curr) return;
- ChildWalker<WasmWalker<SubType, ReturnType> >(*this).visit(curr);
+ ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr);
this->visit(curr);
@@ -1389,6 +1396,7 @@ struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> {
}
void startWalk(Module *module) override {
+ // Dispatch statically through the SubType.
SubType* self = static_cast<SubType*>(this);
for (auto curr : module->functionTypes) {
self->visitFunctionType(curr);