diff options
Diffstat (limited to 'src/wasm.h')
-rw-r--r-- | src/wasm.h | 347 |
1 files changed, 173 insertions, 174 deletions
diff --git a/src/wasm.h b/src/wasm.h index d9e25783a..af3917744 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -15,6 +15,17 @@ // * Validation: See wasm-validator.h. // +// +// wasm.js internal WebAssembly representation design: +// +// * Unify where possible. Where size isn't a concern, combine +// classes, so binary ops and relational ops are joined. This +// simplifies that AST and makes traversals easier. +// * Optimize for size? This might justify separating if and if_else +// (so that if doesn't have an always-empty else; also it avoids +// a branch). +// + #ifndef __wasm_h__ #define __wasm_h__ @@ -153,7 +164,7 @@ struct Literal { } } - void printFloat(std::ostream &o, float f) { + static void printFloat(std::ostream &o, float f) { if (isnan(f)) { union { float ff; @@ -166,7 +177,7 @@ struct Literal { printDouble(o, f); } - void printDouble(std::ostream &o, double d) { + static void printDouble(std::ostream &o, double d) { if (d == 0 && 1/d < 0) { o << "-0"; return; @@ -210,26 +221,22 @@ struct Literal { enum UnaryOp { Clz, Ctz, Popcnt, // int - Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt // float + Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt, // float + // conversions + ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int + ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float }; enum BinaryOp { Add, Sub, Mul, // int or float DivS, DivU, RemS, RemU, And, Or, Xor, Shl, ShrU, ShrS, // int - Div, CopySign, Min, Max // float -}; - -enum RelationalOp { + Div, CopySign, Min, Max, // float + // relational ops Eq, Ne, // int or float LtS, LtU, LeS, LeU, GtS, GtU, GeS, GeU, // int Lt, Le, Gt, Ge // float }; -enum ConvertOp { - ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int - ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float -}; - enum HostOp { PageSize, MemorySize, GrowMemory, HasFeature }; @@ -253,28 +260,25 @@ class Expression { public: enum Id { InvalidId = 0, - BlockId = 1, - IfId = 2, - LoopId = 3, - LabelId = 4, - BreakId = 5, - SwitchId =6 , - CallId = 7, - CallImportId = 8, - CallIndirectId = 9, - GetLocalId = 10, - SetLocalId = 11, - LoadId = 12, - StoreId = 13, - ConstId = 14, - UnaryId = 15, - BinaryId = 16, - CompareId = 17, - ConvertId = 18, - SelectId = 19, - HostId = 20, - NopId = 21, - UnreachableId = 22 + BlockId, + IfId, + LoopId, + BreakId, + SwitchId, + CallId, + CallImportId, + CallIndirectId, + GetLocalId, + SetLocalId, + LoadId, + StoreId, + ConstId, + UnaryId, + BinaryId, + SelectId, + HostId, + NopId, + UnreachableId }; Id _id; @@ -293,6 +297,12 @@ public: return _id == T()._id ? (T*)this : nullptr; } + template<class T> + T* cast() { + assert(_id == T()._id); + return (T*)this; + } + inline std::ostream& print(std::ostream &o, unsigned indent); // avoid virtual here, for performance friend std::ostream& operator<<(std::ostream &o, Expression* expression) { @@ -319,7 +329,9 @@ public: class Block : public Expression { public: - Block() : Expression(BlockId) {} + Block() : Expression(BlockId) { + type = none; // blocks by default do not return, but if their last statement does, they might + } Name name; ExpressionList list; @@ -339,7 +351,9 @@ public: class If : public Expression { public: - If() : Expression(IfId), ifFalse(nullptr) {} + If() : Expression(IfId), ifFalse(nullptr) { + type = none; // by default none; if-else can have one, though + } Expression *condition, *ifTrue, *ifFalse; @@ -375,31 +389,29 @@ public: } }; -class Label : public Expression { -public: - Label() : Expression(LabelId) {} - - Name name; - Expression* body; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "label ") << name; - incIndent(o, indent); - printFullLine(o, indent, body); - return decIndent(o, indent); - } -}; - class Break : public Expression { public: - Break() : Expression(BreakId), value(nullptr) {} + Break() : Expression(BreakId), condition(nullptr), value(nullptr) {} + Expression *condition; Name name; Expression *value; std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "br ") << name; - incIndent(o, indent); + if (condition) { + printOpening(o, "br_if"); + incIndent(o, indent); + printFullLine(o, indent, condition); + doIndent(o, indent) << name << '\n'; + } else { + printOpening(o, "br ") << name; + if (!value) { + // avoid a new line just for the parens + o << ")"; + return o; + } + incIndent(o, indent); + } if (value) printFullLine(o, indent, value); return decIndent(o, indent); } @@ -428,14 +440,15 @@ public: incIndent(o, indent); printFullLine(o, indent, value); doIndent(o, indent) << "(table"; - assert(default_.is()); for (auto& t : targets) { o << " (case " << (t.is() ? t : default_) << ")"; } - o << ") (case " << default_ << ")\n"; + o << ")"; + if (default_.is()) o << " (case " << default_ << ")"; + o << "\n"; for (auto& c : cases) { doIndent(o, indent); - printMinorOpening(o, "case ") << c.name.str; + printMinorOpening(o, "case ") << c.name; incIndent(o, indent); printFullLine(o, indent, c.body); decIndent(o, indent) << '\n'; @@ -445,12 +458,18 @@ public: }; -class Call : public Expression { +class CallBase : public Expression { public: - Call() : Expression(CallId) {} + CallBase(Id which) : Expression(which) {} - Name target; ExpressionList operands; +}; + +class Call : public CallBase { +public: + Call() : CallBase(CallId) {} + + Name target; std::ostream& printBody(std::ostream &o, unsigned indent) { o << target; @@ -526,16 +545,15 @@ public: } }; -class CallIndirect : public Expression { +class CallIndirect : public CallBase { public: - CallIndirect() : Expression(CallIndirectId) {} + CallIndirect() : CallBase(CallIndirectId) {} - FunctionType *type; + FunctionType *fullType; Expression *target; - ExpressionList operands; std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "call_indirect ") << type->name; + printOpening(o, "call_indirect ") << fullType->name; incIndent(o, indent); printFullLine(o, indent, target); for (auto operand : operands) { @@ -670,16 +688,31 @@ public: o << '('; prepareColor(o) << printWasmType(type) << '.'; switch (op) { - case Clz: o << "clz"; break; - case Ctz: o << "ctz"; break; - case Popcnt: o << "popcnt"; break; - case Neg: o << "neg"; break; - case Abs: o << "abs"; break; - case Ceil: o << "ceil"; break; - case Floor: o << "floor"; break; - case Trunc: o << "trunc"; break; - case Nearest: o << "nearest"; break; - case Sqrt: o << "sqrt"; break; + case Clz: o << "clz"; break; + case Ctz: o << "ctz"; break; + case Popcnt: o << "popcnt"; break; + case Neg: o << "neg"; break; + case Abs: o << "abs"; break; + case Ceil: o << "ceil"; break; + case Floor: o << "floor"; break; + case Trunc: o << "trunc"; break; + case Nearest: o << "nearest"; break; + case Sqrt: o << "sqrt"; break; + case ExtendSInt32: o << "extend_s/i32"; break; + case ExtendUInt32: o << "extend_u/i32"; break; + case WrapInt64: o << "wrap/i64"; break; + case TruncSFloat32: o << "trunc_s/f32"; break; + case TruncUFloat32: o << "trunc_u/f32"; break; + case TruncSFloat64: o << "trunc_s/f64"; break; + case TruncUFloat64: o << "trunc_u/f64"; break; + case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; + case ConvertUInt32: o << "convert_u/i32"; break; + case ConvertSInt32: o << "convert_s/i32"; break; + case ConvertUInt64: o << "convert_u/i64"; break; + case ConvertSInt64: o << "convert_s/i64"; break; + case PromoteFloat32: o << "promote/f32"; break; + case DemoteFloat64: o << "demote/f64"; break; + case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; default: abort(); } incIndent(o, indent); @@ -697,7 +730,7 @@ public: std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; - prepareColor(o) << printWasmType(type) << '.'; + prepareColor(o) << printWasmType(isRelational() ? left->type : type) << '.'; switch (op) { case Add: o << "add"; break; case Sub: o << "sub"; break; @@ -716,7 +749,21 @@ public: case CopySign: o << "copysign"; break; case Min: o << "min"; break; case Max: o << "max"; break; - default: abort(); + case Eq: o << "eq"; break; + case Ne: o << "ne"; break; + case LtS: o << "lt_s"; break; + case LtU: o << "lt_u"; break; + case LeS: o << "le_s"; break; + case LeU: o << "le_u"; break; + case GtS: o << "gt_s"; break; + case GtU: o << "gt_u"; break; + case GeS: o << "ge_s"; break; + case GeU: o << "ge_u"; break; + case Lt: o << "lt"; break; + case Le: o << "le"; break; + case Gt: o << "gt"; break; + case Ge: o << "ge"; break; + default: abort(); } restoreNormalColor(o); incIndent(o, indent); @@ -725,82 +772,18 @@ public: return decIndent(o, indent); } - // the type is always the type of the operands - void finalize() { - type = left->type; - } -}; + // the type is always the type of the operands, + // except for relationals -class Compare : public Expression { -public: - Compare() : Expression(CompareId) { - type = WasmType::i32; // output is always i32 - } + bool isRelational() { return op >= Eq; } - RelationalOp op; - WasmType inputType; - Expression *left, *right; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(inputType) << '.'; - switch (op) { - case Eq: o << "eq"; break; - case Ne: o << "ne"; break; - case LtS: o << "lt_s"; break; - case LtU: o << "lt_u"; break; - case LeS: o << "le_s"; break; - case LeU: o << "le_u"; break; - case GtS: o << "gt_s"; break; - case GtU: o << "gt_u"; break; - case GeS: o << "ge_s"; break; - case GeU: o << "ge_u"; break; - case Lt: o << "lt"; break; - case Le: o << "le"; break; - case Gt: o << "gt"; break; - case Ge: o << "ge"; break; - default: abort(); - } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, left); - printFullLine(o, indent, right); - return decIndent(o, indent); - } -}; - -class Convert : public Expression { -public: - Convert() : Expression(ConvertId) {} - - ConvertOp op; - Expression *value; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(type) << '.'; - switch (op) { - case ExtendSInt32: o << "extend_s/i32"; break; - case ExtendUInt32: o << "extend_u/i32"; break; - case WrapInt64: o << "wrap/i64"; break; - case TruncSFloat32: o << "trunc_s/f32"; break; - case TruncUFloat32: o << "trunc_u/f32"; break; - case TruncSFloat64: o << "trunc_s/f64"; break; - case TruncUFloat64: o << "trunc_u/f64"; break; - case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; - case ConvertUInt32: o << "convert_u/i32"; break; - case ConvertSInt32: o << "convert_s/i32"; break; - case ConvertUInt64: o << "convert_u/i64"; break; - case ConvertSInt64: o << "convert_s/i64"; break; - case PromoteFloat32: o << "promote/f32"; break; - case DemoteFloat64: o << "demote/f64"; break; - case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; - default: abort(); + void finalize() { + if (isRelational()) { + type = i32; + } else { + assert(left->type == right->type); + type = left->type; } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, value); - return decIndent(o, indent); } }; @@ -832,7 +815,7 @@ public: std::ostream& doPrint(std::ostream &o, unsigned indent) { switch (op) { case PageSize: printOpening(o, "pagesize") << ')'; break; - case MemorySize: printOpening(o, "memorysize") << ')'; break; + case MemorySize: printOpening(o, "memory_size") << ')'; break; case GrowMemory: { printOpening(o, "grow_memory"); incIndent(o, indent); @@ -845,6 +828,20 @@ public: } return o; } + + void finalize() { + switch (op) { + case PageSize: case MemorySize: case HasFeature: { + type = i32; + break; + } + case GrowMemory: { + type = none; + break; + } + default: abort(); + } + } }; class Unreachable : public Expression { @@ -906,7 +903,7 @@ public: std::ostream& print(std::ostream &o, unsigned indent) { printOpening(o, "import ") << name << ' '; printText(o, module.str) << ' '; - printText(o, base.str) << ' '; + printText(o, base.str); type.print(o, indent); return o << ')'; } @@ -1028,7 +1025,28 @@ public: printOpening(o, "memory") << " " << module.memory.initial; if (module.memory.max) o << " " << module.memory.max; for (auto segment : module.memory.segments) { - o << " (segment " << segment.offset << " \"" << segment.data << "\")"; + o << " (segment " << segment.offset << " \""; + for (size_t i = 0; i < segment.size; i++) { + unsigned char c = segment.data[i]; + switch (c) { + case '\n': o << "\\n"; break; + case '\r': o << "\\0d"; break; + case '\t': o << "\\t"; break; + case '\f': o << "\\0c"; break; + case '\b': o << "\\08"; break; + case '\\': o << "\\\\"; break; + case '"' : o << "\\\""; break; + case '\'' : o << "\\'"; break; + default: { + if (c >= 32 && c < 127) { + o << c; + } else { + o << std::hex << '\\' << (c/16) << (c%16) << std::dec; + } + } + } + } + o << "\")"; } o << ")\n"; for (auto& curr : module.functionTypes) { @@ -1077,7 +1095,6 @@ struct WasmVisitor { virtual ReturnType visitBlock(Block *curr) { abort(); } virtual ReturnType visitIf(If *curr) { abort(); } virtual ReturnType visitLoop(Loop *curr) { abort(); } - virtual ReturnType visitLabel(Label *curr) { abort(); } virtual ReturnType visitBreak(Break *curr) { abort(); } virtual ReturnType visitSwitch(Switch *curr) { abort(); } virtual ReturnType visitCall(Call *curr) { abort(); } @@ -1090,8 +1107,6 @@ struct WasmVisitor { virtual ReturnType visitConst(Const *curr) { abort(); } virtual ReturnType visitUnary(Unary *curr) { abort(); } virtual ReturnType visitBinary(Binary *curr) { abort(); } - virtual ReturnType visitCompare(Compare *curr) { abort(); } - virtual ReturnType visitConvert(Convert *curr) { abort(); } virtual ReturnType visitSelect(Select *curr) { abort(); } virtual ReturnType visitHost(Host *curr) { abort(); } virtual ReturnType visitNop(Nop *curr) { abort(); } @@ -1111,7 +1126,6 @@ struct WasmVisitor { case Expression::Id::BlockId: return visitBlock((Block*)curr); case Expression::Id::IfId: return visitIf((If*)curr); case Expression::Id::LoopId: return visitLoop((Loop*)curr); - case Expression::Id::LabelId: return visitLabel((Label*)curr); case Expression::Id::BreakId: return visitBreak((Break*)curr); case Expression::Id::SwitchId: return visitSwitch((Switch*)curr); case Expression::Id::CallId: return visitCall((Call*)curr); @@ -1124,8 +1138,6 @@ struct WasmVisitor { case Expression::Id::ConstId: return visitConst((Const*)curr); case Expression::Id::UnaryId: return visitUnary((Unary*)curr); case Expression::Id::BinaryId: return visitBinary((Binary*)curr); - case Expression::Id::CompareId: return visitCompare((Compare*)curr); - case Expression::Id::ConvertId: return visitConvert((Convert*)curr); case Expression::Id::SelectId: return visitSelect((Select*)curr); case Expression::Id::HostId: return visitHost((Host*)curr); case Expression::Id::NopId: return visitNop((Nop*)curr); @@ -1148,7 +1160,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { void visitBlock(Block *curr) override { curr->doPrint(o, indent); } void visitIf(If *curr) override { curr->doPrint(o, indent); } void visitLoop(Loop *curr) override { curr->doPrint(o, indent); } - void visitLabel(Label *curr) override { curr->doPrint(o, indent); } void visitBreak(Break *curr) override { curr->doPrint(o, indent); } void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); } void visitCall(Call *curr) override { curr->doPrint(o, indent); } @@ -1161,8 +1172,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { void visitConst(Const *curr) override { curr->doPrint(o, indent); } void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } - void visitCompare(Compare *curr) override { curr->doPrint(o, indent); } - void visitConvert(Convert *curr) override { curr->doPrint(o, indent); } void visitSelect(Select *curr) override { curr->doPrint(o, indent); } void visitHost(Host *curr) override { curr->doPrint(o, indent); } void visitNop(Nop *curr) override { curr->doPrint(o, indent); } @@ -1176,7 +1185,7 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { // // Simple WebAssembly children-first walking (i.e., post-order, if you look -// at the children as subtrees of the current node), with the ability to +// at the children as subtrees of the current node), with the ability to replace // the current expression node. Useful for writing optimization passes. // @@ -1194,7 +1203,6 @@ struct WasmWalker : public WasmVisitor<void> { void visitBlock(Block *curr) override {} void visitIf(If *curr) override {} void visitLoop(Loop *curr) override {} - void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override {} void visitSwitch(Switch *curr) override {} void visitCall(Call *curr) override {} @@ -1207,8 +1215,6 @@ struct WasmWalker : public WasmVisitor<void> { void visitConst(Const *curr) override {} void visitUnary(Unary *curr) override {} void visitBinary(Binary *curr) override {} - void visitCompare(Compare *curr) override {} - void visitConvert(Convert *curr) override {} void visitSelect(Select *curr) override {} void visitHost(Host *curr) override {} void visitNop(Nop *curr) override {} @@ -1245,8 +1251,8 @@ struct WasmWalker : public WasmVisitor<void> { void visitLoop(Loop *curr) override { parent.walk(curr->body); } - void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override { + parent.walk(curr->condition); parent.walk(curr->value); } void visitSwitch(Switch *curr) override { @@ -1293,13 +1299,6 @@ struct WasmWalker : public WasmVisitor<void> { parent.walk(curr->left); parent.walk(curr->right); } - void visitCompare(Compare *curr) override { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitConvert(Convert *curr) override { - parent.walk(curr->value); - } void visitSelect(Select *curr) override { parent.walk(curr->condition); parent.walk(curr->ifTrue); |