summaryrefslogtreecommitdiff
path: root/src/wasm.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm.h')
-rw-r--r--src/wasm.h818
1 files changed, 818 insertions, 0 deletions
diff --git a/src/wasm.h b/src/wasm.h
new file mode 100644
index 000000000..d313c6e4b
--- /dev/null
+++ b/src/wasm.h
@@ -0,0 +1,818 @@
+//
+// WebAssembly representation and processing library
+//
+
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
+#include <map>
+#include <vector>
+
+#include "colors.h"
+
+namespace wasm {
+
+// Utilities
+
+// Arena allocation for mixed-type data.
+struct Arena {
+ std::vector<char*> chunks;
+ int index; // in last chunk
+
+ template<class T>
+ T* alloc() {
+ const size_t CHUNK = 10000;
+ size_t currSize = (sizeof(T) + 7) & (-8); // same alignment as malloc TODO optimize?
+ assert(currSize < CHUNK);
+ if (chunks.size() == 0 || index + currSize >= CHUNK) {
+ chunks.push_back(new char[CHUNK]);
+ index = 0;
+ }
+ T* ret = (T*)(chunks.back() + index);
+ index += currSize;
+ new (ret) T();
+ return ret;
+ }
+
+ void clear() {
+ for (char* chunk : chunks) {
+ delete[] chunk;
+ }
+ chunks.clear();
+ }
+
+ ~Arena() {
+ clear();
+ }
+};
+
+std::ostream &doIndent(std::ostream &o, unsigned indent) {
+ for (unsigned i = 0; i < indent; i++) {
+ o << " ";
+ }
+ return o;
+}
+void incIndent(std::ostream &o, unsigned& indent) {
+ o << '\n';
+ indent++;
+}
+void decIndent(std::ostream &o, unsigned& indent) {
+ indent--;
+ doIndent(o, indent);
+ o << ')';
+}
+
+// Basics
+
+struct Name : public cashew::IString {
+ Name() : cashew::IString() {}
+ Name(const char *str) : cashew::IString(str) {}
+ Name(cashew::IString str) : cashew::IString(str) {}
+
+ std::ostream& print(std::ostream &o) {
+ assert(str);
+ o << '$' << str; // reference interpreter requires we prefix all names
+ return o;
+ }
+};
+
+// Types
+
+enum BasicType {
+ none,
+ i32,
+ i64,
+ f32,
+ f64
+};
+
+std::ostream& printBasicType(std::ostream &o, BasicType type) {
+ switch (type) {
+ case BasicType::none: o << "none"; break;
+ case BasicType::i32: o << "i32"; break;
+ case BasicType::i64: o << "i64"; break;
+ case BasicType::f32: o << "f32"; break;
+ case BasicType::f64: o << "f64"; break;
+ }
+ return o;
+}
+
+unsigned getBasicTypeSize(BasicType type) {
+ switch (type) {
+ case BasicType::none: abort();
+ case BasicType::i32: return 4;
+ case BasicType::i64: return 8;
+ case BasicType::f32: return 4;
+ case BasicType::f64: return 8;
+ }
+}
+
+bool isFloat(BasicType type) {
+ switch (type) {
+ case f32:
+ case f64: return true;
+ }
+ return false;
+}
+
+BasicType getBasicType(unsigned size, bool float_) {
+ if (size < 4) return BasicType::i32;
+ if (size == 4) return float_ ? BasicType::f32 : BasicType::i32;
+ if (size == 8) return float_ ? BasicType::f64 : BasicType::i64;
+ abort();
+}
+
+void prepareMajorColor(std::ostream &o) {
+ Colors::red(o);
+ Colors::bold(o);
+}
+
+void prepareColor(std::ostream &o) {
+ Colors::magenta(o);
+ Colors::bold(o);
+}
+
+void prepareMinorColor(std::ostream &o) {
+ Colors::orange(o);
+}
+
+void restoreNormalColor(std::ostream &o) {
+ Colors::normal(o);
+}
+
+std::ostream& printText(std::ostream &o, const char *str) {
+ o << '"';
+ Colors::green(o);
+ o << str;
+ Colors::normal(o);
+ o << '"';
+ return o;
+}
+
+struct Literal {
+ BasicType type;
+ union {
+ int32_t i32;
+ int64_t i64;
+ float f32;
+ double f64;
+ };
+
+ Literal() : type(BasicType::none) {}
+ Literal(int32_t init) : type(BasicType::i32), i32(init) {}
+ Literal(int64_t init) : type(BasicType::i64), i64(init) {}
+ Literal(float init) : type(BasicType::f32), f32(init) {}
+ Literal(double init) : type(BasicType::f64), f64(init) {}
+
+ std::ostream& print(std::ostream &o) {
+ o << '(';
+ prepareMinorColor(o);
+ printBasicType(o, type) << ".const ";
+ switch (type) {
+ case none: abort();
+ case BasicType::i32: o << i32; break;
+ case BasicType::i64: o << i64; break;
+ case BasicType::f32: o << f32; break;
+ case BasicType::f64: o << f64; break;
+ }
+ restoreNormalColor(o);
+ o << ')';
+ return o;
+ }
+};
+
+// Operators
+
+enum UnaryOp {
+ Clz, Ctz, Popcnt, // int
+ Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt // float
+};
+
+enum BinaryOp {
+ Add, Sub, Mul, // int or float
+ DivS, DivU, RemS, RemU, And, Or, Xor, Shl, ShrU, ShrS, // int
+ Div, CopySign, Min, Max // float
+};
+
+enum RelationalOp {
+ Eq, Ne, // int or float
+ LtS, LtU, LeS, LeU, GtS, GtU, GeS, GeU, // int
+ Lt, Le, Gt, Ge // float
+};
+
+enum ConvertOp {
+ ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int
+ ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float
+};
+
+enum HostOp {
+ PageSize, MemorySize, GrowMemory, HasFeature
+};
+
+// Expressions
+
+class Expression {
+public:
+ BasicType type;
+
+ Expression() : type(type) {}
+
+ virtual std::ostream& print(std::ostream &o, unsigned indent) = 0;
+
+ template<class T>
+ bool is() {
+ return !!dynamic_cast<T*>(this);
+ }
+};
+
+std::ostream& printFullLine(std::ostream &o, unsigned indent, Expression *expression) {
+ doIndent(o, indent);
+ expression->print(o, indent);
+ o << '\n';
+}
+
+std::ostream& printOpening(std::ostream &o, const char *str, bool major=false) {
+ o << '(';
+ major ? prepareMajorColor(o) : prepareColor(o);
+ o << str;
+ restoreNormalColor(o);
+ return o;
+}
+
+std::ostream& printMinorOpening(std::ostream &o, const char *str) {
+ o << '(';
+ prepareMinorColor(o);
+ o << str;
+ restoreNormalColor(o);
+ return o;
+}
+
+typedef std::vector<Expression*> ExpressionList; // TODO: optimize
+
+class Nop : public Expression {
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ o << "nop";
+ return o;
+ }
+};
+
+class Block : public Expression {
+public:
+ Name var;
+ ExpressionList list;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "block");
+ if (var.is()) {
+ o << " ";
+ var.print(o);
+ }
+ incIndent(o, indent);
+ for (auto expression : list) {
+ printFullLine(o, indent, expression);
+ }
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class If : public Expression {
+public:
+ Expression *condition, *ifTrue, *ifFalse;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "if");
+ incIndent(o, indent);
+ printFullLine(o, indent, condition);
+ printFullLine(o, indent, ifTrue);
+ if (ifFalse) printFullLine(o, indent, ifFalse);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Loop : public Expression {
+public:
+ Name out, in;
+ Expression *body;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "loop");
+ if (out.is()) {
+ o << " ";
+ out.print(o);
+ if (in.is()) {
+ o << " ";
+ in.print(o);
+ }
+ }
+ incIndent(o, indent);
+ printFullLine(o, indent, body);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Label : public Expression {
+public:
+ Name var;
+};
+
+class Break : public Expression {
+public:
+ Name var;
+ Expression *condition, *value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "break ");
+ var.print(o);
+ incIndent(o, indent);
+ if (condition) printFullLine(o, indent, condition);
+ if (value) printFullLine(o, indent, value);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Switch : public Expression {
+public:
+ struct Case {
+ Literal value;
+ Expression *body;
+ bool fallthru;
+ };
+
+ Name var;
+ Expression *value;
+ std::vector<Case> cases;
+ Expression *default_;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "switch ");
+ var.print(o);
+ incIndent(o, indent);
+ printFullLine(o, indent, value);
+ o << "TODO: cases/default\n";
+ decIndent(o, indent);
+ return o;
+ }
+
+};
+
+class Call : public Expression {
+public:
+ Name target;
+ ExpressionList operands;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "call ");
+ target.print(o);
+ if (operands.size() > 0) {
+ incIndent(o, indent);
+ for (auto operand : operands) {
+ printFullLine(o, indent, operand);
+ }
+ decIndent(o, indent);
+ } else {
+ o << ')';
+ }
+ return o;
+ }
+};
+
+class CallImport : public Call {
+};
+
+class FunctionType {
+public:
+ Name name;
+ BasicType result;
+ std::vector<BasicType> params;
+
+ std::ostream& print(std::ostream &o, unsigned indent, bool full=false) {
+ if (full) {
+ printOpening(o, "type") << ' ';
+ name.print(o);
+ }
+ if (params.size() > 0) {
+ o << ' ';
+ printMinorOpening(o, "param");
+ for (auto& param : params) {
+ o << ' ';
+ printBasicType(o, param);
+ }
+ o << ')';
+ }
+ if (result != none) {
+ o << ' ';
+ printMinorOpening(o, "result ");
+ printBasicType(o, result) << ')';
+ }
+ if (full) {
+ o << ')';
+ }
+ return o;
+ }
+
+ bool operator==(FunctionType& b) {
+ if (name != b.name) return false; // XXX
+ if (result != b.result) return false;
+ if (params.size() != b.params.size()) return false;
+ for (size_t i = 0; i < params.size(); i++) {
+ if (params[i] != b.params[i]) return false;
+ }
+ return true;
+ }
+ bool operator!=(FunctionType& b) {
+ return !(*this == b);
+ }
+};
+
+class CallIndirect : public Expression {
+public:
+ FunctionType *type;
+ Expression *target;
+ ExpressionList operands;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "call_indirect ");
+ type->name.print(o);
+ incIndent(o, indent);
+ printFullLine(o, indent, target);
+ for (auto operand : operands) {
+ printFullLine(o, indent, operand);
+ }
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class GetLocal : public Expression {
+public:
+ Name id;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "get_local ");
+ id.print(o) << ')';
+ return o;
+ }
+};
+
+class SetLocal : public Expression {
+public:
+ Name id;
+ Expression *value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "set_local ");
+ id.print(o);
+ incIndent(o, indent);
+ printFullLine(o, indent, value);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Load : public Expression {
+public:
+ unsigned bytes;
+ bool signed_;
+ bool float_;
+ int offset;
+ unsigned align;
+ Expression *ptr;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ o << '(';
+ prepareColor(o);
+ printBasicType(o, getBasicType(bytes, float_)) << ".load";
+ if (bytes < 4) {
+ if (bytes == 1) {
+ o << '8';
+ } else if (bytes == 2) {
+ o << "16";
+ } else {
+ abort();
+ }
+ if (!signed_) o << "_u";
+ }
+ restoreNormalColor(o);
+ o << " align=" << align;
+ assert(!offset);
+ incIndent(o, indent);
+ printFullLine(o, indent, ptr);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Store : public Expression {
+public:
+ unsigned bytes;
+ bool float_;
+ int offset;
+ unsigned align;
+ Expression *ptr, *value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ o << '(';
+ prepareColor(o);
+ printBasicType(o, getBasicType(bytes, float_)) << ".store";
+ if (bytes < 4) {
+ if (bytes == 1) {
+ o << '8';
+ } else if (bytes == 2) {
+ o << "16";
+ } else {
+ abort();
+ }
+ }
+ restoreNormalColor(o);
+ o << " align=" << align;
+ assert(!offset);
+ incIndent(o, indent);
+ printFullLine(o, indent, ptr);
+ printFullLine(o, indent, value);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Const : public Expression {
+public:
+ Literal value;
+
+ Const* set(Literal value_) {
+ value = value_;
+ return this;
+ }
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ value.print(o);
+ return o;
+ }
+};
+
+class Unary : public Expression {
+public:
+ UnaryOp op;
+ Expression *value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "unary ");
+ switch (op) {
+ case Neg: o << "neg"; break;
+ default: abort();
+ }
+ incIndent(o, indent);
+ printFullLine(o, indent, value);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Binary : public Expression {
+public:
+ BinaryOp op;
+ Expression *left, *right;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ o << '(';
+ prepareColor(o);
+ printBasicType(o, type) << '.';
+ switch (op) {
+ case Add: o << "add"; break;
+ case Sub: o << "sub"; break;
+ case Mul: o << "mul"; break;
+ case DivS: o << "div_s"; break;
+ case DivU: o << "div_u"; break;
+ case RemS: o << "rem_s"; break;
+ case RemU: o << "rem_u"; break;
+ case And: o << "and"; break;
+ case Or: o << "or"; break;
+ case Xor: o << "xor"; break;
+ case Shl: o << "shl"; break;
+ case ShrU: o << "shr_u"; break;
+ case ShrS: o << "shr_s"; break;
+ case Div: o << "div"; break;
+ case CopySign: o << "copysign"; break;
+ case Min: o << "min"; break;
+ case Max: o << "max"; break;
+ default: abort();
+ }
+ restoreNormalColor(o);
+ incIndent(o, indent);
+ printFullLine(o, indent, left);
+ printFullLine(o, indent, right);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Compare : public Expression {
+public:
+ RelationalOp op;
+ Expression *left, *right;
+
+ Compare() {
+ type = BasicType::i32;
+ }
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ o << '(';
+ prepareColor(o);
+ printBasicType(o, type) << '.';
+ switch (op) {
+ case Eq: o << "eq"; break;
+ case Ne: o << "ne"; break;
+ case LtS: o << "lt_s"; break;
+ case LtU: o << "lt_u"; break;
+ case LeS: o << "le_s"; break;
+ case LeU: o << "le_u"; break;
+ case GtS: o << "gt_s"; break;
+ case GtU: o << "gt_u"; break;
+ case GeS: o << "ge_s"; break;
+ case GeU: o << "ge_u"; break;
+ case Lt: o << "lt"; break;
+ case Le: o << "le"; break;
+ case Gt: o << "gt"; break;
+ case Ge: o << "ge"; break;
+ default: abort();
+ }
+ restoreNormalColor(o);
+ incIndent(o, indent);
+ printFullLine(o, indent, left);
+ printFullLine(o, indent, right);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Convert : public Expression {
+public:
+ ConvertOp op;
+ Expression *value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) override {
+ printOpening(o, "convert ");
+ switch (op) {
+ case ConvertUInt32: o << "uint32toDouble"; break;
+ case ConvertSInt32: o << "sint32toDouble"; break;
+ case TruncSFloat64: o << "float64tosint32"; break;
+ default: abort();
+ }
+ incIndent(o, indent);
+ printFullLine(o, indent, value);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Host : public Expression {
+public:
+ HostOp op;
+ ExpressionList operands;
+};
+
+// Globals
+
+struct NameType {
+ Name name;
+ BasicType type;
+ NameType() : name(nullptr), type(none) {}
+ NameType(Name name, BasicType type) : name(name), type(type) {}
+};
+
+class Function {
+public:
+ Name name;
+ BasicType result;
+ std::vector<NameType> params;
+ std::vector<NameType> locals;
+ Expression *body;
+
+ std::ostream& print(std::ostream &o, unsigned indent) {
+ printOpening(o, "func ", true);
+ name.print(o);
+ if (params.size() > 0) {
+ for (auto& param : params) {
+ o << ' ';
+ printMinorOpening(o, "param ");
+ param.name.print(o) << " ";
+ printBasicType(o, param.type) << ")";
+ }
+ }
+ if (result != none) {
+ o << ' ';
+ printMinorOpening(o, "result ");
+ printBasicType(o, result) << ")";
+ }
+ incIndent(o, indent);
+ for (auto& local : locals) {
+ doIndent(o, indent);
+ printMinorOpening(o, "local ");
+ local.name.print(o) << " ";
+ printBasicType(o, local.type) << ")\n";
+ }
+ printFullLine(o, indent, body);
+ decIndent(o, indent);
+ return o;
+ }
+};
+
+class Import {
+public:
+ Name name, module, base; // name = module.base
+ FunctionType type;
+
+ std::ostream& print(std::ostream &o, unsigned indent) {
+ printOpening(o, "import ");
+ name.print(o) << ' ';
+ printText(o, module.str) << ' ';
+ printText(o, base.str) << ' ';
+ type.print(o, indent);
+ o << ')';
+ return o;
+ }
+};
+
+class Export {
+public:
+ Name name;
+ Name value;
+
+ std::ostream& print(std::ostream &o, unsigned indent) {
+ printOpening(o, "export") << ' ';
+ name.print(o) << ' ';
+ printText(o, name.str) << ' ';
+ value.print(o);
+ o << ')';
+ return o;
+ }
+};
+
+class Table {
+public:
+ std::vector<Name> vars;
+
+ std::ostream& print(std::ostream &o, unsigned indent) {
+ printOpening(o, "table");
+ for (auto var : vars) {
+ o << ' ';
+ var.print(o);
+ }
+ o << ')';
+ return o;
+ }
+};
+
+class Module {
+protected:
+ // wasm contents
+ std::map<Name, FunctionType*> functionTypes;
+ std::map<Name, Import> imports;
+ std::vector<Export> exports;
+ Table table;
+ std::vector<Function*> functions;
+
+ // internals
+ std::map<Name, void*> map; // maps var ids/names to things
+ unsigned nextVar;
+
+public:
+ Module() : nextVar(1) {}
+
+ std::ostream& print(std::ostream &o) {
+ unsigned indent = 0;
+ printOpening(o, "module", true);
+ incIndent(o, indent);
+ for (auto& curr : functionTypes) {
+ doIndent(o, indent);
+ curr.second->print(o, indent, true);
+ o << '\n';
+ }
+ for (auto& curr : imports) {
+ doIndent(o, indent);
+ curr.second.print(o, indent);
+ o << '\n';
+ }
+ for (auto& curr : exports) {
+ doIndent(o, indent);
+ curr.print(o, indent);
+ o << '\n';
+ }
+ doIndent(o, indent);
+ table.print(o, indent);
+ o << '\n';
+ for (auto& curr : functions) {
+ doIndent(o, indent);
+ curr->print(o, indent);
+ o << '\n';
+ }
+ decIndent(o, indent);
+ o << '\n';
+ }
+};
+
+} // namespace wasm
+