summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ast_utils.h26
-rw-r--r--src/wasm-traversal.h94
2 files changed, 118 insertions, 2 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 6e5251860..4a8d9ff80 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -76,7 +76,7 @@ struct ExpressionAnalyzer {
// vs
// (block (unreachable))
// This converts to the latter form.
-struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
+struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<ReFinalize>>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new ReFinalize; }
@@ -154,6 +154,8 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
void visitStore(Store *curr) { curr->finalize(); }
void visitAtomicRMW(AtomicRMW *curr) { curr->finalize(); }
void visitAtomicCmpxchg(AtomicCmpxchg *curr) { curr->finalize(); }
+ void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
+ void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
void visitConst(Const *curr) { curr->finalize(); }
void visitUnary(Unary *curr) { curr->finalize(); }
void visitBinary(Binary *curr) { curr->finalize(); }
@@ -173,6 +175,14 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
}
}
+ void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
+ void visitImport(Import* curr) { WASM_UNREACHABLE(); }
+ void visitExport(Export* curr) { WASM_UNREACHABLE(); }
+ void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
+ void visitTable(Table* curr) { WASM_UNREACHABLE(); }
+ void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
+ void visitModule(Module* curr) { WASM_UNREACHABLE(); }
+
WasmType getValueType(Expression* value) {
return value ? value->type : none;
}
@@ -186,7 +196,7 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize>> {
// Re-finalize a single node. This is slow, if you want to refinalize
// an entire ast, use ReFinalize
-struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
+struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> {
void visitBlock(Block *curr) { curr->finalize(); }
void visitIf(If *curr) { curr->finalize(); }
void visitLoop(Loop *curr) { curr->finalize(); }
@@ -201,6 +211,10 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
void visitSetGlobal(SetGlobal *curr) { curr->finalize(); }
void visitLoad(Load *curr) { curr->finalize(); }
void visitStore(Store *curr) { curr->finalize(); }
+ void visitAtomicRMW(AtomicRMW* curr) { curr->finalize(); }
+ void visitAtomicCmpxchg(AtomicCmpxchg* curr) { curr->finalize(); }
+ void visitAtomicWait(AtomicWait* curr) { curr->finalize(); }
+ void visitAtomicWake(AtomicWake* curr) { curr->finalize(); }
void visitConst(Const *curr) { curr->finalize(); }
void visitUnary(Unary *curr) { curr->finalize(); }
void visitBinary(Binary *curr) { curr->finalize(); }
@@ -211,6 +225,14 @@ struct ReFinalizeNode : public Visitor<ReFinalizeNode> {
void visitNop(Nop *curr) { curr->finalize(); }
void visitUnreachable(Unreachable *curr) { curr->finalize(); }
+ void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); }
+ void visitImport(Import* curr) { WASM_UNREACHABLE(); }
+ void visitExport(Export* curr) { WASM_UNREACHABLE(); }
+ void visitGlobal(Global* curr) { WASM_UNREACHABLE(); }
+ void visitTable(Table* curr) { WASM_UNREACHABLE(); }
+ void visitMemory(Memory* curr) { WASM_UNREACHABLE(); }
+ void visitModule(Module* curr) { WASM_UNREACHABLE(); }
+
// given a stack of nested expressions, update them all from child to parent
static void updateStack(std::vector<Expression*>& expressionStack) {
for (int i = int(expressionStack.size()) - 1; i >= 0; i--) {
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 8384c6a6f..0c5088917 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -32,6 +32,8 @@
namespace wasm {
+// A generic visitor, defaulting to doing nothing on each visit
+
template<typename SubType, typename ReturnType = void>
struct Visitor {
// Expression visitors
@@ -115,6 +117,98 @@ struct Visitor {
}
};
+// A visitor which must be overridden for each visitor that is reached.
+
+template<typename SubType, typename ReturnType = void>
+struct OverriddenVisitor {
+ // Expression visitors, which must be overridden
+ #define UNIMPLEMENTED(CLASS_TO_VISIT) \
+ ReturnType visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { \
+ static_assert(&SubType::visit##CLASS_TO_VISIT != &OverriddenVisitor<SubType, ReturnType>::visit##CLASS_TO_VISIT, "Derived class must implement visit" #CLASS_TO_VISIT); \
+ WASM_UNREACHABLE(); \
+ }
+
+ UNIMPLEMENTED(Block);
+ UNIMPLEMENTED(If);
+ UNIMPLEMENTED(Loop);
+ UNIMPLEMENTED(Break);
+ UNIMPLEMENTED(Switch);
+ UNIMPLEMENTED(Call);
+ UNIMPLEMENTED(CallImport);
+ UNIMPLEMENTED(CallIndirect);
+ UNIMPLEMENTED(GetLocal);
+ UNIMPLEMENTED(SetLocal);
+ UNIMPLEMENTED(GetGlobal);
+ UNIMPLEMENTED(SetGlobal);
+ UNIMPLEMENTED(Load);
+ UNIMPLEMENTED(Store);
+ UNIMPLEMENTED(AtomicRMW);
+ UNIMPLEMENTED(AtomicCmpxchg);
+ UNIMPLEMENTED(AtomicWait);
+ UNIMPLEMENTED(AtomicWake);
+ UNIMPLEMENTED(Const);
+ UNIMPLEMENTED(Unary);
+ UNIMPLEMENTED(Binary);
+ UNIMPLEMENTED(Select);
+ UNIMPLEMENTED(Drop);
+ UNIMPLEMENTED(Return);
+ UNIMPLEMENTED(Host);
+ UNIMPLEMENTED(Nop);
+ UNIMPLEMENTED(Unreachable);
+ UNIMPLEMENTED(FunctionType);
+ UNIMPLEMENTED(Import);
+ UNIMPLEMENTED(Export);
+ UNIMPLEMENTED(Global);
+ UNIMPLEMENTED(Function);
+ UNIMPLEMENTED(Table);
+ UNIMPLEMENTED(Memory);
+ UNIMPLEMENTED(Module);
+
+ #undef UNIMPLEMENTED
+
+ ReturnType visit(Expression* curr) {
+ assert(curr);
+
+ #define DELEGATE(CLASS_TO_VISIT) \
+ return static_cast<SubType*>(this)-> \
+ visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
+
+ switch (curr->_id) {
+ case Expression::Id::BlockId: DELEGATE(Block);
+ case Expression::Id::IfId: DELEGATE(If);
+ case Expression::Id::LoopId: DELEGATE(Loop);
+ case Expression::Id::BreakId: DELEGATE(Break);
+ case Expression::Id::SwitchId: DELEGATE(Switch);
+ case Expression::Id::CallId: DELEGATE(Call);
+ case Expression::Id::CallImportId: DELEGATE(CallImport);
+ case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
+ case Expression::Id::GetLocalId: DELEGATE(GetLocal);
+ case Expression::Id::SetLocalId: DELEGATE(SetLocal);
+ case Expression::Id::GetGlobalId: DELEGATE(GetGlobal);
+ case Expression::Id::SetGlobalId: DELEGATE(SetGlobal);
+ case Expression::Id::LoadId: DELEGATE(Load);
+ case Expression::Id::StoreId: DELEGATE(Store);
+ case Expression::Id::AtomicRMWId: DELEGATE(AtomicRMW);
+ case Expression::Id::AtomicCmpxchgId: DELEGATE(AtomicCmpxchg);
+ case Expression::Id::AtomicWaitId: DELEGATE(AtomicWait);
+ case Expression::Id::AtomicWakeId: DELEGATE(AtomicWake);
+ case Expression::Id::ConstId: DELEGATE(Const);
+ case Expression::Id::UnaryId: DELEGATE(Unary);
+ case Expression::Id::BinaryId: DELEGATE(Binary);
+ case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::DropId: DELEGATE(Drop);
+ case Expression::Id::ReturnId: DELEGATE(Return);
+ case Expression::Id::HostId: DELEGATE(Host);
+ case Expression::Id::NopId: DELEGATE(Nop);
+ case Expression::Id::UnreachableId: DELEGATE(Unreachable);
+ case Expression::Id::InvalidId:
+ default: WASM_UNREACHABLE();
+ }
+
+ #undef DELEGATE
+ }
+};
+
// Visit with a single unified visitor, called on every node, instead of
// separate visit* per node