diff options
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 361 |
1 files changed, 256 insertions, 105 deletions
diff --git a/src/wasm.h b/src/wasm.h index 4ef01395c..3deae22ef 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -234,13 +234,46 @@ enum HostOp { class Expression { public: + enum Id { + InvalidId = 0, + BlockId = 1, + IfId = 2, + LoopId = 3, + LabelId = 4, + BreakId = 5, + SwitchId =6 , + CallId = 7, + CallImportId = 8, + CallIndirectId = 9, + GetLocalId = 10, + SetLocalId = 11, + LoadId = 12, + StoreId = 13, + ConstId = 14, + UnaryId = 15, + BinaryId = 16, + CompareId = 17, + ConvertId = 18, + HostId = 19, + NopId = 20 + }; + Id _id; + + Expression() : _id(InvalidId) {} + Expression(Id id) : _id(id) {} + WasmType type; // the type of the expression: its output, not necessarily its input(s) - virtual std::ostream& print(std::ostream &o, unsigned indent) = 0; + std::ostream& print(std::ostream &o, unsigned indent); // avoid virtual here, for performance template<class T> bool is() { - return !!dynamic_cast<T*>(this); + return _id == T()._id; + } + + template<class T> + T* dyn_cast() { + return _id == T()._id ? (T*)this : nullptr; } }; @@ -269,17 +302,22 @@ std::ostream& printMinorOpening(std::ostream &o, const char *str) { typedef std::vector<Expression*> ExpressionList; // TODO: optimize class Nop : public Expression { - std::ostream& print(std::ostream &o, unsigned indent) override { +public: + Nop() : Expression(NopId) {} + + std::ostream& doPrint(std::ostream &o, unsigned indent) { return printMinorOpening(o, "nop") << ')'; } }; class Block : public Expression { public: + Block() : Expression(BlockId) {} + Name name; ExpressionList list; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "block"); if (name.is()) { o << ' ' << name; @@ -294,9 +332,11 @@ public: class If : public Expression { public: + If() : Expression(IfId) {} + Expression *condition, *ifTrue, *ifFalse; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "if"); incIndent(o, indent); printFullLine(o, indent, condition); @@ -308,10 +348,12 @@ public: class Loop : public Expression { public: + Loop() : Expression(LoopId) {} + Name out, in; Expression *body; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "loop"); if (out.is()) { o << ' ' << out; @@ -327,16 +369,27 @@ public: class Label : public Expression { public: + Label() : Expression(LabelId) {} + Name name; Expression* body; + + std::ostream& doPrint(std::ostream &o, unsigned indent) { + printOpening(o, "label ") << name; + incIndent(o, indent); + printFullLine(o, indent, body); + return decIndent(o, indent); + } }; class Break : public Expression { public: + Break() : Expression(BreakId) {} + Name name; Expression *condition, *value; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "break ") << name; incIndent(o, indent); if (condition) printFullLine(o, indent, condition); @@ -347,6 +400,8 @@ public: class Switch : public Expression { public: + Switch() : Expression(SwitchId) {} + struct Case { Literal value; Expression *body; @@ -358,7 +413,7 @@ public: std::vector<Case> cases; Expression *default_; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "switch ") << name; incIndent(o, indent); printFullLine(o, indent, value); @@ -370,6 +425,8 @@ public: class Call : public Expression { public: + Call() : Expression(CallId) {} + Name target; ExpressionList operands; @@ -387,14 +444,19 @@ public: return o; } - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "call "); return printBody(o, indent); } }; class CallImport : public Call { - std::ostream& print(std::ostream &o, unsigned indent) override { +public: + CallImport() { + _id = CallImportId; + } + + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "call_import "); return printBody(o, indent); } @@ -444,11 +506,13 @@ public: class CallIndirect : public Expression { public: + CallIndirect() : Expression(CallIndirectId) {} + FunctionType *type; Expression *target; ExpressionList operands; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "call_indirect ") << type->name; incIndent(o, indent); printFullLine(o, indent, target); @@ -461,19 +525,23 @@ public: class GetLocal : public Expression { public: + GetLocal() : Expression(GetLocalId) {} + Name name; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { return printOpening(o, "get_local ") << name << ')'; } }; class SetLocal : public Expression { public: + SetLocal() : Expression(SetLocalId) {} + Name name; Expression *value; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { printOpening(o, "set_local ") << name; incIndent(o, indent); printFullLine(o, indent, value); @@ -483,6 +551,8 @@ public: class Load : public Expression { public: + Load() : Expression(LoadId) {} + unsigned bytes; bool signed_; bool float_; @@ -490,7 +560,7 @@ public: unsigned align; Expression *ptr; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o) << printWasmType(getWasmType(bytes, float_)) << ".load"; if (bytes < 4) { @@ -514,13 +584,15 @@ public: class Store : public Expression { public: + Store() : Expression(StoreId) {} + unsigned bytes; bool float_; int offset; unsigned align; Expression *ptr, *value; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o) << printWasmType(getWasmType(bytes, float_)) << ".store"; if (bytes < 4) { @@ -544,6 +616,8 @@ public: class Const : public Expression { public: + Const() : Expression(ConstId) {} + Literal value; Const* set(Literal value_) { @@ -552,17 +626,19 @@ public: return this; } - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { return o << value; } }; class Unary : public Expression { public: + Unary() : Expression(UnaryId) {} + UnaryOp op; Expression *value; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o) << printWasmType(type) << '.'; switch (op) { @@ -579,10 +655,12 @@ public: class Binary : public Expression { public: + Binary() : Expression(BinaryId) {} + BinaryOp op; Expression *left, *right; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o) << printWasmType(type) << '.'; switch (op) { @@ -615,15 +693,15 @@ public: class Compare : public Expression { public: + Compare() : Expression(CompareId) { + type = WasmType::i32; // output is always i32 + } + RelationalOp op; WasmType inputType; Expression *left, *right; - Compare() { - type = WasmType::i32; // output is always i32 - } - - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o) << printWasmType(inputType) << '.'; switch (op) { @@ -653,10 +731,12 @@ public: class Convert : public Expression { public: + Convert() : Expression(ConvertId) {} + ConvertOp op; Expression *value; - std::ostream& print(std::ostream &o, unsigned indent) override { + std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; prepareColor(o); switch (op) { @@ -674,8 +754,14 @@ public: class Host : public Expression { public: + Host() : Expression(HostId) {} + HostOp op; ExpressionList operands; + + std::ostream& doPrint(std::ostream &o, unsigned indent) { + abort(); + } }; // Globals @@ -805,7 +891,7 @@ public: }; // -// Simple WebAssembly AST visiting and children-first walking +// Simple WebAssembly AST visiting // template<class ReturnType> @@ -833,30 +919,73 @@ struct WasmVisitor { ReturnType visit(Expression *curr) { assert(curr); - if (Block *cast = dynamic_cast<Block*>(curr)) return visitBlock(cast); - if (If *cast = dynamic_cast<If*>(curr)) return visitIf(cast); - if (Loop *cast = dynamic_cast<Loop*>(curr)) return visitLoop(cast); - if (Label *cast = dynamic_cast<Label*>(curr)) return visitLabel(cast); - if (Break *cast = dynamic_cast<Break*>(curr)) return visitBreak(cast); - if (Switch *cast = dynamic_cast<Switch*>(curr)) return visitSwitch(cast); - if (Call *cast = dynamic_cast<Call*>(curr)) return visitCall(cast); - if (CallImport *cast = dynamic_cast<CallImport*>(curr)) return visitCallImport(cast); - if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) return visitCallIndirect(cast); - if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) return visitGetLocal(cast); - if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) return visitSetLocal(cast); - if (Load *cast = dynamic_cast<Load*>(curr)) return visitLoad(cast); - if (Store *cast = dynamic_cast<Store*>(curr)) return visitStore(cast); - if (Const *cast = dynamic_cast<Const*>(curr)) return visitConst(cast); - if (Unary *cast = dynamic_cast<Unary*>(curr)) return visitUnary(cast); - if (Binary *cast = dynamic_cast<Binary*>(curr)) return visitBinary(cast); - if (Compare *cast = dynamic_cast<Compare*>(curr)) return visitCompare(cast); - if (Convert *cast = dynamic_cast<Convert*>(curr)) return visitConvert(cast); - if (Host *cast = dynamic_cast<Host*>(curr)) return visitHost(cast); - if (Nop *cast = dynamic_cast<Nop*>(curr)) return visitNop(cast); - abort(); + switch (curr->_id) { + case Expression::Id::BlockId: return visitBlock((Block*)curr); + case Expression::Id::IfId: return visitIf((If*)curr); + case Expression::Id::LoopId: return visitLoop((Loop*)curr); + case Expression::Id::LabelId: return visitLabel((Label*)curr); + case Expression::Id::BreakId: return visitBreak((Break*)curr); + case Expression::Id::SwitchId: return visitSwitch((Switch*)curr); + case Expression::Id::CallId: return visitCall((Call*)curr); + case Expression::Id::CallImportId: return visitCallImport((CallImport*)curr); + case Expression::Id::CallIndirectId: return visitCallIndirect((CallIndirect*)curr); + case Expression::Id::GetLocalId: return visitGetLocal((GetLocal*)curr); + case Expression::Id::SetLocalId: return visitSetLocal((SetLocal*)curr); + case Expression::Id::LoadId: return visitLoad((Load*)curr); + case Expression::Id::StoreId: return visitStore((Store*)curr); + case Expression::Id::ConstId: return visitConst((Const*)curr); + case Expression::Id::UnaryId: return visitUnary((Unary*)curr); + case Expression::Id::BinaryId: return visitBinary((Binary*)curr); + case Expression::Id::CompareId: return visitCompare((Compare*)curr); + case Expression::Id::ConvertId: return visitConvert((Convert*)curr); + case Expression::Id::HostId: return visitHost((Host*)curr); + case Expression::Id::NopId: return visitNop((Nop*)curr); + default: { + std::cerr << "visiting unknown expression " << curr->_id << '\n'; + abort(); + } + } } }; +std::ostream& Expression::print(std::ostream &o, unsigned indent) { + struct ExpressionPrinter : public WasmVisitor<void> { + std::ostream &o; + unsigned indent; + + ExpressionPrinter(std::ostream &o, unsigned indent) : o(o), indent(indent) {} + + void visitBlock(Block *curr) override { curr->doPrint(o, indent); } + void visitIf(If *curr) override { curr->doPrint(o, indent); } + void visitLoop(Loop *curr) override { curr->doPrint(o, indent); } + void visitLabel(Label *curr) override { curr->doPrint(o, indent); } + void visitBreak(Break *curr) override { curr->doPrint(o, indent); } + void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); } + void visitCall(Call *curr) override { curr->doPrint(o, indent); } + void visitCallImport(CallImport *curr) override { curr->doPrint(o, indent); } + void visitCallIndirect(CallIndirect *curr) override { curr->doPrint(o, indent); } + void visitGetLocal(GetLocal *curr) override { curr->doPrint(o, indent); } + void visitSetLocal(SetLocal *curr) override { curr->doPrint(o, indent); } + void visitLoad(Load *curr) override { curr->doPrint(o, indent); } + void visitStore(Store *curr) override { curr->doPrint(o, indent); } + void visitConst(Const *curr) override { curr->doPrint(o, indent); } + void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } + void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } + void visitCompare(Compare *curr) override { curr->doPrint(o, indent); } + void visitConvert(Convert *curr) override { curr->doPrint(o, indent); } + void visitHost(Host *curr) override { curr->doPrint(o, indent); } + void visitNop(Nop *curr) override { curr->doPrint(o, indent); } + }; + + ExpressionPrinter(o, indent).visit(this); + + return o; +} + +// +// Simple WebAssembly children-first walking +// + struct WasmWalker : public WasmVisitor<Expression*> { wasm::Arena* allocator; // use an existing allocator, or null if no allocations @@ -889,71 +1018,93 @@ struct WasmWalker : public WasmVisitor<Expression*> { Expression *walk(Expression *curr) { if (!curr) return curr; - if (Block *cast = dynamic_cast<Block*>(curr)) { - ExpressionList& list = cast->list; - for (size_t z = 0; z < list.size(); z++) { - list[z] = walk(list[z]); + struct ChildWalker : public WasmVisitor<void> { + WasmWalker& parent; + + ChildWalker(WasmWalker& parent) : parent(parent) {} + + void visitBlock(Block *curr) override { + ExpressionList& list = curr->list; + for (size_t z = 0; z < list.size(); z++) { + list[z] = parent.walk(list[z]); + } } - } else if (If *cast = dynamic_cast<If*>(curr)) { - cast->condition = walk(cast->condition); - cast->ifTrue = walk(cast->ifTrue); - cast->ifFalse = walk(cast->ifFalse); - } else if (Loop *cast = dynamic_cast<Loop*>(curr)) { - cast->body = walk(cast->body); - } else if (Label *cast = dynamic_cast<Label*>(curr)) { - } else if (Break *cast = dynamic_cast<Break*>(curr)) { - cast->condition = walk(cast->condition); - cast->value = walk(cast->value); - } else if (Switch *cast = dynamic_cast<Switch*>(curr)) { - cast->value = walk(cast->value); - for (auto& curr : cast->cases) { - curr.body = walk(curr.body); + void visitIf(If *curr) override { + curr->condition = parent.walk(curr->condition); + curr->ifTrue = parent.walk(curr->ifTrue); + curr->ifFalse = parent.walk(curr->ifFalse); } - cast->default_ = walk(cast->default_); - } else if (Call *cast = dynamic_cast<Call*>(curr)) { - ExpressionList& list = cast->operands; - for (size_t z = 0; z < list.size(); z++) { - list[z] = walk(list[z]); + void visitLoop(Loop *curr) override { + curr->body = parent.walk(curr->body); } - } else if (CallImport *cast = dynamic_cast<CallImport*>(curr)) { - ExpressionList& list = cast->operands; - for (size_t z = 0; z < list.size(); z++) { - list[z] = walk(list[z]); + void visitLabel(Label *curr) override {} + void visitBreak(Break *curr) override { + curr->condition = parent.walk(curr->condition); + curr->value = parent.walk(curr->value); } - } else if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) { - cast->target = walk(cast->target); - ExpressionList& list = cast->operands; - for (size_t z = 0; z < list.size(); z++) { - list[z] = walk(list[z]); + void visitSwitch(Switch *curr) override { + curr->value = parent.walk(curr->value); + for (auto& case_ : curr->cases) { + case_.body = parent.walk(case_.body); + } + curr->default_ = parent.walk(curr->default_); } - } else if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) { - } else if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) { - cast->value = walk(cast->value); - } else if (Load *cast = dynamic_cast<Load*>(curr)) { - cast->ptr = walk(cast->ptr); - } else if (Store *cast = dynamic_cast<Store*>(curr)) { - cast->ptr = walk(cast->ptr); - cast->value = walk(cast->value); - } else if (Const *cast = dynamic_cast<Const*>(curr)) { - } else if (Unary *cast = dynamic_cast<Unary*>(curr)) { - cast->value = walk(cast->value); - } else if (Binary *cast = dynamic_cast<Binary*>(curr)) { - cast->left = walk(cast->left); - cast->right = walk(cast->right); - } else if (Compare *cast = dynamic_cast<Compare*>(curr)) { - cast->left = walk(cast->left); - cast->right = walk(cast->right); - } else if (Convert *cast = dynamic_cast<Convert*>(curr)) { - cast->value = walk(cast->value); - } else if (Host *cast = dynamic_cast<Host*>(curr)) { - ExpressionList& list = cast->operands; - for (size_t z = 0; z < list.size(); z++) { - list[z] = walk(list[z]); + void visitCall(Call *curr) override { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + list[z] = parent.walk(list[z]); + } } - } else if (Nop *cast = dynamic_cast<Nop*>(curr)) { - } else { - abort(); - } + void visitCallImport(CallImport *curr) override { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + list[z] = parent.walk(list[z]); + } + } + void visitCallIndirect(CallIndirect *curr) override { + curr->target = parent.walk(curr->target); + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + list[z] = parent.walk(list[z]); + } + } + void visitGetLocal(GetLocal *curr) override {} + void visitSetLocal(SetLocal *curr) override { + curr->value = parent.walk(curr->value); + } + void visitLoad(Load *curr) override { + curr->ptr = parent.walk(curr->ptr); + } + void visitStore(Store *curr) override { + curr->ptr = parent.walk(curr->ptr); + curr->value = parent.walk(curr->value); + } + void visitConst(Const *curr) override {} + void visitUnary(Unary *curr) override { + curr->value = parent.walk(curr->value); + } + void visitBinary(Binary *curr) override { + curr->left = parent.walk(curr->left); + curr->right = parent.walk(curr->right); + } + void visitCompare(Compare *curr) override { + curr->left = parent.walk(curr->left); + curr->right = parent.walk(curr->right); + } + void visitConvert(Convert *curr) override { + curr->value = parent.walk(curr->value); + } + void visitHost(Host *curr) override { + ExpressionList& list = curr->operands; + for (size_t z = 0; z < list.size(); z++) { + list[z] = parent.walk(list[z]); + } + } + void visitNop(Nop *curr) override {} + }; + + ChildWalker(*this).visit(curr); + return visit(curr); } |