diff options
-rw-r--r-- | src/ast_utils.h | 26 | ||||
-rw-r--r-- | src/wasm-traversal.h | 94 |
2 files changed, 118 insertions, 2 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h index 6e5251860..4a8d9ff80 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -76,7 +76,7 @@ struct ExpressionAnalyzer { // vs // (block (unreachable)) // This converts to the latter form. -struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { +struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<ReFinalize>>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new ReFinalize; } @@ -154,6 +154,8 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { void visitStore(Store *curr) { curr->finalize(); } void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); } void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); } + void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } + void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } void visitConst(Const *curr) { curr->finalize(); } void visitUnary(Unary *curr) { curr->finalize(); } void visitBinary(Binary *curr) { curr->finalize(); } @@ -173,6 +175,14 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { } } + void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } + void visitImport(Import* curr) { WASM_UNREACHABLE(); } + void visitExport(Export* curr) { WASM_UNREACHABLE(); } + void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } + void visitTable(Table* curr) { WASM_UNREACHABLE(); } + void visitMemory(Memory* curr) { WASM_UNREACHABLE(); } + void visitModule(Module* curr) { WASM_UNREACHABLE(); } + WasmType getValueType(Expression* value) { return value ? value->type : none; } @@ -186,7 +196,7 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> { // Re-finalize a single node. This is slow, if you want to refinalize // an entire ast, use ReFinalize -struct ReFinalizeNode : public Visitor<ReFinalizeNode> { +struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitBlock(Block *curr) { curr->finalize(); } void visitIf(If *curr) { curr->finalize(); } void visitLoop(Loop *curr) { curr->finalize(); } @@ -201,6 +211,10 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> { void visitSetGlobal(SetGlobal *curr) { curr->finalize(); } void visitLoad(Load *curr) { curr->finalize(); } void visitStore(Store *curr) { curr->finalize(); } + void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); } + void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); } + void visitAtomicWait(AtomicWait* curr) { curr->finalize(); } + void visitAtomicWake(AtomicWake* curr) { curr->finalize(); } void visitConst(Const *curr) { curr->finalize(); } void visitUnary(Unary *curr) { curr->finalize(); } void visitBinary(Binary *curr) { curr->finalize(); } @@ -211,6 +225,14 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> { void visitNop(Nop *curr) { curr->finalize(); } void visitUnreachable(Unreachable *curr) { curr->finalize(); } + void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } + void visitImport(Import* curr) { WASM_UNREACHABLE(); } + void visitExport(Export* curr) { WASM_UNREACHABLE(); } + void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } + void visitTable(Table* curr) { WASM_UNREACHABLE(); } + void visitMemory(Memory* curr) { WASM_UNREACHABLE(); } + void visitModule(Module* curr) { WASM_UNREACHABLE(); } + // given a stack of nested expressions, update them all from child to parent static void updateStack(std::vector<Expression*>& expressionStack) { for (int i = int(expressionStack.size()) - 1; i >= 0; i--) { diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 8384c6a6f..0c5088917 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -32,6 +32,8 @@ namespace wasm { +// A generic visitor, defaulting to doing nothing on each visit + template<typename SubType, typename ReturnType = void> struct Visitor { // Expression visitors @@ -115,6 +117,98 @@ struct Visitor { } }; +// A visitor which must be overridden for each visitor that is reached. + +template<typename SubType, typename ReturnType = void> +struct OverriddenVisitor { + // Expression visitors, which must be overridden + #define UNIMPLEMENTED(CLASS_TO_VISIT) \ + ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \ + static_assert(&SubType::visit##CLASS_TO_VISIT != &OverriddenVisitor<SubType, ReturnType>::visit##CLASS_TO_VISIT, "Derived class must implement visit" #CLASS_TO_VISIT); \ + WASM_UNREACHABLE(); \ + } + + UNIMPLEMENTED(Block); + UNIMPLEMENTED(If); + UNIMPLEMENTED(Loop); + UNIMPLEMENTED(Break); + UNIMPLEMENTED(Switch); + UNIMPLEMENTED(Call); + UNIMPLEMENTED(CallImport); + UNIMPLEMENTED(CallIndirect); + UNIMPLEMENTED(GetLocal); + UNIMPLEMENTED(SetLocal); + UNIMPLEMENTED(GetGlobal); + UNIMPLEMENTED(SetGlobal); + UNIMPLEMENTED(Load); + UNIMPLEMENTED(Store); + UNIMPLEMENTED(AtomicRMW); + UNIMPLEMENTED(AtomicCmpxchg); + UNIMPLEMENTED(AtomicWait); + UNIMPLEMENTED(AtomicWake); + UNIMPLEMENTED(Const); + UNIMPLEMENTED(Unary); + UNIMPLEMENTED(Binary); + UNIMPLEMENTED(Select); + UNIMPLEMENTED(Drop); + UNIMPLEMENTED(Return); + UNIMPLEMENTED(Host); + UNIMPLEMENTED(Nop); + UNIMPLEMENTED(Unreachable); + UNIMPLEMENTED(FunctionType); + UNIMPLEMENTED(Import); + UNIMPLEMENTED(Export); + UNIMPLEMENTED(Global); + UNIMPLEMENTED(Function); + UNIMPLEMENTED(Table); + UNIMPLEMENTED(Memory); + UNIMPLEMENTED(Module); + + #undef UNIMPLEMENTED + + ReturnType visit(Expression* curr) { + assert(curr); + + #define DELEGATE(CLASS_TO_VISIT) \ + return static_cast<SubType*>(this)-> \ + visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr)) + + switch (curr->_id) { + case Expression::Id::BlockId: DELEGATE(Block); + case Expression::Id::IfId: DELEGATE(If); + case Expression::Id::LoopId: DELEGATE(Loop); + case Expression::Id::BreakId: DELEGATE(Break); + case Expression::Id::SwitchId: DELEGATE(Switch); + case Expression::Id::CallId: DELEGATE(Call); + case Expression::Id::CallImportId: DELEGATE(CallImport); + case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); + case Expression::Id::GetLocalId: DELEGATE(GetLocal); + case Expression::Id::SetLocalId: DELEGATE(SetLocal); + case Expression::Id::GetGlobalId: DELEGATE(GetGlobal); + case Expression::Id::SetGlobalId: DELEGATE(SetGlobal); + case Expression::Id::LoadId: DELEGATE(Load); + case Expression::Id::StoreId: DELEGATE(Store); + case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW); + case Expression::Id::AtomicCmpxchgId: DELEGATE(AtomicCmpxchg); + case Expression::Id::AtomicWaitId: DELEGATE(AtomicWait); + case Expression::Id::AtomicWakeId: DELEGATE(AtomicWake); + case Expression::Id::ConstId: DELEGATE(Const); + case Expression::Id::UnaryId: DELEGATE(Unary); + case Expression::Id::BinaryId: DELEGATE(Binary); + case Expression::Id::SelectId: DELEGATE(Select); + case Expression::Id::DropId: DELEGATE(Drop); + case Expression::Id::ReturnId: DELEGATE(Return); + case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::NopId: DELEGATE(Nop); + case Expression::Id::UnreachableId: DELEGATE(Unreachable); + case Expression::Id::InvalidId: + default: WASM_UNREACHABLE(); + } + + #undef DELEGATE + } +}; + // Visit with a single unified visitor, called on every node, instead of // separate visit* per node |