summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h410
1 files changed, 215 insertions, 195 deletions
diff --git a/src/wasm.h b/src/wasm.h
index 9edd7645e..ffca28f9e 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1133,100 +1133,106 @@ class AllocatingModule : public Module {
};
//
-// Simple WebAssembly AST visiting. Useful for anything that wants to do
-// something different for each AST node type, like printing, interpreting,
-// etc.
+// WebAssembly AST visitor. Useful for anything that wants to do something
+// different for each AST node type, like printing, interpreting, etc.
+//
+// This class is specifically designed as a template to avoid virtual function
+// call overhead. To write a visitor, derive from this class as follows:
+//
+// struct MyVisitor : public WasmVisitor<MyVisitor> { .. }
//
-template<class ReturnType>
+template<typename SubType, typename ReturnType>
struct WasmVisitor {
virtual ~WasmVisitor() {}
// should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048
// Expression visitors
- virtual ReturnType visitBlock(Block *curr) { abort(); }
- virtual ReturnType visitIf(If *curr) { abort(); }
- virtual ReturnType visitLoop(Loop *curr) { abort(); }
- virtual ReturnType visitBreak(Break *curr) { abort(); }
- virtual ReturnType visitSwitch(Switch *curr) { abort(); }
- virtual ReturnType visitCall(Call *curr) { abort(); }
- virtual ReturnType visitCallImport(CallImport *curr) { abort(); }
- virtual ReturnType visitCallIndirect(CallIndirect *curr) { abort(); }
- virtual ReturnType visitGetLocal(GetLocal *curr) { abort(); }
- virtual ReturnType visitSetLocal(SetLocal *curr) { abort(); }
- virtual ReturnType visitLoad(Load *curr) { abort(); }
- virtual ReturnType visitStore(Store *curr) { abort(); }
- virtual ReturnType visitConst(Const *curr) { abort(); }
- virtual ReturnType visitUnary(Unary *curr) { abort(); }
- virtual ReturnType visitBinary(Binary *curr) { abort(); }
- virtual ReturnType visitSelect(Select *curr) { abort(); }
- virtual ReturnType visitHost(Host *curr) { abort(); }
- virtual ReturnType visitNop(Nop *curr) { abort(); }
- virtual ReturnType visitUnreachable(Unreachable *curr) { abort(); }
+ ReturnType visitBlock(Block *curr) { abort(); }
+ ReturnType visitIf(If *curr) { abort(); }
+ ReturnType visitLoop(Loop *curr) { abort(); }
+ ReturnType visitBreak(Break *curr) { abort(); }
+ ReturnType visitSwitch(Switch *curr) { abort(); }
+ ReturnType visitCall(Call *curr) { abort(); }
+ ReturnType visitCallImport(CallImport *curr) { abort(); }
+ ReturnType visitCallIndirect(CallIndirect *curr) { abort(); }
+ ReturnType visitGetLocal(GetLocal *curr) { abort(); }
+ ReturnType visitSetLocal(SetLocal *curr) { abort(); }
+ ReturnType visitLoad(Load *curr) { abort(); }
+ ReturnType visitStore(Store *curr) { abort(); }
+ ReturnType visitConst(Const *curr) { abort(); }
+ ReturnType visitUnary(Unary *curr) { abort(); }
+ ReturnType visitBinary(Binary *curr) { abort(); }
+ ReturnType visitSelect(Select *curr) { abort(); }
+ ReturnType visitHost(Host *curr) { abort(); }
+ ReturnType visitNop(Nop *curr) { abort(); }
+ ReturnType visitUnreachable(Unreachable *curr) { abort(); }
// Module-level visitors
- virtual ReturnType visitFunctionType(FunctionType *curr) { abort(); }
- virtual ReturnType visitImport(Import *curr) { abort(); }
- virtual ReturnType visitExport(Export *curr) { abort(); }
- virtual ReturnType visitFunction(Function *curr) { abort(); }
- virtual ReturnType visitTable(Table *curr) { abort(); }
- virtual ReturnType visitMemory(Memory *curr) { abort(); }
- virtual ReturnType visitModule(Module *curr) { abort(); }
+ ReturnType visitFunctionType(FunctionType *curr) { abort(); }
+ ReturnType visitImport(Import *curr) { abort(); }
+ ReturnType visitExport(Export *curr) { abort(); }
+ ReturnType visitFunction(Function *curr) { abort(); }
+ ReturnType visitTable(Table *curr) { abort(); }
+ ReturnType visitMemory(Memory *curr) { abort(); }
+ ReturnType visitModule(Module *curr) { abort(); }
+
+#define DELEGATE(CLASS_TO_VISIT) \
+ return static_cast<SubType*>(this)-> \
+ visit##CLASS_TO_VISIT(static_cast<CLASS_TO_VISIT*>(curr))
ReturnType visit(Expression *curr) {
assert(curr);
switch (curr->_id) {
- case Expression::Id::BlockId: return visitBlock((Block*)curr);
- case Expression::Id::IfId: return visitIf((If*)curr);
- case Expression::Id::LoopId: return visitLoop((Loop*)curr);
- case Expression::Id::BreakId: return visitBreak((Break*)curr);
- case Expression::Id::SwitchId: return visitSwitch((Switch*)curr);
- case Expression::Id::CallId: return visitCall((Call*)curr);
- case Expression::Id::CallImportId: return visitCallImport((CallImport*)curr);
- case Expression::Id::CallIndirectId: return visitCallIndirect((CallIndirect*)curr);
- case Expression::Id::GetLocalId: return visitGetLocal((GetLocal*)curr);
- case Expression::Id::SetLocalId: return visitSetLocal((SetLocal*)curr);
- case Expression::Id::LoadId: return visitLoad((Load*)curr);
- case Expression::Id::StoreId: return visitStore((Store*)curr);
- case Expression::Id::ConstId: return visitConst((Const*)curr);
- case Expression::Id::UnaryId: return visitUnary((Unary*)curr);
- case Expression::Id::BinaryId: return visitBinary((Binary*)curr);
- case Expression::Id::SelectId: return visitSelect((Select*)curr);
- case Expression::Id::HostId: return visitHost((Host*)curr);
- case Expression::Id::NopId: return visitNop((Nop*)curr);
- case Expression::Id::UnreachableId: return visitUnreachable((Unreachable*)curr);
- default: {
- std::cerr << "visiting unknown expression " << curr->_id << '\n';
- abort();
- }
+ case Expression::Id::InvalidId: abort();
+ case Expression::Id::BlockId: DELEGATE(Block);
+ case Expression::Id::IfId: DELEGATE(If);
+ case Expression::Id::LoopId: DELEGATE(Loop);
+ case Expression::Id::BreakId: DELEGATE(Break);
+ case Expression::Id::SwitchId: DELEGATE(Switch);
+ case Expression::Id::CallId: DELEGATE(Call);
+ case Expression::Id::CallImportId: DELEGATE(CallImport);
+ case Expression::Id::CallIndirectId: DELEGATE(CallIndirect);
+ case Expression::Id::GetLocalId: DELEGATE(GetLocal);
+ case Expression::Id::SetLocalId: DELEGATE(SetLocal);
+ case Expression::Id::LoadId: DELEGATE(Load);
+ case Expression::Id::StoreId: DELEGATE(Store);
+ case Expression::Id::ConstId: DELEGATE(Const);
+ case Expression::Id::UnaryId: DELEGATE(Unary);
+ case Expression::Id::BinaryId: DELEGATE(Binary);
+ case Expression::Id::SelectId: DELEGATE(Select);
+ case Expression::Id::HostId: DELEGATE(Host);
+ case Expression::Id::NopId: DELEGATE(Nop);
+ case Expression::Id::UnreachableId: DELEGATE(Unreachable);
+ default: WASM_UNREACHABLE();
}
}
};
std::ostream& Expression::print(std::ostream &o, unsigned indent) {
- struct ExpressionPrinter : public WasmVisitor<void> {
+ struct ExpressionPrinter : public WasmVisitor<ExpressionPrinter, void> {
std::ostream &o;
unsigned indent;
ExpressionPrinter(std::ostream &o, unsigned indent) : o(o), indent(indent) {}
- void visitBlock(Block *curr) override { curr->doPrint(o, indent); }
- void visitIf(If *curr) override { curr->doPrint(o, indent); }
- void visitLoop(Loop *curr) override { curr->doPrint(o, indent); }
- void visitBreak(Break *curr) override { curr->doPrint(o, indent); }
- void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); }
- void visitCall(Call *curr) override { curr->doPrint(o, indent); }
- void visitCallImport(CallImport *curr) override { curr->doPrint(o, indent); }
- void visitCallIndirect(CallIndirect *curr) override { curr->doPrint(o, indent); }
- void visitGetLocal(GetLocal *curr) override { curr->doPrint(o, indent); }
- void visitSetLocal(SetLocal *curr) override { curr->doPrint(o, indent); }
- void visitLoad(Load *curr) override { curr->doPrint(o, indent); }
- void visitStore(Store *curr) override { curr->doPrint(o, indent); }
- void visitConst(Const *curr) override { curr->doPrint(o, indent); }
- void visitUnary(Unary *curr) override { curr->doPrint(o, indent); }
- void visitBinary(Binary *curr) override { curr->doPrint(o, indent); }
- void visitSelect(Select *curr) override { curr->doPrint(o, indent); }
- void visitHost(Host *curr) override { curr->doPrint(o, indent); }
- void visitNop(Nop *curr) override { curr->doPrint(o, indent); }
- void visitUnreachable(Unreachable *curr) override { curr->doPrint(o, indent); }
+ void visitBlock(Block *curr) { curr->doPrint(o, indent); }
+ void visitIf(If *curr) { curr->doPrint(o, indent); }
+ void visitLoop(Loop *curr) { curr->doPrint(o, indent); }
+ void visitBreak(Break *curr) { curr->doPrint(o, indent); }
+ void visitSwitch(Switch *curr) { curr->doPrint(o, indent); }
+ void visitCall(Call *curr) { curr->doPrint(o, indent); }
+ void visitCallImport(CallImport *curr) { curr->doPrint(o, indent); }
+ void visitCallIndirect(CallIndirect *curr) { curr->doPrint(o, indent); }
+ void visitGetLocal(GetLocal *curr) { curr->doPrint(o, indent); }
+ void visitSetLocal(SetLocal *curr) { curr->doPrint(o, indent); }
+ void visitLoad(Load *curr) { curr->doPrint(o, indent); }
+ void visitStore(Store *curr) { curr->doPrint(o, indent); }
+ void visitConst(Const *curr) { curr->doPrint(o, indent); }
+ void visitUnary(Unary *curr) { curr->doPrint(o, indent); }
+ void visitBinary(Binary *curr) { curr->doPrint(o, indent); }
+ void visitSelect(Select *curr) { curr->doPrint(o, indent); }
+ void visitHost(Host *curr) { curr->doPrint(o, indent); }
+ void visitNop(Nop *curr) { curr->doPrint(o, indent); }
+ void visitUnreachable(Unreachable *curr) { curr->doPrint(o, indent); }
};
ExpressionPrinter(o, indent).visit(this);
@@ -1235,12 +1241,106 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
}
//
+// Base class for all WasmWalkers
+//
+template<typename SubType, typename ReturnType = void>
+struct WasmWalkerBase : public WasmVisitor<SubType, ReturnType> {
+ virtual void walk(Expression*& curr) { abort(); }
+ virtual void startWalk(Function *func) { abort(); }
+ virtual void startWalk(Module *module) { abort(); }
+};
+
+template<typename ParentType>
+struct ChildWalker : public WasmWalkerBase<ChildWalker<ParentType>> {
+ ParentType& parent;
+
+ ChildWalker(ParentType& parent) : parent(parent) {}
+
+ void visitBlock(Block *curr) {
+ ExpressionList& list = curr->list;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitIf(If *curr) {
+ parent.walk(curr->condition);
+ parent.walk(curr->ifTrue);
+ parent.walk(curr->ifFalse);
+ }
+ void visitLoop(Loop *curr) {
+ parent.walk(curr->body);
+ }
+ void visitBreak(Break *curr) {
+ parent.walk(curr->condition);
+ parent.walk(curr->value);
+ }
+ void visitSwitch(Switch *curr) {
+ parent.walk(curr->value);
+ for (auto& case_ : curr->cases) {
+ parent.walk(case_.body);
+ }
+ }
+ void visitCall(Call *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitCallImport(CallImport *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitCallIndirect(CallIndirect *curr) {
+ parent.walk(curr->target);
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitGetLocal(GetLocal *curr) {}
+ void visitSetLocal(SetLocal *curr) {
+ parent.walk(curr->value);
+ }
+ void visitLoad(Load *curr) {
+ parent.walk(curr->ptr);
+ }
+ void visitStore(Store *curr) {
+ parent.walk(curr->ptr);
+ parent.walk(curr->value);
+ }
+ void visitConst(Const *curr) {}
+ void visitUnary(Unary *curr) {
+ parent.walk(curr->value);
+ }
+ void visitBinary(Binary *curr) {
+ parent.walk(curr->left);
+ parent.walk(curr->right);
+ }
+ void visitSelect(Select *curr) {
+ parent.walk(curr->condition);
+ parent.walk(curr->ifTrue);
+ parent.walk(curr->ifFalse);
+ }
+ void visitHost(Host *curr) {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ parent.walk(list[z]);
+ }
+ }
+ void visitNop(Nop *curr) {}
+ void visitUnreachable(Unreachable *curr) {}
+};
+
+//
// Simple WebAssembly children-first walking (i.e., post-order, if you look
// at the children as subtrees of the current node), with the ability to replace
// the current expression node. Useful for writing optimization passes.
//
-struct WasmWalker : public WasmVisitor<void> {
+template<typename SubType, typename ReturnType = void>
+struct WasmWalker : public WasmWalkerBase<SubType, ReturnType> {
Expression* replace;
WasmWalker() : replace(nullptr) {}
@@ -1251,123 +1351,41 @@ struct WasmWalker : public WasmVisitor<void> {
}
// By default, do nothing
- void visitBlock(Block *curr) override {}
- void visitIf(If *curr) override {}
- void visitLoop(Loop *curr) override {}
- void visitBreak(Break *curr) override {}
- void visitSwitch(Switch *curr) override {}
- void visitCall(Call *curr) override {}
- void visitCallImport(CallImport *curr) override {}
- void visitCallIndirect(CallIndirect *curr) override {}
- void visitGetLocal(GetLocal *curr) override {}
- void visitSetLocal(SetLocal *curr) override {}
- void visitLoad(Load *curr) override {}
- void visitStore(Store *curr) override {}
- void visitConst(Const *curr) override {}
- void visitUnary(Unary *curr) override {}
- void visitBinary(Binary *curr) override {}
- void visitSelect(Select *curr) override {}
- void visitHost(Host *curr) override {}
- void visitNop(Nop *curr) override {}
- void visitUnreachable(Unreachable *curr) override {}
-
- void visitFunctionType(FunctionType *curr) override {}
- void visitImport(Import *curr) override {}
- void visitExport(Export *curr) override {}
- void visitFunction(Function *curr) override {}
- void visitTable(Table *curr) override {}
- void visitMemory(Memory *curr) override {}
- void visitModule(Module *curr) override {}
+ ReturnType visitBlock(Block *curr) {}
+ ReturnType visitIf(If *curr) {}
+ ReturnType visitLoop(Loop *curr) {}
+ ReturnType visitBreak(Break *curr) {}
+ ReturnType visitSwitch(Switch *curr) {}
+ ReturnType visitCall(Call *curr) {}
+ ReturnType visitCallImport(CallImport *curr) {}
+ ReturnType visitCallIndirect(CallIndirect *curr) {}
+ ReturnType visitGetLocal(GetLocal *curr) {}
+ ReturnType visitSetLocal(SetLocal *curr) {}
+ ReturnType visitLoad(Load *curr) {}
+ ReturnType visitStore(Store *curr) {}
+ ReturnType visitConst(Const *curr) {}
+ ReturnType visitUnary(Unary *curr) {}
+ ReturnType visitBinary(Binary *curr) {}
+ ReturnType visitSelect(Select *curr) {}
+ ReturnType visitHost(Host *curr) {}
+ ReturnType visitNop(Nop *curr) {}
+ ReturnType visitUnreachable(Unreachable *curr) {}
+
+ ReturnType visitFunctionType(FunctionType *curr) {}
+ ReturnType visitImport(Import *curr) {}
+ ReturnType visitExport(Export *curr) {}
+ ReturnType visitFunction(Function *curr) {}
+ ReturnType visitTable(Table *curr) {}
+ ReturnType visitMemory(Memory *curr) {}
+ ReturnType visitModule(Module *curr) {}
// children-first
- void walk(Expression*& curr) {
+ void walk(Expression*& curr) override {
if (!curr) return;
- struct ChildWalker : public WasmVisitor {
- WasmWalker& parent;
-
- ChildWalker(WasmWalker& parent) : parent(parent) {}
-
- void visitBlock(Block *curr) override {
- ExpressionList& list = curr->list;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitIf(If *curr) override {
- parent.walk(curr->condition);
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- }
- void visitLoop(Loop *curr) override {
- parent.walk(curr->body);
- }
- void visitBreak(Break *curr) override {
- parent.walk(curr->condition);
- parent.walk(curr->value);
- }
- void visitSwitch(Switch *curr) override {
- parent.walk(curr->value);
- for (auto& case_ : curr->cases) {
- parent.walk(case_.body);
- }
- }
- void visitCall(Call *curr) override {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitCallImport(CallImport *curr) override {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitCallIndirect(CallIndirect *curr) override {
- parent.walk(curr->target);
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitGetLocal(GetLocal *curr) override {}
- void visitSetLocal(SetLocal *curr) override {
- parent.walk(curr->value);
- }
- void visitLoad(Load *curr) override {
- parent.walk(curr->ptr);
- }
- void visitStore(Store *curr) override {
- parent.walk(curr->ptr);
- parent.walk(curr->value);
- }
- void visitConst(Const *curr) override {}
- void visitUnary(Unary *curr) override {
- parent.walk(curr->value);
- }
- void visitBinary(Binary *curr) override {
- parent.walk(curr->left);
- parent.walk(curr->right);
- }
- void visitSelect(Select *curr) override {
- parent.walk(curr->condition);
- parent.walk(curr->ifTrue);
- parent.walk(curr->ifFalse);
- }
- void visitHost(Host *curr) override {
- ExpressionList& list = curr->operands;
- for (size_t z = 0; z < list.size(); z++) {
- parent.walk(list[z]);
- }
- }
- void visitNop(Nop *curr) override {}
- void visitUnreachable(Unreachable *curr) override {}
- };
-
- ChildWalker(*this).visit(curr);
+ ChildWalker<WasmWalker<SubType, ReturnType>>(*this).visit(curr);
- visit(curr);
+ this->visit(curr);
if (replace) {
curr = replace;
@@ -1375,33 +1393,35 @@ struct WasmWalker : public WasmVisitor<void> {
}
}
- void startWalk(Function *func) {
+ void startWalk(Function *func) override {
walk(func->body);
}
- void startWalk(Module *module) {
+ void startWalk(Module *module) override {
+ // Dispatch statically through the SubType.
+ SubType* self = static_cast<SubType*>(this);
for (auto curr : module->functionTypes) {
- visitFunctionType(curr);
+ self->visitFunctionType(curr);
assert(!replace);
}
for (auto curr : module->imports) {
- visitImport(curr);
+ self->visitImport(curr);
assert(!replace);
}
for (auto curr : module->exports) {
- visitExport(curr);
+ self->visitExport(curr);
assert(!replace);
}
for (auto curr : module->functions) {
startWalk(curr);
- visitFunction(curr);
+ self->visitFunction(curr);
assert(!replace);
}
- visitTable(&module->table);
+ self->visitTable(&module->table);
assert(!replace);
- visitMemory(&module->memory);
+ self->visitMemory(&module->memory);
assert(!replace);
- visitModule(module);
+ self->visitModule(module);
assert(!replace);
}
};