summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h361
1 files changed, 256 insertions, 105 deletions
diff --git a/src/wasm.h b/src/wasm.h
index 4ef01395c..3deae22ef 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -234,13 +234,46 @@ enum HostOp {
class Expression {
public:
+ enum Id {
+ InvalidId = 0,
+ BlockId = 1,
+ IfId = 2,
+ LoopId = 3,
+ LabelId = 4,
+ BreakId = 5,
+ SwitchId =6 ,
+ CallId = 7,
+ CallImportId = 8,
+ CallIndirectId = 9,
+ GetLocalId = 10,
+ SetLocalId = 11,
+ LoadId = 12,
+ StoreId = 13,
+ ConstId = 14,
+ UnaryId = 15,
+ BinaryId = 16,
+ CompareId = 17,
+ ConvertId = 18,
+ HostId = 19,
+ NopId = 20
+ };
+ Id _id;
+
+ Expression() : _id(InvalidId) {}
+ Expression(Id id) : _id(id) {}
+
WasmType type; // the type of the expression: its output, not necessarily its input(s)
- virtual std::ostream& print(std::ostream &o, unsigned indent) = 0;
+ std::ostream& print(std::ostream &o, unsigned indent); // avoid virtual here, for performance
template<class T>
bool is() {
- return !!dynamic_cast<T*>(this);
+ return _id == T()._id;
+ }
+
+ template<class T>
+ T* dyn_cast() {
+ return _id == T()._id ? (T*)this : nullptr;
}
};
@@ -269,17 +302,22 @@ std::ostream& printMinorOpening(std::ostream &o, const char *str) {
typedef std::vector<Expression*> ExpressionList; // TODO: optimize
class Nop : public Expression {
- std::ostream& print(std::ostream &o, unsigned indent) override {
+public:
+ Nop() : Expression(NopId) {}
+
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
return printMinorOpening(o, "nop") << ')';
}
};
class Block : public Expression {
public:
+ Block() : Expression(BlockId) {}
+
Name name;
ExpressionList list;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "block");
if (name.is()) {
o << ' ' << name;
@@ -294,9 +332,11 @@ public:
class If : public Expression {
public:
+ If() : Expression(IfId) {}
+
Expression *condition, *ifTrue, *ifFalse;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "if");
incIndent(o, indent);
printFullLine(o, indent, condition);
@@ -308,10 +348,12 @@ public:
class Loop : public Expression {
public:
+ Loop() : Expression(LoopId) {}
+
Name out, in;
Expression *body;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "loop");
if (out.is()) {
o << ' ' << out;
@@ -327,16 +369,27 @@ public:
class Label : public Expression {
public:
+ Label() : Expression(LabelId) {}
+
Name name;
Expression* body;
+
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
+ printOpening(o, "label ") << name;
+ incIndent(o, indent);
+ printFullLine(o, indent, body);
+ return decIndent(o, indent);
+ }
};
class Break : public Expression {
public:
+ Break() : Expression(BreakId) {}
+
Name name;
Expression *condition, *value;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "break ") << name;
incIndent(o, indent);
if (condition) printFullLine(o, indent, condition);
@@ -347,6 +400,8 @@ public:
class Switch : public Expression {
public:
+ Switch() : Expression(SwitchId) {}
+
struct Case {
Literal value;
Expression *body;
@@ -358,7 +413,7 @@ public:
std::vector<Case> cases;
Expression *default_;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "switch ") << name;
incIndent(o, indent);
printFullLine(o, indent, value);
@@ -370,6 +425,8 @@ public:
class Call : public Expression {
public:
+ Call() : Expression(CallId) {}
+
Name target;
ExpressionList operands;
@@ -387,14 +444,19 @@ public:
return o;
}
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "call ");
return printBody(o, indent);
}
};
class CallImport : public Call {
- std::ostream& print(std::ostream &o, unsigned indent) override {
+public:
+ CallImport() {
+ _id = CallImportId;
+ }
+
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "call_import ");
return printBody(o, indent);
}
@@ -444,11 +506,13 @@ public:
class CallIndirect : public Expression {
public:
+ CallIndirect() : Expression(CallIndirectId) {}
+
FunctionType *type;
Expression *target;
ExpressionList operands;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "call_indirect ") << type->name;
incIndent(o, indent);
printFullLine(o, indent, target);
@@ -461,19 +525,23 @@ public:
class GetLocal : public Expression {
public:
+ GetLocal() : Expression(GetLocalId) {}
+
Name name;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
return printOpening(o, "get_local ") << name << ')';
}
};
class SetLocal : public Expression {
public:
+ SetLocal() : Expression(SetLocalId) {}
+
Name name;
Expression *value;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
printOpening(o, "set_local ") << name;
incIndent(o, indent);
printFullLine(o, indent, value);
@@ -483,6 +551,8 @@ public:
class Load : public Expression {
public:
+ Load() : Expression(LoadId) {}
+
unsigned bytes;
bool signed_;
bool float_;
@@ -490,7 +560,7 @@ public:
unsigned align;
Expression *ptr;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o) << printWasmType(getWasmType(bytes, float_)) << ".load";
if (bytes < 4) {
@@ -514,13 +584,15 @@ public:
class Store : public Expression {
public:
+ Store() : Expression(StoreId) {}
+
unsigned bytes;
bool float_;
int offset;
unsigned align;
Expression *ptr, *value;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o) << printWasmType(getWasmType(bytes, float_)) << ".store";
if (bytes < 4) {
@@ -544,6 +616,8 @@ public:
class Const : public Expression {
public:
+ Const() : Expression(ConstId) {}
+
Literal value;
Const* set(Literal value_) {
@@ -552,17 +626,19 @@ public:
return this;
}
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
return o << value;
}
};
class Unary : public Expression {
public:
+ Unary() : Expression(UnaryId) {}
+
UnaryOp op;
Expression *value;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o) << printWasmType(type) << '.';
switch (op) {
@@ -579,10 +655,12 @@ public:
class Binary : public Expression {
public:
+ Binary() : Expression(BinaryId) {}
+
BinaryOp op;
Expression *left, *right;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o) << printWasmType(type) << '.';
switch (op) {
@@ -615,15 +693,15 @@ public:
class Compare : public Expression {
public:
+ Compare() : Expression(CompareId) {
+ type = WasmType::i32; // output is always i32
+ }
+
RelationalOp op;
WasmType inputType;
Expression *left, *right;
- Compare() {
- type = WasmType::i32; // output is always i32
- }
-
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o) << printWasmType(inputType) << '.';
switch (op) {
@@ -653,10 +731,12 @@ public:
class Convert : public Expression {
public:
+ Convert() : Expression(ConvertId) {}
+
ConvertOp op;
Expression *value;
- std::ostream& print(std::ostream &o, unsigned indent) override {
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
prepareColor(o);
switch (op) {
@@ -674,8 +754,14 @@ public:
class Host : public Expression {
public:
+ Host() : Expression(HostId) {}
+
HostOp op;
ExpressionList operands;
+
+ std::ostream& doPrint(std::ostream &o, unsigned indent) {
+ abort();
+ }
};
// Globals
@@ -805,7 +891,7 @@ public:
};
//
-// Simple WebAssembly AST visiting and children-first walking
+// Simple WebAssembly AST visiting
//
template<class ReturnType>
@@ -833,30 +919,73 @@ struct WasmVisitor {
ReturnType visit(Expression *curr) {
assert(curr);
- if (Block *cast = dynamic_cast<Block*>(curr)) return visitBlock(cast);
- if (If *cast = dynamic_cast<If*>(curr)) return visitIf(cast);
- if (Loop *cast = dynamic_cast<Loop*>(curr)) return visitLoop(cast);
- if (Label *cast = dynamic_cast<Label*>(curr)) return visitLabel(cast);
- if (Break *cast = dynamic_cast<Break*>(curr)) return visitBreak(cast);
- if (Switch *cast = dynamic_cast<Switch*>(curr)) return visitSwitch(cast);
- if (Call *cast = dynamic_cast<Call*>(curr)) return visitCall(cast);
- if (CallImport *cast = dynamic_cast<CallImport*>(curr)) return visitCallImport(cast);
- if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) return visitCallIndirect(cast);
- if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) return visitGetLocal(cast);
- if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) return visitSetLocal(cast);
- if (Load *cast = dynamic_cast<Load*>(curr)) return visitLoad(cast);
- if (Store *cast = dynamic_cast<Store*>(curr)) return visitStore(cast);
- if (Const *cast = dynamic_cast<Const*>(curr)) return visitConst(cast);
- if (Unary *cast = dynamic_cast<Unary*>(curr)) return visitUnary(cast);
- if (Binary *cast = dynamic_cast<Binary*>(curr)) return visitBinary(cast);
- if (Compare *cast = dynamic_cast<Compare*>(curr)) return visitCompare(cast);
- if (Convert *cast = dynamic_cast<Convert*>(curr)) return visitConvert(cast);
- if (Host *cast = dynamic_cast<Host*>(curr)) return visitHost(cast);
- if (Nop *cast = dynamic_cast<Nop*>(curr)) return visitNop(cast);
- abort();
+ switch (curr->_id) {
+ case Expression::Id::BlockId: return visitBlock((Block*)curr);
+ case Expression::Id::IfId: return visitIf((If*)curr);
+ case Expression::Id::LoopId: return visitLoop((Loop*)curr);
+ case Expression::Id::LabelId: return visitLabel((Label*)curr);
+ case Expression::Id::BreakId: return visitBreak((Break*)curr);
+ case Expression::Id::SwitchId: return visitSwitch((Switch*)curr);
+ case Expression::Id::CallId: return visitCall((Call*)curr);
+ case Expression::Id::CallImportId: return visitCallImport((CallImport*)curr);
+ case Expression::Id::CallIndirectId: return visitCallIndirect((CallIndirect*)curr);
+ case Expression::Id::GetLocalId: return visitGetLocal((GetLocal*)curr);
+ case Expression::Id::SetLocalId: return visitSetLocal((SetLocal*)curr);
+ case Expression::Id::LoadId: return visitLoad((Load*)curr);
+ case Expression::Id::StoreId: return visitStore((Store*)curr);
+ case Expression::Id::ConstId: return visitConst((Const*)curr);
+ case Expression::Id::UnaryId: return visitUnary((Unary*)curr);
+ case Expression::Id::BinaryId: return visitBinary((Binary*)curr);
+ case Expression::Id::CompareId: return visitCompare((Compare*)curr);
+ case Expression::Id::ConvertId: return visitConvert((Convert*)curr);
+ case Expression::Id::HostId: return visitHost((Host*)curr);
+ case Expression::Id::NopId: return visitNop((Nop*)curr);
+ default: {
+ std::cerr << "visiting unknown expression " << curr->_id << '\n';
+ abort();
+ }
+ }
}
};
+std::ostream& Expression::print(std::ostream &o, unsigned indent) {
+ struct ExpressionPrinter : public WasmVisitor<void> {
+ std::ostream &o;
+ unsigned indent;
+
+ ExpressionPrinter(std::ostream &o, unsigned indent) : o(o), indent(indent) {}
+
+ void visitBlock(Block *curr) override { curr->doPrint(o, indent); }
+ void visitIf(If *curr) override { curr->doPrint(o, indent); }
+ void visitLoop(Loop *curr) override { curr->doPrint(o, indent); }
+ void visitLabel(Label *curr) override { curr->doPrint(o, indent); }
+ void visitBreak(Break *curr) override { curr->doPrint(o, indent); }
+ void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); }
+ void visitCall(Call *curr) override { curr->doPrint(o, indent); }
+ void visitCallImport(CallImport *curr) override { curr->doPrint(o, indent); }
+ void visitCallIndirect(CallIndirect *curr) override { curr->doPrint(o, indent); }
+ void visitGetLocal(GetLocal *curr) override { curr->doPrint(o, indent); }
+ void visitSetLocal(SetLocal *curr) override { curr->doPrint(o, indent); }
+ void visitLoad(Load *curr) override { curr->doPrint(o, indent); }
+ void visitStore(Store *curr) override { curr->doPrint(o, indent); }
+ void visitConst(Const *curr) override { curr->doPrint(o, indent); }
+ void visitUnary(Unary *curr) override { curr->doPrint(o, indent); }
+ void visitBinary(Binary *curr) override { curr->doPrint(o, indent); }
+ void visitCompare(Compare *curr) override { curr->doPrint(o, indent); }
+ void visitConvert(Convert *curr) override { curr->doPrint(o, indent); }
+ void visitHost(Host *curr) override { curr->doPrint(o, indent); }
+ void visitNop(Nop *curr) override { curr->doPrint(o, indent); }
+ };
+
+ ExpressionPrinter(o, indent).visit(this);
+
+ return o;
+}
+
+//
+// Simple WebAssembly children-first walking
+//
+
struct WasmWalker : public WasmVisitor<Expression*> {
wasm::Arena* allocator; // use an existing allocator, or null if no allocations
@@ -889,71 +1018,93 @@ struct WasmWalker : public WasmVisitor<Expression*> {
Expression *walk(Expression *curr) {
if (!curr) return curr;
- if (Block *cast = dynamic_cast<Block*>(curr)) {
- ExpressionList& list = cast->list;
- for (size_t z = 0; z < list.size(); z++) {
- list[z] = walk(list[z]);
+ struct ChildWalker : public WasmVisitor<void> {
+ WasmWalker& parent;
+
+ ChildWalker(WasmWalker& parent) : parent(parent) {}
+
+ void visitBlock(Block *curr) override {
+ ExpressionList& list = curr->list;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = parent.walk(list[z]);
+ }
}
- } else if (If *cast = dynamic_cast<If*>(curr)) {
- cast->condition = walk(cast->condition);
- cast->ifTrue = walk(cast->ifTrue);
- cast->ifFalse = walk(cast->ifFalse);
- } else if (Loop *cast = dynamic_cast<Loop*>(curr)) {
- cast->body = walk(cast->body);
- } else if (Label *cast = dynamic_cast<Label*>(curr)) {
- } else if (Break *cast = dynamic_cast<Break*>(curr)) {
- cast->condition = walk(cast->condition);
- cast->value = walk(cast->value);
- } else if (Switch *cast = dynamic_cast<Switch*>(curr)) {
- cast->value = walk(cast->value);
- for (auto& curr : cast->cases) {
- curr.body = walk(curr.body);
+ void visitIf(If *curr) override {
+ curr->condition = parent.walk(curr->condition);
+ curr->ifTrue = parent.walk(curr->ifTrue);
+ curr->ifFalse = parent.walk(curr->ifFalse);
}
- cast->default_ = walk(cast->default_);
- } else if (Call *cast = dynamic_cast<Call*>(curr)) {
- ExpressionList& list = cast->operands;
- for (size_t z = 0; z < list.size(); z++) {
- list[z] = walk(list[z]);
+ void visitLoop(Loop *curr) override {
+ curr->body = parent.walk(curr->body);
}
- } else if (CallImport *cast = dynamic_cast<CallImport*>(curr)) {
- ExpressionList& list = cast->operands;
- for (size_t z = 0; z < list.size(); z++) {
- list[z] = walk(list[z]);
+ void visitLabel(Label *curr) override {}
+ void visitBreak(Break *curr) override {
+ curr->condition = parent.walk(curr->condition);
+ curr->value = parent.walk(curr->value);
}
- } else if (CallIndirect *cast = dynamic_cast<CallIndirect*>(curr)) {
- cast->target = walk(cast->target);
- ExpressionList& list = cast->operands;
- for (size_t z = 0; z < list.size(); z++) {
- list[z] = walk(list[z]);
+ void visitSwitch(Switch *curr) override {
+ curr->value = parent.walk(curr->value);
+ for (auto& case_ : curr->cases) {
+ case_.body = parent.walk(case_.body);
+ }
+ curr->default_ = parent.walk(curr->default_);
}
- } else if (GetLocal *cast = dynamic_cast<GetLocal*>(curr)) {
- } else if (SetLocal *cast = dynamic_cast<SetLocal*>(curr)) {
- cast->value = walk(cast->value);
- } else if (Load *cast = dynamic_cast<Load*>(curr)) {
- cast->ptr = walk(cast->ptr);
- } else if (Store *cast = dynamic_cast<Store*>(curr)) {
- cast->ptr = walk(cast->ptr);
- cast->value = walk(cast->value);
- } else if (Const *cast = dynamic_cast<Const*>(curr)) {
- } else if (Unary *cast = dynamic_cast<Unary*>(curr)) {
- cast->value = walk(cast->value);
- } else if (Binary *cast = dynamic_cast<Binary*>(curr)) {
- cast->left = walk(cast->left);
- cast->right = walk(cast->right);
- } else if (Compare *cast = dynamic_cast<Compare*>(curr)) {
- cast->left = walk(cast->left);
- cast->right = walk(cast->right);
- } else if (Convert *cast = dynamic_cast<Convert*>(curr)) {
- cast->value = walk(cast->value);
- } else if (Host *cast = dynamic_cast<Host*>(curr)) {
- ExpressionList& list = cast->operands;
- for (size_t z = 0; z < list.size(); z++) {
- list[z] = walk(list[z]);
+ void visitCall(Call *curr) override {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = parent.walk(list[z]);
+ }
}
- } else if (Nop *cast = dynamic_cast<Nop*>(curr)) {
- } else {
- abort();
- }
+ void visitCallImport(CallImport *curr) override {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = parent.walk(list[z]);
+ }
+ }
+ void visitCallIndirect(CallIndirect *curr) override {
+ curr->target = parent.walk(curr->target);
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = parent.walk(list[z]);
+ }
+ }
+ void visitGetLocal(GetLocal *curr) override {}
+ void visitSetLocal(SetLocal *curr) override {
+ curr->value = parent.walk(curr->value);
+ }
+ void visitLoad(Load *curr) override {
+ curr->ptr = parent.walk(curr->ptr);
+ }
+ void visitStore(Store *curr) override {
+ curr->ptr = parent.walk(curr->ptr);
+ curr->value = parent.walk(curr->value);
+ }
+ void visitConst(Const *curr) override {}
+ void visitUnary(Unary *curr) override {
+ curr->value = parent.walk(curr->value);
+ }
+ void visitBinary(Binary *curr) override {
+ curr->left = parent.walk(curr->left);
+ curr->right = parent.walk(curr->right);
+ }
+ void visitCompare(Compare *curr) override {
+ curr->left = parent.walk(curr->left);
+ curr->right = parent.walk(curr->right);
+ }
+ void visitConvert(Convert *curr) override {
+ curr->value = parent.walk(curr->value);
+ }
+ void visitHost(Host *curr) override {
+ ExpressionList& list = curr->operands;
+ for (size_t z = 0; z < list.size(); z++) {
+ list[z] = parent.walk(list[z]);
+ }
+ }
+ void visitNop(Nop *curr) override {}
+ };
+
+ ChildWalker(*this).visit(curr);
+
return visit(curr);
}