summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h347
1 files changed, 173 insertions, 174 deletions
diff --git a/src/wasm.h b/src/wasm.h
index d9e25783a..af3917744 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -15,6 +15,17 @@
// * Validation: See wasm-validator.h.
//
+//
+// wasm.js internal WebAssembly representation design:
+//
+// * Unify where possible. Where size isn't a concern, combine
+// classes, so binary ops and relational ops are joined. This
+// simplifies that AST and makes traversals easier.
+// * Optimize for size? This might justify separating if and if_else
+// (so that if doesn't have an always-empty else; also it avoids
+// a branch).
+//
+
#ifndef __wasm_h__
#define __wasm_h__
@@ -153,7 +164,7 @@ struct Literal {
}
}
- void printFloat(std::ostream &o, float f) {
+ static void printFloat(std::ostream &o, float f) {
if (isnan(f)) {
union {
float ff;
@@ -166,7 +177,7 @@ struct Literal {
printDouble(o, f);
}
- void printDouble(std::ostream &o, double d) {
+ static void printDouble(std::ostream &o, double d) {
if (d == 0 && 1/d < 0) {
o << "-0";
return;
@@ -210,26 +221,22 @@ struct Literal {
enum UnaryOp {
Clz, Ctz, Popcnt, // int
- Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt // float
+ Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt, // float
+ // conversions
+ ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int
+ ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float
};
enum BinaryOp {
Add, Sub, Mul, // int or float
DivS, DivU, RemS, RemU, And, Or, Xor, Shl, ShrU, ShrS, // int
- Div, CopySign, Min, Max // float
-};
-
-enum RelationalOp {
+ Div, CopySign, Min, Max, // float
+ // relational ops
Eq, Ne, // int or float
LtS, LtU, LeS, LeU, GtS, GtU, GeS, GeU, // int
Lt, Le, Gt, Ge // float
};
-enum ConvertOp {
- ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int
- ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float
-};
-
enum HostOp {
PageSize, MemorySize, GrowMemory, HasFeature
};
@@ -253,28 +260,25 @@ class Expression {
public:
enum Id {
InvalidId = 0,
- BlockId = 1,
- IfId = 2,
- LoopId = 3,
- LabelId = 4,
- BreakId = 5,
- SwitchId =6 ,
- CallId = 7,
- CallImportId = 8,
- CallIndirectId = 9,
- GetLocalId = 10,
- SetLocalId = 11,
- LoadId = 12,
- StoreId = 13,
- ConstId = 14,
- UnaryId = 15,
- BinaryId = 16,
- CompareId = 17,
- ConvertId = 18,
- SelectId = 19,
- HostId = 20,
- NopId = 21,
- UnreachableId = 22
+ BlockId,
+ IfId,
+ LoopId,
+ BreakId,
+ SwitchId,
+ CallId,
+ CallImportId,
+ CallIndirectId,
+ GetLocalId,
+ SetLocalId,
+ LoadId,
+ StoreId,
+ ConstId,
+ UnaryId,
+ BinaryId,
+ SelectId,
+ HostId,
+ NopId,
+ UnreachableId
};
Id _id;
@@ -293,6 +297,12 @@ public:
return _id == T()._id ? (T*)this : nullptr;
}
+ template<class T>
+ T* cast() {
+ assert(_id == T()._id);
+ return (T*)this;
+ }
+
inline std::ostream& print(std::ostream &o, unsigned indent); // avoid virtual here, for performance
friend std::ostream& operator<<(std::ostream &o, Expression* expression) {
@@ -319,7 +329,9 @@ public:
class Block : public Expression {
public:
- Block() : Expression(BlockId) {}
+ Block() : Expression(BlockId) {
+ type = none; // blocks by default do not return, but if their last statement does, they might
+ }
Name name;
ExpressionList list;
@@ -339,7 +351,9 @@ public:
class If : public Expression {
public:
- If() : Expression(IfId), ifFalse(nullptr) {}
+ If() : Expression(IfId), ifFalse(nullptr) {
+ type = none; // by default none; if-else can have one, though
+ }
Expression *condition, *ifTrue, *ifFalse;
@@ -375,31 +389,29 @@ public:
}
};
-class Label : public Expression {
-public:
- Label() : Expression(LabelId) {}
-
- Name name;
- Expression* body;
-
- std::ostream& doPrint(std::ostream &o, unsigned indent) {
- printOpening(o, "label ") << name;
- incIndent(o, indent);
- printFullLine(o, indent, body);
- return decIndent(o, indent);
- }
-};
-
class Break : public Expression {
public:
- Break() : Expression(BreakId), value(nullptr) {}
+ Break() : Expression(BreakId), condition(nullptr), value(nullptr) {}
+ Expression *condition;
Name name;
Expression *value;
std::ostream& doPrint(std::ostream &o, unsigned indent) {
- printOpening(o, "br ") << name;
- incIndent(o, indent);
+ if (condition) {
+ printOpening(o, "br_if");
+ incIndent(o, indent);
+ printFullLine(o, indent, condition);
+ doIndent(o, indent) << name << '\n';
+ } else {
+ printOpening(o, "br ") << name;
+ if (!value) {
+ // avoid a new line just for the parens
+ o << ")";
+ return o;
+ }
+ incIndent(o, indent);
+ }
if (value) printFullLine(o, indent, value);
return decIndent(o, indent);
}
@@ -428,14 +440,15 @@ public:
incIndent(o, indent);
printFullLine(o, indent, value);
doIndent(o, indent) << "(table";
- assert(default_.is());
for (auto& t : targets) {
o << " (case " << (t.is() ? t : default_) << ")";
}
- o << ") (case " << default_ << ")\n";
+ o << ")";
+ if (default_.is()) o << " (case " << default_ << ")";
+ o << "\n";
for (auto& c : cases) {
doIndent(o, indent);
- printMinorOpening(o, "case ") << c.name.str;
+ printMinorOpening(o, "case ") << c.name;
incIndent(o, indent);
printFullLine(o, indent, c.body);
decIndent(o, indent) << '\n';
@@ -445,12 +458,18 @@ public:
};
-class Call : public Expression {
+class CallBase : public Expression {
public:
- Call() : Expression(CallId) {}
+ CallBase(Id which) : Expression(which) {}
- Name target;
ExpressionList operands;
+};
+
+class Call : public CallBase {
+public:
+ Call() : CallBase(CallId) {}
+
+ Name target;
std::ostream& printBody(std::ostream &o, unsigned indent) {
o << target;
@@ -526,16 +545,15 @@ public:
}
};
-class CallIndirect : public Expression {
+class CallIndirect : public CallBase {
public:
- CallIndirect() : Expression(CallIndirectId) {}
+ CallIndirect() : CallBase(CallIndirectId) {}
- FunctionType *type;
+ FunctionType *fullType;
Expression *target;
- ExpressionList operands;
std::ostream& doPrint(std::ostream &o, unsigned indent) {
- printOpening(o, "call_indirect ") << type->name;
+ printOpening(o, "call_indirect ") << fullType->name;
incIndent(o, indent);
printFullLine(o, indent, target);
for (auto operand : operands) {
@@ -670,16 +688,31 @@ public:
o << '(';
prepareColor(o) << printWasmType(type) << '.';
switch (op) {
- case Clz: o << "clz"; break;
- case Ctz: o << "ctz"; break;
- case Popcnt: o << "popcnt"; break;
- case Neg: o << "neg"; break;
- case Abs: o << "abs"; break;
- case Ceil: o << "ceil"; break;
- case Floor: o << "floor"; break;
- case Trunc: o << "trunc"; break;
- case Nearest: o << "nearest"; break;
- case Sqrt: o << "sqrt"; break;
+ case Clz: o << "clz"; break;
+ case Ctz: o << "ctz"; break;
+ case Popcnt: o << "popcnt"; break;
+ case Neg: o << "neg"; break;
+ case Abs: o << "abs"; break;
+ case Ceil: o << "ceil"; break;
+ case Floor: o << "floor"; break;
+ case Trunc: o << "trunc"; break;
+ case Nearest: o << "nearest"; break;
+ case Sqrt: o << "sqrt"; break;
+ case ExtendSInt32: o << "extend_s/i32"; break;
+ case ExtendUInt32: o << "extend_u/i32"; break;
+ case WrapInt64: o << "wrap/i64"; break;
+ case TruncSFloat32: o << "trunc_s/f32"; break;
+ case TruncUFloat32: o << "trunc_u/f32"; break;
+ case TruncSFloat64: o << "trunc_s/f64"; break;
+ case TruncUFloat64: o << "trunc_u/f64"; break;
+ case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break;
+ case ConvertUInt32: o << "convert_u/i32"; break;
+ case ConvertSInt32: o << "convert_s/i32"; break;
+ case ConvertUInt64: o << "convert_u/i64"; break;
+ case ConvertSInt64: o << "convert_s/i64"; break;
+ case PromoteFloat32: o << "promote/f32"; break;
+ case DemoteFloat64: o << "demote/f64"; break;
+ case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break;
default: abort();
}
incIndent(o, indent);
@@ -697,7 +730,7 @@ public:
std::ostream& doPrint(std::ostream &o, unsigned indent) {
o << '(';
- prepareColor(o) << printWasmType(type) << '.';
+ prepareColor(o) << printWasmType(isRelational() ? left->type : type) << '.';
switch (op) {
case Add: o << "add"; break;
case Sub: o << "sub"; break;
@@ -716,7 +749,21 @@ public:
case CopySign: o << "copysign"; break;
case Min: o << "min"; break;
case Max: o << "max"; break;
- default: abort();
+ case Eq: o << "eq"; break;
+ case Ne: o << "ne"; break;
+ case LtS: o << "lt_s"; break;
+ case LtU: o << "lt_u"; break;
+ case LeS: o << "le_s"; break;
+ case LeU: o << "le_u"; break;
+ case GtS: o << "gt_s"; break;
+ case GtU: o << "gt_u"; break;
+ case GeS: o << "ge_s"; break;
+ case GeU: o << "ge_u"; break;
+ case Lt: o << "lt"; break;
+ case Le: o << "le"; break;
+ case Gt: o << "gt"; break;
+ case Ge: o << "ge"; break;
+ default: abort();
}
restoreNormalColor(o);
incIndent(o, indent);
@@ -725,82 +772,18 @@ public:
return decIndent(o, indent);
}
- // the type is always the type of the operands
- void finalize() {
- type = left->type;
- }
-};
+ // the type is always the type of the operands,
+ // except for relationals
-class Compare : public Expression {
-public:
- Compare() : Expression(CompareId) {
- type = WasmType::i32; // output is always i32
- }
+ bool isRelational() { return op >= Eq; }
- RelationalOp op;
- WasmType inputType;
- Expression *left, *right;
-
- std::ostream& doPrint(std::ostream &o, unsigned indent) {
- o << '(';
- prepareColor(o) << printWasmType(inputType) << '.';
- switch (op) {
- case Eq: o << "eq"; break;
- case Ne: o << "ne"; break;
- case LtS: o << "lt_s"; break;
- case LtU: o << "lt_u"; break;
- case LeS: o << "le_s"; break;
- case LeU: o << "le_u"; break;
- case GtS: o << "gt_s"; break;
- case GtU: o << "gt_u"; break;
- case GeS: o << "ge_s"; break;
- case GeU: o << "ge_u"; break;
- case Lt: o << "lt"; break;
- case Le: o << "le"; break;
- case Gt: o << "gt"; break;
- case Ge: o << "ge"; break;
- default: abort();
- }
- restoreNormalColor(o);
- incIndent(o, indent);
- printFullLine(o, indent, left);
- printFullLine(o, indent, right);
- return decIndent(o, indent);
- }
-};
-
-class Convert : public Expression {
-public:
- Convert() : Expression(ConvertId) {}
-
- ConvertOp op;
- Expression *value;
-
- std::ostream& doPrint(std::ostream &o, unsigned indent) {
- o << '(';
- prepareColor(o) << printWasmType(type) << '.';
- switch (op) {
- case ExtendSInt32: o << "extend_s/i32"; break;
- case ExtendUInt32: o << "extend_u/i32"; break;
- case WrapInt64: o << "wrap/i64"; break;
- case TruncSFloat32: o << "trunc_s/f32"; break;
- case TruncUFloat32: o << "trunc_u/f32"; break;
- case TruncSFloat64: o << "trunc_s/f64"; break;
- case TruncUFloat64: o << "trunc_u/f64"; break;
- case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break;
- case ConvertUInt32: o << "convert_u/i32"; break;
- case ConvertSInt32: o << "convert_s/i32"; break;
- case ConvertUInt64: o << "convert_u/i64"; break;
- case ConvertSInt64: o << "convert_s/i64"; break;
- case PromoteFloat32: o << "promote/f32"; break;
- case DemoteFloat64: o << "demote/f64"; break;
- case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break;
- default: abort();
+ void finalize() {
+ if (isRelational()) {
+ type = i32;
+ } else {
+ assert(left->type == right->type);
+ type = left->type;
}
- restoreNormalColor(o);
- incIndent(o, indent);
- printFullLine(o, indent, value);
- return decIndent(o, indent);
}
};
@@ -832,7 +815,7 @@ public:
std::ostream& doPrint(std::ostream &o, unsigned indent) {
switch (op) {
case PageSize: printOpening(o, "pagesize") << ')'; break;
- case MemorySize: printOpening(o, "memorysize") << ')'; break;
+ case MemorySize: printOpening(o, "memory_size") << ')'; break;
case GrowMemory: {
printOpening(o, "grow_memory");
incIndent(o, indent);
@@ -845,6 +828,20 @@ public:
}
return o;
}
+
+ void finalize() {
+ switch (op) {
+ case PageSize: case MemorySize: case HasFeature: {
+ type = i32;
+ break;
+ }
+ case GrowMemory: {
+ type = none;
+ break;
+ }
+ default: abort();
+ }
+ }
};
class Unreachable : public Expression {
@@ -906,7 +903,7 @@ public:
std::ostream& print(std::ostream &o, unsigned indent) {
printOpening(o, "import ") << name << ' ';
printText(o, module.str) << ' ';
- printText(o, base.str) << ' ';
+ printText(o, base.str);
type.print(o, indent);
return o << ')';
}
@@ -1028,7 +1025,28 @@ public:
printOpening(o, "memory") << " " << module.memory.initial;
if (module.memory.max) o << " " << module.memory.max;
for (auto segment : module.memory.segments) {
- o << " (segment " << segment.offset << " \"" << segment.data << "\")";
+ o << " (segment " << segment.offset << " \"";
+ for (size_t i = 0; i < segment.size; i++) {
+ unsigned char c = segment.data[i];
+ switch (c) {
+ case '\n': o << "\\n"; break;
+ case '\r': o << "\\0d"; break;
+ case '\t': o << "\\t"; break;
+ case '\f': o << "\\0c"; break;
+ case '\b': o << "\\08"; break;
+ case '\\': o << "\\\\"; break;
+ case '"' : o << "\\\""; break;
+ case '\'' : o << "\\'"; break;
+ default: {
+ if (c >= 32 && c < 127) {
+ o << c;
+ } else {
+ o << std::hex << '\\' << (c/16) << (c%16) << std::dec;
+ }
+ }
+ }
+ }
+ o << "\")";
}
o << ")\n";
for (auto& curr : module.functionTypes) {
@@ -1077,7 +1095,6 @@ struct WasmVisitor {
virtual ReturnType visitBlock(Block *curr) { abort(); }
virtual ReturnType visitIf(If *curr) { abort(); }
virtual ReturnType visitLoop(Loop *curr) { abort(); }
- virtual ReturnType visitLabel(Label *curr) { abort(); }
virtual ReturnType visitBreak(Break *curr) { abort(); }
virtual ReturnType visitSwitch(Switch *curr) { abort(); }
virtual ReturnType visitCall(Call *curr) { abort(); }
@@ -1090,8 +1107,6 @@ struct WasmVisitor {
virtual ReturnType visitConst(Const *curr) { abort(); }
virtual ReturnType visitUnary(Unary *curr) { abort(); }
virtual ReturnType visitBinary(Binary *curr) { abort(); }
- virtual ReturnType visitCompare(Compare *curr) { abort(); }
- virtual ReturnType visitConvert(Convert *curr) { abort(); }
virtual ReturnType visitSelect(Select *curr) { abort(); }
virtual ReturnType visitHost(Host *curr) { abort(); }
virtual ReturnType visitNop(Nop *curr) { abort(); }
@@ -1111,7 +1126,6 @@ struct WasmVisitor {
case Expression::Id::BlockId: return visitBlock((Block*)curr);
case Expression::Id::IfId: return visitIf((If*)curr);
case Expression::Id::LoopId: return visitLoop((Loop*)curr);
- case Expression::Id::LabelId: return visitLabel((Label*)curr);
case Expression::Id::BreakId: return visitBreak((Break*)curr);
case Expression::Id::SwitchId: return visitSwitch((Switch*)curr);
case Expression::Id::CallId: return visitCall((Call*)curr);
@@ -1124,8 +1138,6 @@ struct WasmVisitor {
case Expression::Id::ConstId: return visitConst((Const*)curr);
case Expression::Id::UnaryId: return visitUnary((Unary*)curr);
case Expression::Id::BinaryId: return visitBinary((Binary*)curr);
- case Expression::Id::CompareId: return visitCompare((Compare*)curr);
- case Expression::Id::ConvertId: return visitConvert((Convert*)curr);
case Expression::Id::SelectId: return visitSelect((Select*)curr);
case Expression::Id::HostId: return visitHost((Host*)curr);
case Expression::Id::NopId: return visitNop((Nop*)curr);
@@ -1148,7 +1160,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
void visitBlock(Block *curr) override { curr->doPrint(o, indent); }
void visitIf(If *curr) override { curr->doPrint(o, indent); }
void visitLoop(Loop *curr) override { curr->doPrint(o, indent); }
- void visitLabel(Label *curr) override { curr->doPrint(o, indent); }
void visitBreak(Break *curr) override { curr->doPrint(o, indent); }
void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); }
void visitCall(Call *curr) override { curr->doPrint(o, indent); }
@@ -1161,8 +1172,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
void visitConst(Const *curr) override { curr->doPrint(o, indent); }
void visitUnary(Unary *curr) override { curr->doPrint(o, indent); }
void visitBinary(Binary *curr) override { curr->doPrint(o, indent); }
- void visitCompare(Compare *curr) override { curr->doPrint(o, indent); }
- void visitConvert(Convert *curr) override { curr->doPrint(o, indent); }
void visitSelect(Select *curr) override { curr->doPrint(o, indent); }
void visitHost(Host *curr) override { curr->doPrint(o, indent); }
void visitNop(Nop *curr) override { curr->doPrint(o, indent); }
@@ -1176,7 +1185,7 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
//
// Simple WebAssembly children-first walking (i.e., post-order, if you look
-// at the children as subtrees of the current node), with the ability to
+// at the children as subtrees of the current node), with the ability to replace
// the current expression node. Useful for writing optimization passes.
//
@@ -1194,7 +1203,6 @@ struct WasmWalker : public WasmVisitor<void> {
void visitBlock(Block *curr) override {}
void visitIf(If *curr) override {}
void visitLoop(Loop *curr) override {}
- void visitLabel(Label *curr) override {}
void visitBreak(Break *curr) override {}
void visitSwitch(Switch *curr) override {}
void visitCall(Call *curr) override {}
@@ -1207,8 +1215,6 @@ struct WasmWalker : public WasmVisitor<void> {
void visitConst(Const *curr) override {}
void visitUnary(Unary *curr) override {}
void visitBinary(Binary *curr) override {}
- void visitCompare(Compare *curr) override {}
- void visitConvert(Convert *curr) override {}
void visitSelect(Select *curr) override {}
void visitHost(Host *curr) override {}
void visitNop(Nop *curr) override {}
@@ -1245,8 +1251,8 @@ struct WasmWalker : public WasmVisitor<void> {
void visitLoop(Loop *curr) override {
parent.walk(curr->body);
}
- void visitLabel(Label *curr) override {}
void visitBreak(Break *curr) override {
+ parent.walk(curr->condition);
parent.walk(curr->value);
}
void visitSwitch(Switch *curr) override {
@@ -1293,13 +1299,6 @@ struct WasmWalker : public WasmVisitor<void> {
parent.walk(curr->left);
parent.walk(curr->right);
}
- void visitCompare(Compare *curr) override {
- parent.walk(curr->left);
- parent.walk(curr->right);
- }
- void visitConvert(Convert *curr) override {
- parent.walk(curr->value);
- }
void visitSelect(Select *curr) override {
parent.walk(curr->condition);
parent.walk(curr->ifTrue);