diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/asm2wasm-main.cpp | 34 | ||||
-rw-r--r-- | src/asm2wasm.h | 797 | ||||
-rw-r--r-- | src/asm_v_wasm.h | 70 | ||||
-rw-r--r-- | src/binaryen-shell.cpp | 41 | ||||
-rw-r--r-- | src/emscripten-optimizer/optimizer-shared.cpp | 122 | ||||
-rw-r--r-- | src/emscripten-optimizer/optimizer.h | 23 | ||||
-rw-r--r-- | src/emscripten-optimizer/parser.h | 10 | ||||
-rw-r--r-- | src/emscripten-optimizer/simple_ast.cpp | 2 | ||||
-rw-r--r-- | src/emscripten-optimizer/simple_ast.h | 139 | ||||
-rw-r--r-- | src/js/post.js | 177 | ||||
-rw-r--r-- | src/mixed_arena.h | 11 | ||||
-rw-r--r-- | src/parsing.h | 153 | ||||
-rw-r--r-- | src/pass.h | 1 | ||||
-rw-r--r-- | src/passes/NameManager.cpp | 3 | ||||
-rw-r--r-- | src/passes/RemoveImports.cpp | 43 | ||||
-rw-r--r-- | src/s2wasm-main.cpp | 48 | ||||
-rw-r--r-- | src/s2wasm.h | 1033 | ||||
-rw-r--r-- | src/shared-constants.h | 75 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 352 | ||||
-rw-r--r-- | src/wasm-js.cpp | 213 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 355 | ||||
-rw-r--r-- | src/wasm-validator.h | 27 | ||||
-rw-r--r-- | src/wasm.h | 347 | ||||
-rw-r--r-- | src/wasm2asm-main.cpp | 60 | ||||
-rw-r--r-- | src/wasm2asm.h | 993 |
25 files changed, 4129 insertions, 1000 deletions
diff --git a/src/asm2wasm-main.cpp b/src/asm2wasm-main.cpp index 7c4d93052..b2b490965 100644 --- a/src/asm2wasm-main.cpp +++ b/src/asm2wasm-main.cpp @@ -7,10 +7,15 @@ using namespace cashew; using namespace wasm; +namespace wasm { +int debug = 0; +} + int main(int argc, char **argv) { debug = getenv("ASM2WASM_DEBUG") ? getenv("ASM2WASM_DEBUG")[0] - '0' : 0; char *infile = argv[1]; + char *mappedGlobals = argc < 3 ? nullptr : argv[2]; if (debug) std::cerr << "loading '" << infile << "'...\n"; FILE *f = fopen(argv[1], "r"); @@ -27,33 +32,17 @@ int main(int argc, char **argv) { fclose(f); input[num] = 0; - // emcc --separate-asm modules look like - // - // Module["asm"] = (function(global, env, buffer) { - // .. - // }); - // - // we need to clean that up. - if (*input == 'M') { - while (*input != 'f') { - input++; - num--; - } - char *end = input + num - 1; - while (*end != '}') { - *end = 0; - end--; - } - } + Asm2WasmPreProcessor pre; + input = pre.process(input); if (debug) std::cerr << "parsing...\n"; cashew::Parser<Ref, DotZeroValueBuilder> builder; Ref asmjs = builder.parseToplevel(input); if (debug) std::cerr << "wasming...\n"; - Module wasm; + AllocatingModule wasm; wasm.memory.initial = wasm.memory.max = 16*1024*1024; // we would normally receive this from the compiler - Asm2WasmBuilder asm2wasm(wasm); + Asm2WasmBuilder asm2wasm(wasm, pre.memoryGrowth); asm2wasm.processAsm(asmjs); if (debug) std::cerr << "optimizing...\n"; @@ -62,6 +51,11 @@ int main(int argc, char **argv) { if (debug) std::cerr << "printing...\n"; std::cout << wasm; + if (mappedGlobals) { + if (debug) std::cerr << "serializing mapped globals...\n"; + asm2wasm.serializeMappedGlobals(mappedGlobals); + } + if (debug) std::cerr << "done.\n"; } diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 9fa8d11dd..ca88fa883 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -7,34 +7,17 @@ #include "wasm.h" #include "emscripten-optimizer/optimizer.h" #include "mixed_arena.h" +#include "shared-constants.h" +#include "asm_v_wasm.h" namespace wasm { using namespace cashew; -int debug = 0; // wasm::debug is set in main(), typically from an env var +extern int debug; // wasm::debug is set in main(), typically from an env var // Utilities -IString GLOBAL("global"), NAN_("NaN"), INFINITY_("Infinity"), - TOPMOST("topmost"), - INT8ARRAY("Int8Array"), - INT16ARRAY("Int16Array"), - INT32ARRAY("Int32Array"), - UINT8ARRAY("Uint8Array"), - UINT16ARRAY("Uint16Array"), - UINT32ARRAY("Uint32Array"), - FLOAT32ARRAY("Float32Array"), - FLOAT64ARRAY("Float64Array"), - IMPOSSIBLE_CONTINUE("impossible-continue"), - MATH("Math"), - IMUL("imul"), - CLZ32("clz32"), - FROUND("fround"), - ASM2WASM("asm2wasm"), - F64_REM("f64-rem"); - - static void abort_on(std::string why) { std::cerr << why << '\n'; abort(); @@ -75,13 +58,69 @@ struct AstStackHelper { std::vector<Ref> AstStackHelper::astStack; // +// Asm2WasmPreProcessor - does some initial parsing/processing +// of asm.js code. +// + +struct Asm2WasmPreProcessor { + bool memoryGrowth = false; + + char* process(char* input) { + // emcc --separate-asm modules can look like + // + // Module["asm"] = (function(global, env, buffer) { + // .. + // }); + // + // we need to clean that up. + if (*input == 'M') { + size_t num = strlen(input); + while (*input != 'f') { + input++; + num--; + } + char *end = input + num - 1; + while (*end != '}') { + *end = 0; + end--; + } + } + + // asm.js memory growth uses a quite elaborate pattern. Instead of parsing and + // matching it, we do a simpler detection on emscripten's asm.js output format + const char* START_FUNCS = "// EMSCRIPTEN_START_FUNCS"; + char *marker = strstr(input, START_FUNCS); + if (marker) { + *marker = 0; // look for memory growth code just up to here + char *growthSign = strstr(input, "return true;"); // this can only show up in growth code, as normal asm.js lacks "true" + if (growthSign) { + memoryGrowth = true; + // clean out this function, we don't need it + char *growthFuncStart = strstr(input, "function "); + assert(strstr(growthFuncStart + 1, "function ") == 0); // should be only this one function in this area, so no confusion for us + char *growthFuncEnd = strchr(growthSign, '}'); + assert(growthFuncEnd > growthFuncStart + 5); + growthFuncStart[0] = '/'; + growthFuncStart[1] = '*'; + growthFuncEnd--; + growthFuncEnd[0] = '*'; + growthFuncEnd[1] = '/'; + } + *marker = START_FUNCS[0]; + } + + return input; + } +}; + +// // Asm2WasmBuilder - converts an asm.js module into WebAssembly // class Asm2WasmBuilder { - Module& wasm; + AllocatingModule& wasm; - MixedArena allocator; + MixedArena &allocator; // globals @@ -96,9 +135,34 @@ class Asm2WasmBuilder { MappedGlobal(unsigned address, WasmType type, bool import, IString module, IString base) : address(address), type(type), import(import), module(module), base(base) {} }; + // function table + std::map<IString, int> functionTableStarts; // each asm function table gets a range in the one wasm table, starting at a location + std::map<CallIndirect*, IString> callIndirects; // track these, as we need to fix them after we know the functionTableStarts. this maps call => its function table + + bool memoryGrowth; + public: std::map<IString, MappedGlobal> mappedGlobals; + // the global mapping info is not present in the output wasm. We need to save it on the side + // if we intend to load and run this module's wasm. + void serializeMappedGlobals(const char *filename) { + FILE *f = fopen(filename, "w"); + assert(f); + fprintf(f, "{\n"); + bool first = true; + for (auto& pair : mappedGlobals) { + auto name = pair.first; + auto& global = pair.second; + if (first) first = false; + else fprintf(f, ","); + fprintf(f, "\"%s\": { \"address\": %d, \"type\": %d, \"import\": %d, \"module\": \"%s\", \"base\": \"%s\" }\n", + name.str, global.address, global.type, global.import, global.module.str, global.base.str); + } + fprintf(f, "}"); + fclose(f); + } + private: void allocateGlobal(IString name, WasmType type, bool import, IString module = IString(), IString base = IString()) { assert(mappedGlobals.find(name) == mappedGlobals.end()); @@ -116,9 +180,14 @@ private: }; std::map<IString, View> views; // name (e.g. HEAP8) => view info - IString Math_imul; // imported name of Math.imul - IString Math_clz32; // imported name of Math.imul - IString Math_fround; // imported name of Math.fround + + // Imported names of Math.* + IString Math_imul; + IString Math_clz32; + IString Math_fround; + IString Math_abs; + IString Math_floor; + IString Math_sqrt; // function types. we fill in this information as we see // uses, in the first pass @@ -150,9 +219,9 @@ private: if (previous.params.size() > i) { if (previous.params[i] == none) { previous.params[i] = type.params[i]; // use a more concrete type - } else { - previous.params.push_back(type.params[i]); // add a new param } + } else { + previous.params.push_back(type.params[i]); // add a new param } } if (previous.result == none) { @@ -164,22 +233,13 @@ private: } } - char getSigFromType(WasmType type) { - switch (type) { - case i32: return 'i'; - case f64: return 'd'; - case none: return 'v'; - default: abort(); - } - } - FunctionType *getFunctionType(Ref parent, ExpressionList& operands) { // generate signature WasmType result = detectWasmType(parent, nullptr); std::string str = "FUNCSIG$"; - str += getSigFromType(result); + str += getSig(result); for (auto operand : operands) { - str += getSigFromType(operand->type); + str += getSig(operand->type); } IString sig(str.c_str(), false); if (wasm.functionTypesMap.find(sig) == wasm.functionTypesMap.end()) { @@ -196,35 +256,12 @@ private: } public: - Asm2WasmBuilder(Module& wasm) : wasm(wasm), nextGlobal(8), maxGlobal(1000) {} // XXX sync with emcc + Asm2WasmBuilder(AllocatingModule& wasm, bool memoryGrowth) : wasm(wasm), allocator(wasm.allocator), nextGlobal(8), maxGlobal(1000), memoryGrowth(memoryGrowth) {} void processAsm(Ref ast); void optimize(); private: - WasmType asmToWasmType(AsmType asmType) { - switch (asmType) { - case ASM_INT: return WasmType::i32; - case ASM_DOUBLE: return WasmType::f64; - case ASM_FLOAT: return WasmType::f32; - case ASM_NONE: return WasmType::none; - default: {} - } - abort_on("confused asmType", asmType); - return (WasmType)-1; // avoid warning - } - AsmType wasmToAsmType(WasmType type) { - switch (type) { - case WasmType::i32: return ASM_INT; - case WasmType::f32: return ASM_FLOAT; - case WasmType::f64: return ASM_DOUBLE; - case WasmType::none: return ASM_NONE; - default: {} - } - abort_on("confused wasmType", type); - return (AsmType)-1; // avoid warning - } - AsmType detectAsmType(Ref ast, AsmData *data) { if (ast[0] == NAME) { IString name = ast[1]->getIString(); @@ -240,31 +277,29 @@ private: return view->second.type; } } - return detectType(ast, data); + return detectType(ast, data, false, Math_fround); } WasmType detectWasmType(Ref ast, AsmData *data) { return asmToWasmType(detectAsmType(ast, data)); } - bool isUnsignedCoercion(Ref ast) { // TODO: use detectSign? - if (ast[0] == BINARY && ast[1] == TRSHIFT) return true; - return false; + bool isUnsignedCoercion(Ref ast) { + return detectSign(ast, Math_fround) == ASM_UNSIGNED; } - // an asm.js binary op can either be a binary or a relational in wasm - bool parseAsmBinaryOp(IString op, Ref left, Ref right, BinaryOp &binary, RelationalOp &relational, AsmData *asmData) { - if (op == PLUS) { binary = BinaryOp::Add; return true; } - if (op == MINUS) { binary = BinaryOp::Sub; return true; } - if (op == MUL) { binary = BinaryOp::Mul; return true; } - if (op == AND) { binary = BinaryOp::And; return true; } - if (op == OR) { binary = BinaryOp::Or; return true; } - if (op == XOR) { binary = BinaryOp::Xor; return true; } - if (op == LSHIFT) { binary = BinaryOp::Shl; return true; } - if (op == RSHIFT) { binary = BinaryOp::ShrS; return true; } - if (op == TRSHIFT) { binary = BinaryOp::ShrU; return true; } - if (op == EQ) { relational = RelationalOp::Eq; return false; } - if (op == NE) { relational = RelationalOp::Ne; return false; } + BinaryOp parseAsmBinaryOp(IString op, Ref left, Ref right, AsmData *asmData) { + if (op == PLUS) return BinaryOp::Add; + if (op == MINUS) return BinaryOp::Sub; + if (op == MUL) return BinaryOp::Mul; + if (op == AND) return BinaryOp::And; + if (op == OR) return BinaryOp::Or; + if (op == XOR) return BinaryOp::Xor; + if (op == LSHIFT) return BinaryOp::Shl; + if (op == RSHIFT) return BinaryOp::ShrS; + if (op == TRSHIFT) return BinaryOp::ShrU; + if (op == EQ) return BinaryOp::Eq; + if (op == NE) return BinaryOp::Ne; WasmType leftType = detectWasmType(left, asmData); #if 0 std::cout << "CHECK\n"; @@ -278,42 +313,42 @@ private: bool isUnsigned = isUnsignedCoercion(left) || isUnsignedCoercion(right); if (op == DIV) { if (isInteger) { - { binary = isUnsigned ? BinaryOp::DivU : BinaryOp::DivS; return true; } + return isUnsigned ? BinaryOp::DivU : BinaryOp::DivS; } - { binary = BinaryOp::Div; return true; } + return BinaryOp::Div; } if (op == MOD) { if (isInteger) { - { binary = isUnsigned ? BinaryOp::RemU : BinaryOp::RemS; return true; } + return isUnsigned ? BinaryOp::RemU : BinaryOp::RemS; } - { binary = BinaryOp::RemS; return true; } // XXX no floating-point remainder op, this must be handled by the caller + return BinaryOp::RemS; // XXX no floating-point remainder op, this must be handled by the caller } if (op == GE) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::GeU : RelationalOp::GeS; return false; } + return isUnsigned ? BinaryOp::GeU : BinaryOp::GeS; } - { relational = RelationalOp::Ge; return false; } + return BinaryOp::Ge; } if (op == GT) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::GtU : RelationalOp::GtS; return false; } + return isUnsigned ? BinaryOp::GtU : BinaryOp::GtS; } - { relational = RelationalOp::Gt; return false; } + return BinaryOp::Gt; } if (op == LE) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::LeU : RelationalOp::LeS; return false; } + return isUnsigned ? BinaryOp::LeU : BinaryOp::LeS; } - { relational = RelationalOp::Le; return false; } + return BinaryOp::Le; } if (op == LT) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::LtU : RelationalOp::LtS; return false; } + return isUnsigned ? BinaryOp::LtU : BinaryOp::LtS; } - { relational = RelationalOp::Lt; return false; } + return BinaryOp::Lt; } abort_on("bad wasm binary op", op); - return false; // avoid warning + abort(); // avoid warning } unsigned bytesToShift(unsigned bytes) { @@ -330,20 +365,83 @@ private: std::map<unsigned, Ref> tempNums; - Literal getLiteral(Ref ast) { + Literal checkLiteral(Ref ast) { if (ast[0] == NUM) { return Literal((int32_t)ast[1]->getInteger()); } else if (ast[0] == UNARY_PREFIX) { + if (ast[1] == PLUS && ast[2][0] == NUM) { + return Literal((double)ast[2][1]->getNumber()); + } if (ast[1] == MINUS && ast[2][0] == NUM) { double num = -ast[2][1]->getNumber(); assert(isInteger32(num)); return Literal((int32_t)num); } + if (ast[1] == PLUS && ast[2][0] == UNARY_PREFIX && ast[2][1] == MINUS && ast[2][2][0] == NUM) { + return Literal((double)-ast[2][2][1]->getNumber()); + } if (ast[1] == MINUS && ast[2][0] == UNARY_PREFIX && ast[2][1] == PLUS && ast[2][2][0] == NUM) { return Literal((double)-ast[2][2][1]->getNumber()); } } - abort(); + return Literal(); + } + + Literal getLiteral(Ref ast) { + Literal ret = checkLiteral(ast); + if (ret.type == none) abort(); + return ret; + } + + void fixCallType(Expression* call, WasmType type) { + if (call->is<Call>()) call->type = type; + if (call->is<CallImport>()) call->type = type; + else if (call->is<CallIndirect>()) call->type = type; + } + + FunctionType* getBuiltinFunctionType(Name module, Name base, ExpressionList* operands = nullptr) { + if (module == GLOBAL_MATH) { + if (base == ABS) { + assert(operands && operands->size() == 1); + WasmType type = (*operands)[0]->type; + if (type == i32) { + static FunctionType* builtin = nullptr; + if (!builtin) { + builtin = new FunctionType(); + builtin->params.push_back(i32); + builtin->result = i32; + } + return builtin; + } + if (type == f32) { + static FunctionType* builtin = nullptr; + if (!builtin) { + builtin = new FunctionType(); + builtin->params.push_back(f32); + builtin->result = f32; + } + return builtin; + } + if (type == f64) { + static FunctionType* builtin = nullptr; + if (!builtin) { + builtin = new FunctionType(); + builtin->params.push_back(f64); + builtin->result = f64; + } + return builtin; + } + + } + } + return nullptr; + } + + Block* blockify(Expression* expression) { + if (expression->is<Block>()) return expression->dyn_cast<Block>(); + auto ret = allocator.alloc<Block>(); + ret->list.push_back(expression); + return ret; } Function* processFunction(Ref ast); @@ -376,6 +474,18 @@ void Asm2WasmBuilder::processAsm(Ref ast) { assert(Math_fround.isNull()); Math_fround = name; return; + } else if (imported[2] == ABS) { + assert(Math_abs.isNull()); + Math_abs = name; + return; + } else if (imported[2] == FLOOR) { + assert(Math_floor.isNull()); + Math_floor = name; + return; + } else if (imported[2] == SQRT) { + assert(Math_sqrt.isNull()); + Math_sqrt = name; + return; } } std::string fullName = module[1][1]->getCString(); @@ -402,7 +512,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } }; - // first pass - do almost everything, but function imports + IString Int8Array, Int16Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, Float32Array, Float64Array; + + // first pass - do almost everything, but function imports and indirect calls for (unsigned i = 1; i < body->size(); i++) { Ref curr = body[i]; @@ -437,56 +549,96 @@ void Asm2WasmBuilder::processAsm(Ref ast) { assert(value[1][0] == NAME && value[1][1] == Math_fround && value[2][0][0] == NUM && value[2][0][1]->getNumber() == 0); allocateGlobal(name, WasmType::f32, false); } else if (value[0] == DOT) { + // simple module.base import. can be a view, or a function. + if (value[1][0] == NAME) { + IString module = value[1][1]->getIString(); + IString base = value[2]->getIString(); + if (module == GLOBAL) { + if (base == INT8ARRAY) { + Int8Array = name; + } else if (base == INT16ARRAY) { + Int16Array = name; + } else if (base == INT32ARRAY) { + Int32Array = name; + } else if (base == UINT8ARRAY) { + UInt8Array = name; + } else if (base == UINT16ARRAY) { + UInt16Array = name; + } else if (base == UINT32ARRAY) { + UInt32Array = name; + } else if (base == FLOAT32ARRAY) { + Float32Array = name; + } else if (base == FLOAT64ARRAY) { + Float64Array = name; + } + } + } // function import addImport(name, value, WasmType::none); } else if (value[0] == NEW) { // ignore imports of typed arrays, but note the names of the arrays value = value[1]; assert(value[0] == CALL); - Ref constructor = value[1]; - assert(constructor[0] == DOT); // global.*Array - IString heap = constructor[2]->getIString(); unsigned bytes; bool integer, signed_; AsmType asmType; - if (heap == INT8ARRAY) { - bytes = 1; integer = true; signed_ = true; asmType = ASM_INT; - } else if (heap == INT16ARRAY) { - bytes = 2; integer = true; signed_ = true; asmType = ASM_INT; - } else if (heap == INT32ARRAY) { - bytes = 4; integer = true; signed_ = true; asmType = ASM_INT; - } else if (heap == UINT8ARRAY) { - bytes = 1; integer = true; signed_ = false; asmType = ASM_INT; - } else if (heap == UINT16ARRAY) { - bytes = 2; integer = true; signed_ = false; asmType = ASM_INT; - } else if (heap == UINT32ARRAY) { - bytes = 4; integer = true; signed_ = false; asmType = ASM_INT; - } else if (heap == FLOAT32ARRAY) { - bytes = 4; integer = false; signed_ = true; asmType = ASM_DOUBLE; - } else if (heap == FLOAT64ARRAY) { - bytes = 8; integer = false; signed_ = true; asmType = ASM_DOUBLE; + Ref constructor = value[1]; + if (constructor[0] == DOT) { // global.*Array + IString heap = constructor[2]->getIString(); + if (heap == INT8ARRAY) { + bytes = 1; integer = true; signed_ = true; asmType = ASM_INT; + } else if (heap == INT16ARRAY) { + bytes = 2; integer = true; signed_ = true; asmType = ASM_INT; + } else if (heap == INT32ARRAY) { + bytes = 4; integer = true; signed_ = true; asmType = ASM_INT; + } else if (heap == UINT8ARRAY) { + bytes = 1; integer = true; signed_ = false; asmType = ASM_INT; + } else if (heap == UINT16ARRAY) { + bytes = 2; integer = true; signed_ = false; asmType = ASM_INT; + } else if (heap == UINT32ARRAY) { + bytes = 4; integer = true; signed_ = false; asmType = ASM_INT; + } else if (heap == FLOAT32ARRAY) { + bytes = 4; integer = false; signed_ = true; asmType = ASM_FLOAT; + } else if (heap == FLOAT64ARRAY) { + bytes = 8; integer = false; signed_ = true; asmType = ASM_DOUBLE; + } else { + abort_on("invalid view import", heap); + } + } else { // *ArrayView that was previously imported + assert(constructor[0] == NAME); + IString viewName = constructor[1]->getIString(); + if (viewName == Int8Array) { + bytes = 1; integer = true; signed_ = true; asmType = ASM_INT; + } else if (viewName == Int16Array) { + bytes = 2; integer = true; signed_ = true; asmType = ASM_INT; + } else if (viewName == Int32Array) { + bytes = 4; integer = true; signed_ = true; asmType = ASM_INT; + } else if (viewName == UInt8Array) { + bytes = 1; integer = true; signed_ = false; asmType = ASM_INT; + } else if (viewName == UInt16Array) { + bytes = 2; integer = true; signed_ = false; asmType = ASM_INT; + } else if (viewName == UInt32Array) { + bytes = 4; integer = true; signed_ = false; asmType = ASM_INT; + } else if (viewName == Float32Array) { + bytes = 4; integer = false; signed_ = true; asmType = ASM_FLOAT; + } else if (viewName == Float64Array) { + bytes = 8; integer = false; signed_ = true; asmType = ASM_DOUBLE; + } else { + abort_on("invalid short view import", viewName); + } } assert(views.find(name) == views.end()); views.emplace(name, View(bytes, integer, signed_, asmType)); } else if (value[0] == ARRAY) { - // function table. we "merge" them, so e.g. [foo, b1] , [b2, bar] => [foo, bar] , assuming b* are the aborting thunks - // when minified, we can't tell from the name b\d+, but null thunks appear multiple times in a table; others never do - // TODO: we can drop some b*s at the end of the table + // function table. we merge them into one big table, so e.g. [foo, b1] , [b2, bar] => [foo, b1, b2, bar] + // TODO: when not using aliasing function pointers, we could merge them by noticing that + // index 0 in each table is the null func, and each other index should only have one + // non-null func. However, that breaks down when function pointer casts are emulated. + functionTableStarts[name] = wasm.table.names.size(); // this table starts here Ref contents = value[1]; - std::map<IString, unsigned> counts; // name -> how many times seen for (unsigned k = 0; k < contents->size(); k++) { IString curr = contents[k][1]->getIString(); - counts[curr]++; - } - for (unsigned k = 0; k < contents->size(); k++) { - IString curr = contents[k][1]->getIString(); - if (wasm.table.names.size() <= k) { - wasm.table.names.push_back(curr); - } else { - if (counts[curr] == 1) { // if just one appearance, not a null thunk - wasm.table.names[k] = curr; - } - } + wasm.table.names.push_back(curr); } } else { abort_on("invalid var element", pair); @@ -512,7 +664,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } } - // second pass - function imports + // second pass. first, function imports std::vector<IString> toErase; @@ -520,6 +672,12 @@ void Asm2WasmBuilder::processAsm(Ref ast) { IString name = pair.first; Import& import = *pair.second; if (importedFunctionTypes.find(name) != importedFunctionTypes.end()) { + // special math builtins + FunctionType* builtin = getBuiltinFunctionType(import.module, import.base); + if (builtin) { + import.type = *builtin; + continue; + } import.type = importedFunctionTypes[name]; } else if (import.module != ASM2WASM) { // special-case the special module // never actually used @@ -530,6 +688,40 @@ void Asm2WasmBuilder::processAsm(Ref ast) { for (auto curr : toErase) { wasm.removeImport(curr); } + + // finalize indirect calls + + for (auto& pair : callIndirects) { + CallIndirect* call = pair.first; + IString tableName = pair.second; + assert(functionTableStarts.find(tableName) != functionTableStarts.end()); + auto sub = allocator.alloc<Binary>(); + // note that the target is already masked, so we just offset it, we don't need to guard against overflow (which would be an error anyhow) + sub->op = Add; + sub->left = call->target; + sub->right = allocator.alloc<Const>()->set(Literal((int32_t)functionTableStarts[tableName])); + sub->type = WasmType::i32; + call->target = sub; + } + + // apply memory growth, if relevant + if (memoryGrowth) { + // create and export a function that just calls memory growth + auto growWasmMemory = allocator.alloc<Function>(); + growWasmMemory->name = GROW_WASM_MEMORY; + growWasmMemory->params.emplace_back(NEW_SIZE, i32); // the new size + auto get = allocator.alloc<GetLocal>(); + get->name = NEW_SIZE; + auto grow = allocator.alloc<Host>(); + grow->op = GrowMemory; + grow->operands.push_back(get); + growWasmMemory->body = grow; + wasm.addFunction(growWasmMemory); + auto export_ = allocator.alloc<Export>(); + export_->name = export_->value = GROW_WASM_MEMORY; + wasm.addExport(export_); + } + } Function* Asm2WasmBuilder::processFunction(Ref ast) { @@ -563,7 +755,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { IStringSet functionVariables; // params or locals - IString parentLabel; // set in LABEL, then read in WHILE/DO + IString parentLabel; // set in LABEL, then read in WHILE/DO/SWITCH std::vector<IString> breakStack; // where a break will go std::vector<IString> continueStack; // where a continue will go @@ -575,7 +767,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { curr = curr[1]; assert(curr[0] == ASSIGN && curr[2][0] == NAME); IString name = curr[2][1]->getIString(); - AsmType asmType = detectType(curr[3]); + AsmType asmType = detectType(curr[3], nullptr, false, Math_fround); function->params.emplace_back(name, asmToWasmType(asmType)); functionVariables.insert(name); asmData.addParam(name, asmType); @@ -586,7 +778,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { for (unsigned j = 0; j < curr[1]->size(); j++) { Ref pair = curr[1][j]; IString name = pair[0]->getIString(); - AsmType asmType = detectType(pair[1], nullptr, true); + AsmType asmType = detectType(pair[1], nullptr, true, Math_fround); function->locals.emplace_back(name, asmToWasmType(asmType)); functionVariables.insert(name); asmData.addVar(name, asmType); @@ -594,6 +786,15 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { start++; } + bool addedI32Temp = false; + auto ensureI32Temp = [&]() { + if (addedI32Temp) return; + addedI32Temp = true; + function->locals.emplace_back(I32_TEMP, i32); + functionVariables.insert(I32_TEMP); + asmData.addVar(I32_TEMP, ASM_INT); + }; + bool seenReturn = false; // function->result is updated if we see a return bool needTopmost = false; // we label the topmost b lock if we need one for a return // processors @@ -646,63 +847,63 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->align = view.bytes; ret->ptr = processUnshifted(target[2], view.bytes); ret->value = process(ast[3]); - ret->type = ret->value->type; + ret->type = asmToWasmType(view.type); + if (ret->type != ret->value->type) { + // in asm.js we have some implicit coercions that we must do explicitly here + if (ret->type == f32 && ret->value->type == f64) { + auto conv = allocator.alloc<Unary>(); + conv->op = DemoteFloat64; + conv->value = ret->value; + conv->type = WasmType::f32; + ret->value = conv; + } else { + abort(); + } + } return ret; } abort_on("confusing assign", ast); } else if (what == BINARY) { - if (ast[1] == OR && ast[3][0] == NUM && ast[3][1]->getNumber() == 0) { - auto ret = process(ast[2]); // just look through the ()|0 coercion - ret->type = WasmType::i32; // we add it here for e.g. call coercions + if ((ast[1] == OR || ast[1] == TRSHIFT) && ast[3][0] == NUM && ast[3][1]->getNumber() == 0) { + auto ret = process(ast[2]); // just look through the ()|0 or ()>>>0 coercion + fixCallType(ret, i32); return ret; } - BinaryOp binary; - RelationalOp relational; - bool isBinary = parseAsmBinaryOp(ast[1]->getIString(), ast[2], ast[3], binary, relational, &asmData); - if (isBinary) { - auto ret = allocator.alloc<Binary>(); - ret->op = binary; - ret->left = process(ast[2]); - ret->right = process(ast[3]); - ret->type = ret->left->type; - if (binary == BinaryOp::RemS && isWasmTypeFloat(ret->type)) { - // WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours - CallImport *call = allocator.alloc<CallImport>(); - call->target = F64_REM; - call->operands.push_back(ret->left); - call->operands.push_back(ret->right); - call->type = f64; - static bool addedImport = false; - if (!addedImport) { - addedImport = true; - auto import = allocator.alloc<Import>(); // f64-rem = asm2wasm.f64-rem; - import->name = F64_REM; - import->module = ASM2WASM; - import->base = F64_REM; - import->type.name = F64_REM; - import->type.result = f64; - import->type.params.push_back(f64); - import->type.params.push_back(f64); - wasm.addImport(import); - } - return call; + BinaryOp binary = parseAsmBinaryOp(ast[1]->getIString(), ast[2], ast[3], &asmData); + auto ret = allocator.alloc<Binary>(); + ret->op = binary; + ret->left = process(ast[2]); + ret->right = process(ast[3]); + ret->finalize(); + if (binary == BinaryOp::RemS && isWasmTypeFloat(ret->type)) { + // WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours + CallImport *call = allocator.alloc<CallImport>(); + call->target = F64_REM; + call->operands.push_back(ret->left); + call->operands.push_back(ret->right); + call->type = f64; + static bool addedImport = false; + if (!addedImport) { + addedImport = true; + auto import = allocator.alloc<Import>(); // f64-rem = asm2wasm.f64-rem; + import->name = F64_REM; + import->module = ASM2WASM; + import->base = F64_REM; + import->type.name = F64_REM; + import->type.result = f64; + import->type.params.push_back(f64); + import->type.params.push_back(f64); + wasm.addImport(import); } - return ret; - } else { - auto ret = allocator.alloc<Compare>(); - ret->op = relational; - ret->left = process(ast[2]); - ret->right = process(ast[3]); - assert(ret->left->type == ret->right->type); - ret->inputType = ret->left->type; - return ret; + return call; } + return ret; } else if (what == NUM) { auto ret = allocator.alloc<Const>(); double num = ast[1]->getNumber(); if (isInteger32(num)) { ret->value.type = WasmType::i32; - ret->value.i32 = num; + ret->value.i32 = toInteger32(num); } else { ret->value.type = WasmType::f64; ret->value.f64 = num; @@ -718,6 +919,23 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->type = asmToWasmType(asmData.getType(name)); return ret; } + if (name == DEBUGGER) { + CallImport *call = allocator.alloc<CallImport>(); + call->target = DEBUGGER; + call->type = none; + static bool addedImport = false; + if (!addedImport) { + addedImport = true; + auto import = allocator.alloc<Import>(); // debugger = asm2wasm.debugger; + import->name = DEBUGGER; + import->module = ASM2WASM; + import->base = DEBUGGER; + import->type.name = DEBUGGER; + import->type.result = none; + wasm.addImport(import); + } + return call; + } // global var, do a load from memory assert(mappedGlobals.find(name) != mappedGlobals.end()); MappedGlobal global = mappedGlobals[name]; @@ -748,24 +966,26 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { return ret; } else if (what == UNARY_PREFIX) { if (ast[1] == PLUS) { - if (ast[2][0] == NUM) { - auto ret = allocator.alloc<Const>(); - ret->value.type = WasmType::f64; - ret->value.f64 = ast[2][1]->getNumber(); - ret->type = ret->value.type; - return ret; + Literal literal = checkLiteral(ast); + if (literal.type != none) { + return allocator.alloc<Const>()->set(literal); } - AsmType childType = detectAsmType(ast[2], &asmData); - if (childType == ASM_INT) { - auto ret = allocator.alloc<Convert>(); - ret->op = isUnsignedCoercion(ast[2]) ? ConvertUInt32 : ConvertSInt32; - ret->value = process(ast[2]); - ret->type = WasmType::f64; - return ret; + auto ret = process(ast[2]); // we are a +() coercion + if (ret->type == i32) { + auto conv = allocator.alloc<Unary>(); + conv->op = isUnsignedCoercion(ast[2]) ? ConvertUInt32 : ConvertSInt32; + conv->value = ret; + conv->type = WasmType::f64; + return conv; + } + if (ret->type == f32) { + auto conv = allocator.alloc<Unary>(); + conv->op = PromoteFloat32; + conv->value = ret; + conv->type = WasmType::f64; + return conv; } - assert(childType == ASM_NONE || childType == ASM_DOUBLE); // e.g. a coercion on a call or for a return - auto ret = process(ast[2]); // just look through the +() coercion - ret->type = WasmType::f64; // we add it here for e.g. call coercions + fixCallType(ret, f64); return ret; } else if (ast[1] == MINUS) { if (ast[2][0] == NUM || (ast[2][0] == UNARY_PREFIX && ast[2][1] == PLUS && ast[2][2][0] == NUM)) { @@ -784,20 +1004,45 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->type = WasmType::i32; return ret; } - assert(asmType == ASM_DOUBLE); auto ret = allocator.alloc<Unary>(); ret->op = Neg; ret->value = process(ast[2]); - ret->type = WasmType::f64; + if (asmType == ASM_DOUBLE) { + ret->type = WasmType::f64; + } else if (asmType == ASM_FLOAT) { + ret->type = WasmType::f32; + } else { + abort(); + } return ret; } else if (ast[1] == B_NOT) { // ~, might be ~~ as a coercion or just a not if (ast[2][0] == UNARY_PREFIX && ast[2][1] == B_NOT) { - auto ret = allocator.alloc<Convert>(); +#if 0 + auto ret = allocator.alloc<Unary>(); ret->op = TruncSFloat64; // equivalent to U, except for error handling, which asm.js doesn't have anyhow ret->value = process(ast[2][2]); ret->type = WasmType::i32; return ret; +#endif + // WebAssembly traps on float-to-int overflows, but asm.js wouldn't, so we must emulate that + CallImport *ret = allocator.alloc<CallImport>(); + ret->target = F64_TO_INT; + ret->operands.push_back(process(ast[2][2])); + ret->type = i32; + static bool addedImport = false; + if (!addedImport) { + addedImport = true; + auto import = allocator.alloc<Import>(); // f64-to-int = asm2wasm.f64-to-int; + import->name = F64_TO_INT; + import->module = ASM2WASM; + import->base = F64_TO_INT; + import->type.name = F64_TO_INT; + import->type.result = i32; + import->type.params.push_back(f64); + wasm.addImport(import); + } + return ret; } // no bitwise unary not, so do xor with -1 auto ret = allocator.alloc<Binary>(); @@ -808,12 +1053,12 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { return ret; } else if (ast[1] == L_NOT) { // no logical unary not, so do == 0 - auto ret = allocator.alloc<Compare>(); + auto ret = allocator.alloc<Binary>(); ret->op = Eq; ret->left = process(ast[2]); ret->right = allocator.alloc<Const>()->set(Literal(0)); assert(ret->left->type == ret->right->type); - ret->inputType = ret->left->type; + ret->finalize(); return ret; } abort_on("bad unary", ast); @@ -843,20 +1088,96 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->type = WasmType::i32; return ret; } + if (name == Math_fround) { + assert(ast[2]->size() == 1); + Literal lit = checkLiteral(ast[2][0]); + if (lit.type == i32) { + return allocator.alloc<Const>()->set(Literal((float)lit.geti32())); + } else if (lit.type == f64) { + return allocator.alloc<Const>()->set(Literal((float)lit.getf64())); + } + auto ret = allocator.alloc<Unary>(); + ret->value = process(ast[2][0]); + if (ret->value->type == f64) { + ret->op = DemoteFloat64; + } else if (ret->value->type == i32) { + ret->op = ConvertSInt32; + } else if (ret->value->type == f32) { + return ret->value; + } else if (ret->value->type == none) { // call, etc. + ret->value->type = f32; + return ret->value; + } else { + abort_on("confusing fround target", ast[2][0]); + } + ret->type = f32; + return ret; + } + if (name == Math_abs) { + // overloaded on type: i32, f32 or f64 + Expression* value = process(ast[2][0]); + if (value->type == i32) { + // No wasm support, so use a temp local + ensureI32Temp(); + auto set = allocator.alloc<SetLocal>(); + set->name = I32_TEMP; + set->value = value; + set->type = i32; + auto get = [&]() { + auto ret = allocator.alloc<GetLocal>(); + ret->name = I32_TEMP; + ret->type = i32; + return ret; + }; + auto isNegative = allocator.alloc<Binary>(); + isNegative->op = LtS; + isNegative->left = get(); + isNegative->right = allocator.alloc<Const>()->set(0); + isNegative->finalize(); + auto block = allocator.alloc<Block>(); + block->list.push_back(set); + auto flip = allocator.alloc<Binary>(); + flip->op = Sub; + flip->left = allocator.alloc<Const>()->set(0); + flip->right = get(); + flip->type = i32; + auto select = allocator.alloc<Select>(); + select->condition = isNegative; + select->ifTrue = flip; + select->ifFalse = get(); + select->type = i32; + block->list.push_back(select); + block->type = i32; + return block; + } else if (value->type == f32 || value->type == f64) { + auto ret = allocator.alloc<Unary>(); + ret->op = Abs; + ret->value = value; + ret->type = value->type; + return ret; + } else { + abort(); + } + } + if (name == Math_floor || name == Math_sqrt) { + // overloaded on type: f32 or f64 + Expression* value = process(ast[2][0]); + if (value->type == f32 || value->type == f64) { + auto ret = allocator.alloc<Unary>(); + ret->op = name == Math_floor ? Floor : Sqrt; + ret->value = value; + ret->type = value->type; + return ret; + } else { + abort(); + } + } Call* ret; if (wasm.importsMap.find(name) != wasm.importsMap.end()) { Ref parent = astStackHelper.getParent(); WasmType type = !!parent ? detectWasmType(parent, &asmData) : none; -#ifndef __EMSCRIPTEN__ - // no imports yet in reference interpreter, fake it - if (type == none) return allocator.alloc<Nop>(); - if (type == i32) return allocator.alloc<Const>()->set(Literal((int32_t)0)); - if (type == f64) return allocator.alloc<Const>()->set(Literal((double)0.0)); - abort(); -#else ret = allocator.alloc<CallImport>(); noteImportedFunctionCall(ast, type, &asmData); -#endif } else { ret = allocator.alloc<Call>(); } @@ -871,12 +1192,14 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto ret = allocator.alloc<CallIndirect>(); Ref target = ast[1]; assert(target[0] == SUB && target[1][0] == NAME && target[2][0] == BINARY && target[2][1] == AND && target[2][3][0] == NUM); // FUNCTION_TABLE[(expr) & mask] - ret->target = process(target[2][2]); + ret->target = process(target[2]); // TODO: as an optimization, we could look through the mask Ref args = ast[2]; for (unsigned i = 0; i < args->size(); i++) { ret->operands.push_back(process(args[i])); } - ret->type = getFunctionType(astStackHelper.getParent(), ret->operands); + ret->fullType = getFunctionType(astStackHelper.getParent(), ret->operands); + ret->type = ret->fullType->result; + callIndirects[ret] = target[1][1]->getIString(); // we need to fix this up later, when we know how asm function tables are layed out inside the wasm table. return ret; } else if (what == RETURN) { WasmType type = !!ast[1] ? detectWasmType(ast[1], &asmData) : none; @@ -892,7 +1215,26 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->value = !!ast[1] ? process(ast[1]) : nullptr; return ret; } else if (what == BLOCK) { - return processStatements(ast[1], 0); + Name name; + if (parentLabel.is()) { + name = getBreakLabelName(parentLabel); + parentLabel = IString(); + breakStack.push_back(name); + } + auto ret = processStatements(ast[1], 0); + if (name.is()) { + breakStack.pop_back(); + Block* block = ret->dyn_cast<Block>(); + if (block && block->name.isNull()) { + block->name = name; + } else { + block = allocator.alloc<Block>(); + block->name = name; + block->list.push_back(ret); + ret = block; + } + } + return ret; } else if (what == BREAK) { auto ret = allocator.alloc<Break>(); assert(breakStack.size() > 0); @@ -933,6 +1275,12 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { body->list.push_back(process(ast[2])); ret->body = body; } + // loops do not automatically loop, add a branch back + Block* block = blockify(ret->body); + auto continuer = allocator.alloc<Break>(); + continuer->name = ret->in; + block->list.push_back(continuer); + ret->body = block; continueStack.pop_back(); breakStack.pop_back(); return ret; @@ -973,22 +1321,15 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->body = process(ast[2]); continueStack.pop_back(); breakStack.pop_back(); - Break *breakOut = allocator.alloc<Break>(); - breakOut->name = out; - If *condition = allocator.alloc<If>(); - condition->condition = process(ast[1]); - condition->ifTrue = allocator.alloc<Nop>(); - condition->ifFalse = breakOut; - if (Block *block = ret->body->dyn_cast<Block>()) { - block->list.push_back(condition); - } else { - auto newBody = allocator.alloc<Block>(); - newBody->list.push_back(ret->body); - newBody->list.push_back(condition); - ret->body = newBody; - } + Break *continuer = allocator.alloc<Break>(); + continuer->name = in; + continuer->condition = process(ast[1]); + Block *block = blockify(ret->body); + block->list.push_back(continuer); + ret->body = block; return ret; } else if (what == LABEL) { + assert(parentLabel.isNull()); parentLabel = ast[1]->getIString(); return process(ast[2]); } else if (what == CONDITIONAL) { @@ -1005,8 +1346,13 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->type = ret->list[1]->type; return ret; } else if (what == SWITCH) { - // XXX switch is still in flux in the spec repo, just emit a placeholder - IString name = getNextId("switch"); + IString name; + if (!parentLabel.isNull()) { + name = getBreakLabelName(parentLabel); + parentLabel = IString(); + } else { + name = getNextId("switch"); + } breakStack.push_back(name); auto ret = allocator.alloc<Switch>(); ret->name = name; @@ -1118,6 +1464,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } // cleanups/checks assert(breakStack.size() == 0 && continueStack.size() == 0); + assert(parentLabel.isNull()); return function; } diff --git a/src/asm_v_wasm.h b/src/asm_v_wasm.h new file mode 100644 index 000000000..1bf689802 --- /dev/null +++ b/src/asm_v_wasm.h @@ -0,0 +1,70 @@ +#ifndef _asm_v_wasm_h_ +#define _asm_v_wasm_h_ + +#include "emscripten-optimizer/optimizer.h" + +namespace wasm { + +WasmType asmToWasmType(AsmType asmType) { + switch (asmType) { + case ASM_INT: return WasmType::i32; + case ASM_DOUBLE: return WasmType::f64; + case ASM_FLOAT: return WasmType::f32; + case ASM_NONE: return WasmType::none; + default: {} + } + abort(); +} + +AsmType wasmToAsmType(WasmType type) { + switch (type) { + case WasmType::i32: return ASM_INT; + case WasmType::f32: return ASM_FLOAT; + case WasmType::f64: return ASM_DOUBLE; + case WasmType::none: return ASM_NONE; + default: {} + } + abort(); +} + +char getSig(WasmType type) { + switch (type) { + case i32: return 'i'; + case f32: return 'f'; + case f64: return 'd'; + case none: return 'v'; + default: abort(); + } +} + +std::string getSig(FunctionType *type) { + std::string ret; + ret += getSig(type->result); + for (auto param : type->params) { + ret += getSig(param); + } + return ret; +} + +std::string getSig(Function *func) { + std::string ret; + ret += getSig(func->result); + for (auto param : func->params) { + ret += getSig(param.type); + } + return ret; +} + +std::string getSig(CallBase *call) { + std::string ret; + ret += getSig(call->type); + for (auto operand : call->operands) { + ret += getSig(operand->type); + } + return ret; +} + +} // namespace wasm + +#endif // _asm_v_wasm_h_ + diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp index 3ec9e684b..1dd2c1c7c 100644 --- a/src/binaryen-shell.cpp +++ b/src/binaryen-shell.cpp @@ -6,6 +6,7 @@ // #include <setjmp.h> +#include <memory> #include "wasm-s-parser.h" #include "wasm-interpreter.h" @@ -15,6 +16,10 @@ using namespace cashew; using namespace wasm; +namespace wasm { +int debug = 0; +} + // Globals MixedArena globalAllocator; @@ -22,7 +27,7 @@ MixedArena globalAllocator; IString ASSERT_RETURN("assert_return"), ASSERT_TRAP("assert_trap"), ASSERT_INVALID("assert_invalid"), - STDIO("stdio"), + SPECTEST("spectest"), PRINT("print"), INVOKE("invoke"); @@ -36,7 +41,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { ShellExternalInterface() : memory(nullptr) {} void init(Module& wasm) override { - memory = (char*)malloc(wasm.memory.initial); + memory = (char*)calloc(wasm.memory.initial, 1); // apply memory segments for (auto segment : wasm.memory.segments) { memcpy(memory + segment.offset, segment.data, segment.size); @@ -44,7 +49,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } Literal callImport(Import *import, ModuleInstance::LiteralList& arguments) override { - if (import->module == STDIO && import->base == PRINT) { + if (import->module == SPECTEST && import->base == PRINT) { for (auto argument : arguments) { std::cout << argument << '\n'; } @@ -104,19 +109,24 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } break; } - case f32: *((float*)(memory+addr)) = value.getf32(); break; - case f64: *((double*)(memory+addr)) = value.getf64(); break; + // write floats carefully, ensuring all bits reach memory + case f32: *((int32_t*)(memory+addr)) = value.reinterpreti32(); break; + case f64: *((int64_t*)(memory+addr)) = value.reinterpreti64(); break; default: abort(); } } void growMemory(size_t oldSize, size_t newSize) override { memory = (char*)realloc(memory, newSize); + if (newSize > oldSize) { + memset(memory + oldSize, 0, newSize - oldSize); + } } jmp_buf trapState; - void trap() override { + void trap(const char* why) override { + std::cerr << "[trap " << why << "]\n"; longjmp(trapState, 1); } }; @@ -232,7 +242,7 @@ int main(int argc, char **argv) { size_t i = 0; while (i < root.size()) { if (debug) std::cerr << "parsing s-expressions to wasm...\n"; - Module wasm; + AllocatingModule wasm; SExpressionWasmBuilder builder(wasm, *root[i], [&]() { abort(); }); i++; @@ -241,7 +251,7 @@ int main(int argc, char **argv) { if (print_before) { Colors::bold(std::cout); - std::cout << "printing before:\n"; + std::cerr << "printing before:\n"; Colors::normal(std::cout); std::cout << wasm; } @@ -259,7 +269,7 @@ int main(int argc, char **argv) { if (print_after) { Colors::bold(std::cout); - std::cout << "printing after:\n"; + std::cerr << "printing after:\n"; Colors::normal(std::cout); std::cout << wasm; } @@ -278,14 +288,21 @@ int main(int argc, char **argv) { std::cerr << curr << '\n'; if (id == ASSERT_INVALID) { // a module invalidity test - Module wasm; + AllocatingModule wasm; bool invalid = false; jmp_buf trapState; + std::unique_ptr<SExpressionWasmBuilder> builder; if (setjmp(trapState) == 0) { - SExpressionWasmBuilder builder(wasm, *curr[1], [&]() { + builder = std::unique_ptr<SExpressionWasmBuilder>(new SExpressionWasmBuilder(wasm, *curr[1], [&]() { invalid = true; longjmp(trapState, 1); - }); + })); + } + if (print_before || print_after) { + Colors::bold(std::cout); + std::cerr << "printing in module invalidity test:\n"; + Colors::normal(std::cout); + std::cout << wasm; } if (!invalid) { // maybe parsed ok, but otherwise incorrect diff --git a/src/emscripten-optimizer/optimizer-shared.cpp b/src/emscripten-optimizer/optimizer-shared.cpp index 1c36efc62..7fe91eeab 100644 --- a/src/emscripten-optimizer/optimizer-shared.cpp +++ b/src/emscripten-optimizer/optimizer-shared.cpp @@ -19,6 +19,12 @@ bool isInteger32(double x) { return isInteger(x) && (x == (int32_t)x || x == (uint32_t)x); } +int32_t toInteger32(double x) { + if (x == (int32_t)x) return (int32_t)x; + assert(x == (uint32_t)x); + return (uint32_t)x; +} + int parseInt(const char *str) { int ret = *str - '0'; while (*(++str)) { @@ -42,7 +48,7 @@ HeapInfo parseHeap(const char *name) { return ret; } -AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { +AsmType detectType(Ref node, AsmData *asmData, bool inVarDef, IString minifiedFround) { switch (node[0]->getCString()[0]) { case 'n': { if (node[0] == NUM) { @@ -69,7 +75,7 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { if (node[0] == UNARY_PREFIX) { switch (node[1]->getCString()[0]) { case '+': return ASM_DOUBLE; - case '-': return detectType(node[2], asmData, inVarDef); + case '-': return detectType(node[2], asmData, inVarDef, minifiedFround); case '!': case '~': return ASM_INT; } break; @@ -80,7 +86,7 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { if (node[0] == CALL) { if (node[1][0] == NAME) { IString name = node[1][1]->getIString(); - if (name == MATH_FROUND) return ASM_FLOAT; + if (name == MATH_FROUND || name == minifiedFround) return ASM_FLOAT; else if (name == SIMD_FLOAT32X4 || name == SIMD_FLOAT32X4_CHECK) return ASM_FLOAT32X4; else if (name == SIMD_FLOAT64X2 || name == SIMD_FLOAT64X2_CHECK) return ASM_FLOAT64X2; else if (name == SIMD_INT8X16 || name == SIMD_INT8X16_CHECK) return ASM_INT8X16; @@ -89,7 +95,7 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { } return ASM_NONE; } else if (node[0] == CONDITIONAL) { - return detectType(node[2], asmData, inVarDef); + return detectType(node[2], asmData, inVarDef, minifiedFround); } break; } @@ -97,7 +103,7 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { if (node[0] == BINARY) { switch (node[1]->getCString()[0]) { case '+': case '-': - case '*': case '/': case '%': return detectType(node[2], asmData, inVarDef); + case '*': case '/': case '%': return detectType(node[2], asmData, inVarDef, minifiedFround); case '|': case '&': case '^': case '<': case '>': // handles <<, >>, >>=, <=, >= case '=': case '!': { // handles ==, != return ASM_INT; @@ -108,7 +114,7 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { } case 's': { if (node[0] == SEQ) { - return detectType(node[2], asmData, inVarDef); + return detectType(node[2], asmData, inVarDef, minifiedFround); } else if (node[0] == SUB) { assert(node[1][0] == NAME); HeapInfo info = parseHeap(node[1][1]->getCString()); @@ -123,3 +129,107 @@ AsmType detectType(Ref node, AsmData *asmData, bool inVarDef) { return ASM_NONE; } +static void abort_on(Ref node) { + node->stringify(std::cerr); + std::cerr << '\n'; + abort(); +} + +AsmSign detectSign(Ref node, IString minifiedFround) { + IString type = node[0]->getIString(); + if (type == BINARY) { + IString op = node[1]->getIString(); + switch (op.str[0]) { + case '>': { + if (op == TRSHIFT) return ASM_UNSIGNED; + // fallthrough + } + case '|': case '&': case '^': case '<': case '=': case '!': return ASM_SIGNED; + case '+': case '-': return ASM_FLEXIBLE; + case '*': case '/': return ASM_NONSIGNED; // without a coercion, these are double + default: abort_on(node); + } + } else if (type == UNARY_PREFIX) { + IString op = node[1]->getIString(); + switch (op.str[0]) { + case '-': return ASM_FLEXIBLE; + case '+': return ASM_NONSIGNED; // XXX double + case '~': return ASM_SIGNED; + default: abort_on(node); + } + } else if (type == NUM) { + double value = node[1]->getNumber(); + if (value < 0) return ASM_SIGNED; + if (value > uint32_t(-1) || fmod(value, 1) != 0) return ASM_NONSIGNED; + if (value == int32_t(value)) return ASM_FLEXIBLE; + return ASM_UNSIGNED; + } else if (type == NAME) { + return ASM_FLEXIBLE; + } else if (type == CONDITIONAL) { + return detectSign(node[2], minifiedFround); + } else if (type == CALL) { + if (node[1][0] == NAME && (node[1][1] == MATH_FROUND || node[1][1] == minifiedFround)) return ASM_NONSIGNED; + } else if (type == SEQ) { + return detectSign(node[2], minifiedFround); + } + abort_on(node); + abort(); // avoid warning +} + +Ref makeAsmCoercedZero(AsmType type) { + switch (type) { + case ASM_INT: return ValueBuilder::makeNum(0); break; + case ASM_DOUBLE: return ValueBuilder::makeUnary(PLUS, ValueBuilder::makeNum(0)); break; + case ASM_FLOAT: { + if (!ASM_FLOAT_ZERO.isNull()) { + return ValueBuilder::makeName(ASM_FLOAT_ZERO); + } else { + return ValueBuilder::makeCall(MATH_FROUND, ValueBuilder::makeNum(0)); + } + break; + } + case ASM_FLOAT32X4: { + return ValueBuilder::makeCall(SIMD_FLOAT32X4, ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0)); + break; + } + case ASM_FLOAT64X2: { + return ValueBuilder::makeCall(SIMD_FLOAT64X2, ValueBuilder::makeNum(0), ValueBuilder::makeNum(0)); + break; + } + case ASM_INT8X16: { + return ValueBuilder::makeCall(SIMD_INT8X16, ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0)); + break; + } + case ASM_INT16X8: { + return ValueBuilder::makeCall(SIMD_INT16X8, ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0)); + break; + } + case ASM_INT32X4: { + return ValueBuilder::makeCall(SIMD_INT32X4, ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0), ValueBuilder::makeNum(0)); + break; + } + default: assert(0); + } + abort(); +} + +Ref makeAsmCoercion(Ref node, AsmType type) { + switch (type) { + case ASM_INT: return ValueBuilder::makeBinary(node, OR, ValueBuilder::makeNum(0)); + case ASM_DOUBLE: return ValueBuilder::makeUnary(PLUS, node); + case ASM_FLOAT: return ValueBuilder::makeCall(MATH_FROUND, node); + case ASM_FLOAT32X4: return ValueBuilder::makeCall(SIMD_FLOAT32X4_CHECK, node); + case ASM_FLOAT64X2: return ValueBuilder::makeCall(SIMD_FLOAT64X2_CHECK, node); + case ASM_INT8X16: return ValueBuilder::makeCall(SIMD_INT8X16_CHECK, node); + case ASM_INT16X8: return ValueBuilder::makeCall(SIMD_INT16X8_CHECK, node); + case ASM_INT32X4: return ValueBuilder::makeCall(SIMD_INT32X4_CHECK, node); + case ASM_NONE: + default: return node; // non-validating code, emit nothing XXX this is dangerous, we should only allow this when we know we are not validating + } +} + +Ref makeSigning(Ref node, AsmSign sign) { + assert(sign == ASM_SIGNED || sign == ASM_UNSIGNED); + return ValueBuilder::makeBinary(node, sign == ASM_SIGNED ? OR : TRSHIFT, ValueBuilder::makeNum(0)); +} + diff --git a/src/emscripten-optimizer/optimizer.h b/src/emscripten-optimizer/optimizer.h index 5edcaad87..6850214c2 100644 --- a/src/emscripten-optimizer/optimizer.h +++ b/src/emscripten-optimizer/optimizer.h @@ -39,7 +39,7 @@ enum AsmType { struct AsmData; -AsmType detectType(cashew::Ref node, AsmData *asmData=nullptr, bool inVarDef=false); +AsmType detectType(cashew::Ref node, AsmData *asmData=nullptr, bool inVarDef=false, cashew::IString minifiedFround=cashew::IString()); struct AsmData { struct Local { @@ -102,8 +102,8 @@ struct AsmData { }; bool isInteger(double x); - bool isInteger32(double x); +int32_t toInteger32(double x); extern cashew::IString ASM_FLOAT_ZERO; @@ -123,5 +123,24 @@ struct HeapInfo { HeapInfo parseHeap(const char *name); +enum AsmSign { + ASM_FLEXIBLE = 0, // small constants can be signed or unsigned, variables are also flexible + ASM_SIGNED = 1, + ASM_UNSIGNED = 2, + ASM_NONSIGNED = 3, +}; + +extern AsmSign detectSign(cashew::Ref node, cashew::IString minifiedFround); + +inline cashew::Ref deStat(cashew::Ref node) { + if (node[0] == cashew::STAT) return node[1]; + return node; +} + +cashew::Ref makeAsmCoercedZero(AsmType type); +cashew::Ref makeAsmCoercion(cashew::Ref node, AsmType type); + +cashew::Ref makeSigning(cashew::Ref node, AsmSign sign); + #endif // __optimizer_h__ diff --git a/src/emscripten-optimizer/parser.h b/src/emscripten-optimizer/parser.h index 8ca07c6ed..386195d77 100644 --- a/src/emscripten-optimizer/parser.h +++ b/src/emscripten-optimizer/parser.h @@ -830,6 +830,16 @@ class Parser { src++; return Builder::makeBlock(); // we don't need the brackets here, but oh well } + if (*src == '{') { // detect a trivial {} in a statement context + char *before = src; + src++; + skipSpace(src); + if (*src == '}') { + src++; + return Builder::makeBlock(); // we don't need the brackets here, but oh well + } + src = before; + } NodeRef ret = parseElement(src, seps); skipSpace(src); if (*src == ';') { diff --git a/src/emscripten-optimizer/simple_ast.cpp b/src/emscripten-optimizer/simple_ast.cpp index 89b8ce8ed..dcecab33d 100644 --- a/src/emscripten-optimizer/simple_ast.cpp +++ b/src/emscripten-optimizer/simple_ast.cpp @@ -253,7 +253,7 @@ void traverseFunctions(Ref ast, std::function<void (Ref)> visit) { // ValueBuilder -IStringSet ValueBuilder::statable("assign call binary unary-prefix if name num conditional dot new sub seq string object array"); +IStringSet ValueBuilder::statable("assign call binary unary-prefix name num conditional dot new sub seq string object array"); } // namespace cashew diff --git a/src/emscripten-optimizer/simple_ast.h b/src/emscripten-optimizer/simple_ast.h index 4012ff8f6..5952eee72 100644 --- a/src/emscripten-optimizer/simple_ast.h +++ b/src/emscripten-optimizer/simple_ast.h @@ -1268,9 +1268,20 @@ struct JSPrinter { pretty ? emit(", ") : emit(','); newline(); } - emit('"'); - emit(args[i][0]->getCString()); - emit("\":"); + const char *str = args[i][0]->getCString(); + const char *check = str; + bool needQuote = false; + while (*check) { + if (!isalnum(*check) && *check != '_' && *check != '$') { + needQuote = true; + break; + } + check++; + } + if (needQuote) emit('"'); + emit(str); + if (needQuote) emit('"'); + emit(":"); space(); print(args[i][1]); } @@ -1289,15 +1300,15 @@ class ValueBuilder { return &arena.alloc()->setString(s); } - static Ref makeRawArray(int size_hint=0) { - return &arena.alloc()->setArray(size_hint); - } - static Ref makeNull() { return &arena.alloc()->setNull(); } public: + static Ref makeRawArray(int size_hint=0) { + return &arena.alloc()->setArray(size_hint); + } + static Ref makeToplevel() { return &makeRawArray(2)->push_back(makeRawString(TOPLEVEL)) .push_back(makeRawArray()); @@ -1336,6 +1347,80 @@ public: .push_back(target) .push_back(makeRawArray()); } + static Ref makeCall(Ref target, Ref arg) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(target) + .push_back(makeRawArray()); + ret[2]->push_back(arg); + return ret; + } + static Ref makeCall(IString target) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray()); + return ret; + } + static Ref makeCall(IString target, Ref arg) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray(1)); + ret[2]->push_back(arg); + return ret; + } + static Ref makeCall(IString target, Ref arg1, Ref arg2) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray(2)); + ret[2]->push_back(arg1); + ret[2]->push_back(arg2); + return ret; + } + static Ref makeCall(IString target, Ref arg1, Ref arg2, Ref arg3, Ref arg4) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray(4)); + ret[2]->push_back(arg1); + ret[2]->push_back(arg2); + ret[2]->push_back(arg3); + ret[2]->push_back(arg4); + return ret; + } + static Ref makeCall(IString target, Ref arg1, Ref arg2, Ref arg3, Ref arg4, Ref arg5, Ref arg6, Ref arg7, Ref arg8) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray(8)); + ret[2]->push_back(arg1); + ret[2]->push_back(arg2); + ret[2]->push_back(arg3); + ret[2]->push_back(arg4); + ret[2]->push_back(arg5); + ret[2]->push_back(arg6); + ret[2]->push_back(arg7); + ret[2]->push_back(arg8); + return ret; + } + static Ref makeCall(IString target, Ref arg1, Ref arg2, Ref arg3, Ref arg4, Ref arg5, Ref arg6, Ref arg7, Ref arg8, Ref arg9, Ref arg10, Ref arg11, Ref arg12, Ref arg13, Ref arg14, Ref arg15, Ref arg16) { + Ref ret = &makeRawArray(3)->push_back(makeRawString(CALL)) + .push_back(makeName(target)) + .push_back(makeRawArray(16)); + ret[2]->push_back(arg1); + ret[2]->push_back(arg2); + ret[2]->push_back(arg3); + ret[2]->push_back(arg4); + ret[2]->push_back(arg5); + ret[2]->push_back(arg6); + ret[2]->push_back(arg7); + ret[2]->push_back(arg8); + ret[2]->push_back(arg9); + ret[2]->push_back(arg10); + ret[2]->push_back(arg11); + ret[2]->push_back(arg12); + ret[2]->push_back(arg13); + ret[2]->push_back(arg14); + ret[2]->push_back(arg15); + ret[2]->push_back(arg16); + return ret; + } static void appendToCall(Ref call, Ref element) { assert(call[0] == CALL); @@ -1358,6 +1443,15 @@ public: static Ref makeInt(uint32_t num) { return makeDouble(double(num)); } + static Ref makeNum(double num) { + return makeDouble(num); + } + + static Ref makeUnary(IString op, Ref value) { + return &makeRawArray(3)->push_back(makeRawString(UNARY_PREFIX)) + .push_back(makeRawString(op)) + .push_back(value); + } static Ref makeBinary(Ref left, IString op, Ref right) { if (op == SET) { @@ -1395,7 +1489,7 @@ public: func[2]->push_back(makeRawString(arg)); } - static Ref makeVar(bool is_const) { + static Ref makeVar(bool is_const=false) { return &makeRawArray(2)->push_back(makeRawString(VAR)) .push_back(makeRawArray()); } @@ -1432,6 +1526,12 @@ public: .push_back(ifFalse); } + static Ref makeSeq(Ref left, Ref right) { + return &makeRawArray(3)->push_back(makeRawString(SEQ)) + .push_back(left) + .push_back(right); + } + static Ref makeDo(Ref body, Ref condition) { return &makeRawArray(3)->push_back(makeRawString(DO)) .push_back(condition) @@ -1526,6 +1626,29 @@ public: array[1]->push_back(&makeRawArray(2)->push_back(makeRawString(key)) .push_back(value)); } + + static Ref makeAssign(Ref target, Ref value) { + return &makeRawArray(3)->push_back(makeRawString(ASSIGN)) + .push_back(&arena.alloc()->setBool(true)) + .push_back(target) + .push_back(value); + } + static Ref makeAssign(IString target, Ref value) { + return &makeRawArray(3)->push_back(makeRawString(ASSIGN)) + .push_back(&arena.alloc()->setBool(true)) + .push_back(makeName(target)) + .push_back(value); + } + + static Ref makeSub(Ref obj, Ref index) { + return &makeRawArray(2)->push_back(makeRawString(SUB)) + .push_back(obj) + .push_back(index); + } + + static Ref makePtrShift(Ref ptr, int shifts) { + return makeBinary(ptr, RSHIFT, makeInt(shifts)); + } }; // Tolerates 0.0 in the input; does not trust a +() to be there. diff --git a/src/js/post.js b/src/js/post.js index 6406674ba..7f7ca9c1e 100644 --- a/src/js/post.js +++ b/src/js/post.js @@ -1,38 +1,121 @@ -(function() { - var wasmJS = WasmJS({}); // do not use the normal Module in the current scope +function integrateWasmJS(Module) { + // wasm.js has several methods for creating the compiled code module here: + // * 'wasm-s-parser': load s-expression code from a .wast and create wasm + // * 'asm2wasm': load asm.js code and translate to wasm + // * 'just-asm': no wasm, just load the asm.js code and use that (good for testing) + // The method can be set at compile time (BINARYEN_METHOD), or runtime by setting Module['wasmJSMethod']. + var method = Module['wasmJSMethod'] || 'wasm-s-parser'; + assert(method == 'asm2wasm' || method == 'wasm-s-parser' || method == 'just-asm'); - // XXX don't be confused. Module here is in the outside program. wasmJS is the inner wasm-js.cpp. + if (method == 'just-asm') { + eval(Module['read'](Module['asmjsCodeFile'])); + return; + } + + var asm2wasmImports = { // special asm2wasm imports + "f64-rem": function(x, y) { + return x % y; + }, + "f64-to-int": function(x) { + return x | 0; + }, + "debugger": function() { + debugger; + }, + }; + + function flatten(obj) { + var ret = {}; + for (var x in obj) { + for (var y in obj[x]) { + if (ret[y]) Module['printErr']('warning: flatten dupe: ' + y); + ret[y] = obj[x][y]; + } + } + return ret; + } + + // wasm lacks globals, so asm2wasm maps them into locations in memory. that information cannot + // be present in the wasm output of asm2wasm, so we store it in a side file. If we load asm2wasm + // output, either generated ahead of time or on the client, we need to apply those mapped + // globals after loading the module. + function applyMappedGlobals() { + var mappedGlobals = JSON.parse(Module['read'](Module['wasmCodeFile'] + '.mappedGlobals')); + for (var name in mappedGlobals) { + var global = mappedGlobals[name]; + if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here + var value = wasmJS['lookupImport'](global.module, global.base); + var address = global.address; + switch (global.type) { + case WasmTypes.i32: Module['HEAP32'][address >> 2] = value; break; + case WasmTypes.f32: Module['HEAPF32'][address >> 2] = value; break; + case WasmTypes.f64: Module['HEAPF64'][address >> 3] = value; break; + default: abort(); + } + } + } + + if (typeof WASM === 'object') { + // Provide an "asm.js function" for the application, called to "link" the asm.js module. We instantiate + // the wasm module at that time, and it receives imports and provides exports and so forth, the app + // doesn't need to care that it is wasm and not asm. + Module['asm'] = function(global, env, providedBuffer) { + // Load the wasm module + var binary = Module['readBinary'](Module['wasmCodeFile']); + + // Create an instance of the module using native support in the JS engine. + var instance = WASM.instantiateModule(binary, flatten({ // XXX for now, flatten the imports + "global.Math": global.Math, + "env": env, + "asm2wasm": asm2wasmImports + })); - // Generate a module instance of the asm.js converted into wasm. - var code; - if (typeof read === 'function') { - // spidermonkey or v8 shells - code = read(Module['asmjsCodeFile']); - } else if (typeof process === 'object' && typeof require === 'function') { - // node.js - code = require('fs')['readFileSync'](Module['asmjsCodeFile']).toString(); - } else { - throw 'TODO: loading in other platforms'; + // The wasm instance creates its memory. But static init code might have written to + // buffer already, and we must copy it over. + // TODO: avoid this copy, by avoiding such static init writes + // TODO: in shorter term, just copy up to the last static init write + var oldBuffer = Module['buffer']; + var newBuffer = instance.memory; + assert(newBuffer.byteLength >= oldBuffer.byteLength, 'we might fail if we allocated more than TOTAL_MEMORY'); + // the wasm module does write out the memory initialization, in range STATIC_BASE..STATIC_BUMP, so avoid that + (new Int8Array(newBuffer).subarray(0, STATIC_BASE)).set(new Int8Array(oldBuffer).subarray(0, STATIC_BASE)); + (new Int8Array(newBuffer).subarray(STATIC_BASE + STATIC_BUMP)).set(new Int8Array(oldBuffer).subarray(STATIC_BASE + STATIC_BUMP)); + updateGlobalBuffer(newBuffer); + updateGlobalBufferViews(); + Module['reallocBuffer'] = function(size) { + var old = Module['buffer']; + wasmJS['asmExports']['__growWasmMemory'](size); // tiny wasm method that just does grow_memory + return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed + }; + + applyMappedGlobals(); + + return instance; + }; + + return; } - var temp = wasmJS._malloc(code.length + 1); - wasmJS.writeAsciiToMemory(code, temp); - wasmJS._load_asm(temp); - wasmJS._free(temp); + var WasmTypes = { + none: 0, + i32: 1, + i64: 2, + f32: 3, + f64: 4 + }; + + // Use wasm.js to polyfill and execute code in a wasm interpreter. + var wasmJS = WasmJS({}); - // Generate memory XXX TODO get the right size - var theBuffer = Module['buffer'] = new ArrayBuffer(Module['providedTotalMemory'] || 64*1024*1024); + // XXX don't be confused. Module here is in the outside program. wasmJS is the inner wasm-js.cpp. + wasmJS['outside'] = Module; // Inside wasm-js.cpp, Module['outside'] reaches the outside module. // Information for the instance of the module. var info = wasmJS['info'] = { global: null, env: null, - asm2wasm: { // special asm2wasm imports - "f64-rem": function(x, y) { - return x % y; - }, - }, + asm2wasm: asm2wasmImports, parent: Module // Module inside wasm-js.cpp refers to wasm-js.cpp; this allows access to the outside program. }; @@ -52,14 +135,48 @@ return lookup; } - // The asm.js function, called to "link" the asm.js module. - Module['asm'] = function(global, env, buffer) { - assert(buffer === theBuffer); // we should not even need to pass it as a 3rd arg for wasm, but that's the asm.js way. - // write the provided data to a location the wasm instance can get at it. + // The asm.js function, called to "link" the asm.js module. At that time, we are provided imports + // and respond with exports, and so forth. + Module['asm'] = function(global, env, providedBuffer) { + assert(providedBuffer === Module['buffer']); // we should not even need to pass it as a 3rd arg for wasm, but that's the asm.js way. + info.global = global; info.env = env; - wasmJS['_load_mapped_globals'](); // now that we have global and env, we can ready the provided imported globals, copying them to their mapped locations. + + // wasm code would create its own buffer, at this time. But static init code might have + // written to the buffer already, and we must copy it over. We could just avoid + // this copy in wasm.js polyfilling, but to be as close as possible to real wasm, + // we do what wasm would do. + // TODO: avoid this copy, by avoiding such static init writes + // TODO: in shorter term, just copy up to the last static init write + var oldBuffer = Module['buffer']; + var newBuffer = new ArrayBuffer(oldBuffer.byteLength); + (new Int8Array(newBuffer)).set(new Int8Array(oldBuffer)); + updateGlobalBuffer(newBuffer); + updateGlobalBufferViews(); + wasmJS['providedTotalMemory'] = Module['buffer'].byteLength; + + Module['reallocBuffer'] = function(size) { + var old = Module['buffer']; + wasmJS['asmExports']['__growWasmMemory'](size); // tiny wasm method that just does grow_memory + return Module['buffer'] !== old ? Module['buffer'] : null; // if it was reallocated, it changed + }; + + // Prepare to generate wasm, using either asm2wasm or wasm-s-parser + var code = Module['read'](method == 'asm2wasm' ? Module['asmjsCodeFile'] : Module['wasmCodeFile']); + var temp = wasmJS['_malloc'](code.length + 1); + wasmJS['writeAsciiToMemory'](code, temp); + if (method == 'asm2wasm') { + wasmJS['_load_asm2wasm'](temp); + } else { + wasmJS['_load_s_expr2wasm'](temp); + applyMappedGlobals(); + } + wasmJS['_free'](temp); + + wasmJS['_instantiate'](temp); + return wasmJS['asmExports']; }; -})(); +} diff --git a/src/mixed_arena.h b/src/mixed_arena.h index 4d0d85c67..b2fbe06a6 100644 --- a/src/mixed_arena.h +++ b/src/mixed_arena.h @@ -41,5 +41,16 @@ struct MixedArena { extern MixedArena globalAllocator; +#ifdef __wasm_h__ +namespace wasm { + +class AllocatingModule : public Module { +public: + MixedArena allocator; +}; + +} +#endif + #endif // mixed_arena_h diff --git a/src/parsing.h b/src/parsing.h new file mode 100644 index 000000000..59e8fc5ae --- /dev/null +++ b/src/parsing.h @@ -0,0 +1,153 @@ + +#include <sstream> + +#include "wasm.h" +#include "shared-constants.h" +#include "mixed_arena.h" + +namespace wasm { + +Expression* parseConst(cashew::IString s, WasmType type, MixedArena& allocator) { + const char *str = s.str; + auto ret = allocator.alloc<Const>(); + ret->type = ret->value.type = type; + if (isWasmTypeFloat(type)) { + if (s == INFINITY_) { + switch (type) { + case f32: ret->value.f32 = std::numeric_limits<float>::infinity(); break; + case f64: ret->value.f64 = std::numeric_limits<double>::infinity(); break; + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; + } + if (s == NEG_INFINITY) { + switch (type) { + case f32: ret->value.f32 = -std::numeric_limits<float>::infinity(); break; + case f64: ret->value.f64 = -std::numeric_limits<double>::infinity(); break; + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; + } + if (s == NAN_) { + switch (type) { + case f32: ret->value.f32 = std::nan(""); break; + case f64: ret->value.f64 = std::nan(""); break; + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; + } + bool negative = str[0] == '-'; + const char *positive = negative ? str + 1 : str; + if (positive[0] == '+') positive++; + if (positive[0] == 'n' && positive[1] == 'a' && positive[2] == 'n') { + const char * modifier = positive[3] == ':' ? positive + 4 : nullptr; + assert(modifier ? positive[4] == '0' && positive[5] == 'x' : 1); + switch (type) { + case f32: { + union { + uint32_t pattern; + float f; + } u; + if (modifier) { + std::istringstream istr(modifier); + istr >> std::hex >> u.pattern; + u.pattern |= 0x7f800000; + } else { + u.pattern = 0x7fc00000; + } + if (negative) u.pattern |= 0x80000000; + if (!isnan(u.f)) u.pattern |= 1; + assert(isnan(u.f)); + ret->value.f32 = u.f; + break; + } + case f64: { + union { + uint64_t pattern; + double d; + } u; + if (modifier) { + std::istringstream istr(modifier); + istr >> std::hex >> u.pattern; + u.pattern |= 0x7ff0000000000000LL; + } else { + u.pattern = 0x7ff8000000000000L; + } + if (negative) u.pattern |= 0x8000000000000000LL; + if (!isnan(u.d)) u.pattern |= 1; + assert(isnan(u.d)); + ret->value.f64 = u.d; + break; + } + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; + } + if (s == NEG_NAN) { + switch (type) { + case f32: ret->value.f32 = -std::nan(""); break; + case f64: ret->value.f64 = -std::nan(""); break; + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; + } + } + switch (type) { + case i32: { + if ((str[0] == '0' && str[1] == 'x') || (str[0] == '-' && str[1] == '0' && str[2] == 'x')) { + bool negative = str[0] == '-'; + if (negative) str++; + std::istringstream istr(str); + uint32_t temp; + istr >> std::hex >> temp; + ret->value.i32 = negative ? -temp : temp; + } else { + std::istringstream istr(str); + int32_t temp; + istr >> temp; + ret->value.i32 = temp; + } + break; + } + case i64: { + if ((str[0] == '0' && str[1] == 'x') || (str[0] == '-' && str[1] == '0' && str[2] == 'x')) { + bool negative = str[0] == '-'; + if (negative) str++; + std::istringstream istr(str); + uint64_t temp; + istr >> std::hex >> temp; + ret->value.i64 = negative ? -temp : temp; + } else { + std::istringstream istr(str); + int64_t temp; + istr >> temp; + ret->value.i64 = temp; + } + break; + } + case f32: { + char *end; + ret->value.f32 = strtof(str, &end); + assert(!isnan(ret->value.f32)); + break; + } + case f64: { + char *end; + ret->value.f64 = strtod(str, &end); + assert(!isnan(ret->value.f64)); + break; + } + default: return nullptr; + } + //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + return ret; +} + + +} // namespace wasm + diff --git a/src/pass.h b/src/pass.h index de22e7daa..6af9e83e1 100644 --- a/src/pass.h +++ b/src/pass.h @@ -92,7 +92,6 @@ struct NameManager : public Pass { // visitors void visitBlock(Block* curr) override; void visitLoop(Loop* curr) override; - void visitLabel(Label* curr) override; void visitBreak(Break* curr) override; void visitSwitch(Switch* curr) override; void visitCall(Call* curr) override; diff --git a/src/passes/NameManager.cpp b/src/passes/NameManager.cpp index 73a09262b..34ea6adac 100644 --- a/src/passes/NameManager.cpp +++ b/src/passes/NameManager.cpp @@ -24,9 +24,6 @@ void NameManager::visitLoop(Loop* curr) { names.insert(curr->out); names.insert(curr->in); } -void NameManager::visitLabel(Label* curr) { - names.insert(curr->name); -} void NameManager::visitBreak(Break* curr) { names.insert(curr->name); } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp new file mode 100644 index 000000000..561002f57 --- /dev/null +++ b/src/passes/RemoveImports.cpp @@ -0,0 +1,43 @@ +// +// Removeds imports, and replaces them with nops. This is useful +// for running a module through the reference interpreter, which +// does not validate imports for a JS environment (by removing +// imports, we can at least get the reference interpreter to +// look at all the rest of the code). +// + +#include <wasm.h> +#include <pass.h> + +namespace wasm { + +struct RemoveImports : public Pass { + MixedArena* allocator; + std::map<Name, Import*> importsMap; + + void prepare(PassRunner* runner, Module *module) override { + allocator = runner->allocator; + importsMap = module->importsMap; + } + + void visitCallImport(CallImport *curr) override { + WasmType type = importsMap[curr->target]->type.result; + if (type == none) { + replaceCurrent(allocator->alloc<Nop>()); + } else { + Literal nopLiteral; + nopLiteral.type = type; + replaceCurrent(allocator->alloc<Const>()->set(nopLiteral)); + } + } + + void visitModule(Module *curr) override { + curr->importsMap.clear(); + curr->imports.clear(); + } +}; + +static RegisterPass<RemoveImports> registerPass("remove-imports", "removes imports and replaces them with nops"); + +} // namespace wasm + diff --git a/src/s2wasm-main.cpp b/src/s2wasm-main.cpp new file mode 100644 index 000000000..0df9b204a --- /dev/null +++ b/src/s2wasm-main.cpp @@ -0,0 +1,48 @@ +// +// wasm2asm console tool +// + +#include "s2wasm.h" + +using namespace cashew; +using namespace wasm; + +namespace wasm { +int debug = 0; +} + +int main(int argc, char **argv) { + debug = getenv("S2WASM_DEBUG") ? getenv("S2WASM_DEBUG")[0] - '0' : 0; + + char *infile = argv[1]; + + if (debug) std::cerr << "loading '" << infile << "'...\n"; + FILE *f = fopen(argv[1], "r"); + assert(f); + fseek(f, 0, SEEK_END); + int size = ftell(f); + char *input = new char[size+1]; + rewind(f); + int num = fread(input, 1, size, f); + // On Windows, ftell() gives the byte position (\r\n counts as two bytes), but when + // reading, fread() returns the number of characters read (\r\n is read as one char \n, and counted as one), + // so return value of fread can be less than size reported by ftell, and that is normal. + assert((num > 0 || size == 0) && num <= size); + fclose(f); + input[num] = 0; + + if (debug) std::cerr << "parsing and wasming...\n"; + AllocatingModule wasm; + S2WasmBuilder s2wasm(wasm, input); + + if (debug) std::cerr << "emscripten gluing...\n"; + std::stringstream meta; + s2wasm.emscriptenGlue(meta); + + if (debug) std::cerr << "printing...\n"; + std::cout << wasm; + std::cout << meta.str(); + + if (debug) std::cerr << "done.\n"; +} + diff --git a/src/s2wasm.h b/src/s2wasm.h new file mode 100644 index 000000000..71b17fc5b --- /dev/null +++ b/src/s2wasm.h @@ -0,0 +1,1033 @@ + +// +// .s to WebAssembly translator. +// + +#include "wasm.h" +#include "parsing.h" +#include "asm_v_wasm.h" + +namespace wasm { + +extern int debug; // wasm::debug is set in main(), typically from an env var + +cashew::IString EMSCRIPTEN_ASM_CONST("emscripten_asm_const"); + +// +// S2WasmBuilder - parses a .s file into WebAssembly +// + +class S2WasmBuilder { + AllocatingModule& wasm; + MixedArena& allocator; + char *s; + +public: + S2WasmBuilder(AllocatingModule& wasm, char *input) : wasm(wasm), allocator(wasm.allocator) { + s = input; + scan(); + s = input; + process(); + fix(); + } + +private: + // state + + size_t nextStatic = 1; // location of next static allocation, i.e., the data segment + std::map<Name, int32_t> staticAddresses; // name => address + + typedef std::pair<Const*, Name> Addressing; + std::vector<Addressing> addressings; // we fix these up + + struct Relocation { + uint32_t* data; + Name value; + int offset; + Relocation(uint32_t* data, Name value, int offset) : data(data), value(value), offset(offset) {} + }; + std::vector<Relocation> relocations; + + std::set<Name> implementedFunctions; + + std::map<size_t, size_t> addressSegments; // address => segment index + + // utilities + + void skipWhitespace() { + while (1) { + while (*s && isspace(*s)) s++; + if (*s != '#') break; + while (*s != '\n') s++; + } + } + + bool skipComma() { + skipWhitespace(); + if (*s != ',') return false; + s++; + skipWhitespace(); + return true; + } + + #define abort_on(why) { \ + dump(why ":"); \ + abort(); \ + } + + // match and skip the pattern, if matched + bool match(const char *pattern) { + size_t size = strlen(pattern); + if (strncmp(s, pattern, size) == 0) { + s += size; + skipWhitespace(); + return true; + } + return false; + } + + void mustMatch(const char *pattern) { + bool matched = match(pattern); + if (!matched) { + std::cerr << "<< " << pattern << " >>\n"; + abort_on("bad mustMatch"); + } + } + + void dump(const char *text) { + std::cerr << "[[" << text << "]]:\n==========\n"; + for (size_t i = 0; i < 60; i++) { + if (!s[i]) break; + std::cerr << s[i]; + } + std::cerr << "\n==========\n"; + } + + void unget(Name str) { + s -= strlen(str.str); + } + + Name getStr() { + std::string str; // TODO: optimize this and the other get* methods + while (*s && !isspace(*s)) { + str += *s; + s++; + } + return cashew::IString(str.c_str(), false); + } + + void skipToSep() { + while (*s && !isspace(*s) && *s != ',' && *s != '(' && *s != ')' && *s != ':' && *s != '+') { + s++; + } + } + + Name getStrToSep() { + std::string str; + while (*s && !isspace(*s) && *s != ',' && *s != '(' && *s != ')' && *s != ':' && *s != '+') { + str += *s; + s++; + } + return cashew::IString(str.c_str(), false); + } + + Name getStrToColon() { + std::string str; + while (*s && !isspace(*s) && *s != ':') { + str += *s; + s++; + } + return cashew::IString(str.c_str(), false); + } + + int32_t getInt() { + int32_t ret = 0; + bool neg = false; + if (*s == '-') { + neg = true; + s++; + } + while (isdigit(*s)) { + ret *= 10; + ret += (*s - '0'); + s++; + } + if (neg) ret = -ret; + return ret; + } + + void getConst(uint32_t* target) { + if (isdigit(*s)) { + *target = getInt(); + } else { + // a global constant, we need to fix it up later + Name name = getStrToSep(); + int offset = 0; + if (*s == '+') { + s++; + offset = getInt(); + } + relocations.emplace_back(target, name, offset); + } + } + + int64_t getInt64() { + int64_t ret = 0; + bool neg = false; + if (*s == '-') { + neg = true; + s++; + } + while (isdigit(*s)) { + ret *= 10; + ret += (*s - '0'); + s++; + } + if (neg) ret = -ret; + return ret; + } + + Name getCommaSeparated() { + skipWhitespace(); + std::string str; + while (*s && *s != ',' && *s != '\n') { + str += *s; + s++; + } + skipWhitespace(); + return cashew::IString(str.c_str(), false); + } + + Name getAssign() { + skipWhitespace(); + if (*s != '$') return Name(); + std::string str; + char *before = s; + while (*s && *s != '=' && *s != '\n' && *s != ',') { + str += *s; + s++; + } + if (*s != '=') { // not an assign + s = before; + return Name(); + } + s++; + skipComma(); + return cashew::IString(str.c_str(), false); + } + + std::vector<char> getQuoted() { // TODO: support 0 in the middle, etc., use a raw buffer, etc. + assert(*s == '"'); + s++; + std::vector<char> str; + while (*s && *s != '\"') { + if (s[0] == '\\') { + switch (s[1]) { + case 'n': str.push_back('\n'); s += 2; continue; + case 'r': str.push_back('\r'); s += 2; continue; + case 't': str.push_back('\t'); s += 2; continue; + case 'f': str.push_back('\f'); s += 2; continue; + case 'b': str.push_back('\b'); s += 2; continue; + case '\\': str.push_back('\\'); s += 2; continue; + case '"': str.push_back('"'); s += 2; continue; + default: { + if (isdigit(s[1])) { + int code = (s[1] - '0')*8*8 + (s[2] - '0')*8 + (s[3] - '0'); + str.push_back(char(code)); + s += 4; + continue; + } else abort_on("getQuoted-escape"); + } + } + } + str.push_back(*s); + s++; + } + s++; + skipWhitespace(); + return str; + } + + WasmType getType() { + if (match("i32")) return i32; + if (match("i64")) return i64; + if (match("f32")) return f32; + if (match("f64")) return f64; + abort_on("getType"); + } + + // processors + + void scan() { + while (*s) { + s = strstr(s, "\n .type "); + if (!s) break; + mustMatch("\n .type "); + Name name = getCommaSeparated(); + skipComma(); + if (!match("@function")) continue; + mustMatch(name.str); + mustMatch(":"); + implementedFunctions.insert(name); + } + } + + void process() { + while (*s) { + skipWhitespace(); + if (!*s) break; + if (*s != '.') break; + s++; + if (match("text")) parseText(); + else if (match("type")) parseType(); + else if (match("weak")) getStr(); // contents are in the type that follows + else if (match("imports")) skipImports(); + else if (match("data")) {} + else if (match("ident")) {} + else if (match("section")) s = strchr(s, '\n'); + else abort_on("process"); + } + } + + void parseText() { + while (*s) { + skipWhitespace(); + if (!*s) break; + if (*s != '.') break; + s++; + if (match("file")) parseFile(); + else if (match("globl")) parseGlobl(); + else if (match("type")) parseType(); + else { + s--; + break; + } + } + } + + void parseFile() { + assert(*s == '"'); + s++; + std::string filename; + while (*s != '"') { + filename += *s; + s++; + } + s++; + // TODO: use the filename? + } + + void parseGlobl() { + Name name = getStr(); + skipWhitespace(); + } + + void parseFunction() { + if (debug) dump("func"); + Name name = getStrToSep(); + mustMatch(":"); + + unsigned nextId = 0; + auto getNextId = [&nextId]() { + return cashew::IString(('$' + std::to_string(nextId++)).c_str(), false); + }; + + auto func = allocator.alloc<Function>(); + func->name = name; + std::map<Name, WasmType> localTypes; + // params and result + while (1) { + if (match(".param")) { + while (1) { + Name name = getNextId(); + WasmType type = getType(); + func->params.emplace_back(name, type); + localTypes[name] = type; + skipWhitespace(); + if (!match(",")) break; + } + } else if (match(".result")) { + func->result = getType(); + } else if (match(".local")) { + while (1) { + Name name = getNextId(); + WasmType type = getType(); + func->locals.emplace_back(name, type); + localTypes[name] = type; + skipWhitespace(); + if (!match(",")) break; + } + } else break; + } + // parse body + func->body = allocator.alloc<Block>(); + std::vector<Block*> bstack; + bstack.push_back(func->body->dyn_cast<Block>()); + std::vector<Expression*> estack; + auto push = [&](Expression* curr) { + //std::cerr << "push " << curr << '\n'; + estack.push_back(curr); + }; + auto pop = [&]() { + assert(!estack.empty()); + Expression* ret = estack.back(); + assert(ret); + estack.pop_back(); + //std::cerr << "pop " << ret << '\n'; + return ret; + }; + auto getNumInputs = [&]() { + int ret = 1; + char *t = s; + while (*t != '\n') { + if (*t == ',') ret++; + t++; + } + return ret; + }; + auto getInputs = [&](int num) { + // we may have $pop, $0, $pop, $1 etc., which are getlocals + // interleaved with stack pops, and the stack pops must be done in + // *reverse* order, i.e., that input should turn into + // lastpop, getlocal(0), firstpop, getlocal(1) + std::vector<Expression*> inputs; // TODO: optimize (if .s format doesn't change) + inputs.resize(num); + for (int i = 0; i < num; i++) { + if (match("$pop")) { + skipToSep(); + inputs[i] = nullptr; + } else { + auto curr = allocator.alloc<GetLocal>(); + curr->name = getStrToSep(); + curr->type = localTypes[curr->name]; + inputs[i] = curr; + } + if (*s == ')') s++; // tolerate 0(argument) syntax, where we started at the 'a' + if (i < num - 1) skipComma(); + } + for (int i = num-1; i >= 0; i--) { + if (inputs[i] == nullptr) inputs[i] = pop(); + } + return inputs; + }; + auto getInput = [&]() { + return getInputs(1)[0]; + }; + auto setOutput = [&](Expression* curr, Name assign) { + if (assign.isNull() || assign.str[1] == 'd') { // discard + bstack.back()->list.push_back(curr); + } else if (assign.str[1] == 'p') { // push + estack.push_back(curr); + } else { // set to a local + auto set = allocator.alloc<SetLocal>(); + set->name = assign; + set->value = curr; + set->type = curr->type; + bstack.back()->list.push_back(set); + } + }; + auto makeBinary = [&](BinaryOp op, WasmType type) { + Name assign = getAssign(); + skipComma(); + auto curr = allocator.alloc<Binary>(); + curr->op = op; + auto inputs = getInputs(2); + curr->left = inputs[0]; + curr->right = inputs[1]; + curr->finalize(); + assert(curr->type == type); + setOutput(curr, assign); + }; + auto makeUnary = [&](UnaryOp op, WasmType type) { + Name assign = getAssign(); + skipComma(); + auto curr = allocator.alloc<Unary>(); + curr->op = op; + curr->value = getInput(); + curr->type = type; + setOutput(curr, assign); + }; + auto makeHost = [&](HostOp op) { + Name assign = getAssign(); + auto curr = allocator.alloc<Host>(); + curr->op = MemorySize; + setOutput(curr, assign); + }; + auto makeHost1 = [&](HostOp op) { + Name assign = getAssign(); + auto curr = allocator.alloc<Host>(); + curr->op = MemorySize; + curr->operands.push_back(getInput()); + setOutput(curr, assign); + }; + auto makeLoad = [&](WasmType type) { + skipComma(); + auto curr = allocator.alloc<Load>(); + curr->type = type; + int32_t bytes = getInt(); + curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type); + curr->signed_ = match("_s"); + match("_u"); + Name assign = getAssign(); + getConst(&curr->offset); + curr->align = curr->bytes; // XXX + mustMatch("("); + curr->ptr = getInput(); + setOutput(curr, assign); + }; + auto makeStore = [&](WasmType type) { + skipComma(); + auto curr = allocator.alloc<Store>(); + curr->type = type; + int32_t bytes = getInt(); + curr->bytes = bytes > 0 ? bytes : getWasmTypeSize(type); + curr->align = curr->bytes; // XXX + Name assign = getAssign(); + getConst(&curr->offset); + mustMatch("("); + auto inputs = getInputs(2); + curr->ptr = inputs[0]; + curr->value = inputs[1]; + setOutput(curr, assign); + }; + auto makeSelect = [&](WasmType type) { + Name assign = getAssign(); + skipComma(); + auto curr = allocator.alloc<Select>(); + auto inputs = getInputs(3); + curr->condition = inputs[0]; + curr->ifTrue = inputs[1]; + curr->ifFalse = inputs[2]; + curr->type = type; + setOutput(curr, assign); + }; + auto makeCall = [&](WasmType type) { + CallBase* curr; + Name assign; + if (match("_indirect")) { + auto specific = allocator.alloc<CallIndirect>(); + assign = getAssign(); + specific->target = getInput(); + curr = specific; + } else { + assign = getAssign(); + Name target = getCommaSeparated(); + if (implementedFunctions.count(target) > 0) { + auto specific = allocator.alloc<Call>(); + specific->target = target; + curr = specific; + } else { + auto specific = allocator.alloc<CallImport>(); + specific->target = target; + curr = specific; + if (wasm.importsMap.count(target) == 0) { + auto import = allocator.alloc<Import>(); + import->name = import->base = target; + import->module = ENV; + wasm.addImport(import); + } + } + } + curr->type = type; + skipWhitespace(); + if (*s == ',') { + skipComma(); + int num = getNumInputs(); + auto inputs = getInputs(num); + for (int i = 0; i < num; i++) { + curr->operands.push_back(inputs[i]); + } + } + std::reverse(curr->operands.begin(), curr->operands.end()); + setOutput(curr, assign); + if (curr->is<CallIndirect>()) { + auto call = curr->dyn_cast<CallIndirect>(); + auto typeName = cashew::IString((std::string("FUNCSIG_") + getSig(call)).c_str(), false); + if (wasm.functionTypesMap.count(typeName) == 0) { + auto type = allocator.alloc<FunctionType>(); + type->name = typeName; + // TODO type->result + for (auto operand : call->operands) { + type->params.push_back(operand->type); + } + wasm.addFunctionType(type); + call->fullType = type; + } else { + call->fullType = wasm.functionTypesMap[typeName]; + } + } + }; + auto handleTyped = [&](WasmType type) { + switch (*s) { + case 'a': { + if (match("add")) makeBinary(BinaryOp::Add, type); + else if (match("and")) makeBinary(BinaryOp::And, type); + else if (match("abs")) makeUnary(UnaryOp::Abs, type); + else abort_on("type.a"); + break; + } + case 'c': { + if (match("const")) { + Name assign = getAssign(); + char start = *s; + cashew::IString str = getStr(); + if (start == '.' || (isalpha(*s) && str != NAN_ && str != INFINITY_)) { + // global address + auto curr = allocator.alloc<Const>(); + curr->type = i32; + addressings.emplace_back(curr, str); + setOutput(curr, assign); + } else { + // constant + setOutput(parseConst(str, type, allocator), assign); + } + } + else if (match("call")) makeCall(type); + else if (match("convert_s/i32")) makeUnary(UnaryOp::ConvertSInt32, type); + else if (match("convert_u/i32")) makeUnary(UnaryOp::ConvertUInt32, type); + else if (match("convert_s/i64")) makeUnary(UnaryOp::ConvertSInt64, type); + else if (match("convert_u/i64")) makeUnary(UnaryOp::ConvertUInt64, type); + else if (match("clz")) makeUnary(UnaryOp::Clz, type); + else if (match("ctz")) makeUnary(UnaryOp::Ctz, type); + else if (match("copysign")) makeBinary(BinaryOp::CopySign, type); + else if (match("ceil")) makeUnary(UnaryOp::Ceil, type); + else abort_on("type.c"); + break; + } + case 'd': { + if (match("demote/f64")) makeUnary(UnaryOp::DemoteFloat64, type); + else if (match("div_s")) makeBinary(BinaryOp::DivS, type); + else if (match("div_u")) makeBinary(BinaryOp::DivU, type); + else if (match("div")) makeBinary(BinaryOp::Div, type); + else abort_on("type.g"); + break; + } + case 'e': { + if (match("eq")) makeBinary(BinaryOp::Eq, i32); + else if (match("extend_s/i32")) makeUnary(UnaryOp::ExtendSInt32, type); + else if (match("extend_u/i32")) makeUnary(UnaryOp::ExtendUInt32, type); + else abort_on("type.e"); + break; + } + case 'f': { + if (match("floor")) makeUnary(UnaryOp::Floor, type); + else abort_on("type.e"); + break; + } + case 'g': { + if (match("gt_s")) makeBinary(BinaryOp::GtS, i32); + else if (match("gt_u")) makeBinary(BinaryOp::GtU, i32); + else if (match("ge_s")) makeBinary(BinaryOp::GeS, i32); + else if (match("ge_u")) makeBinary(BinaryOp::GeU, i32); + else if (match("gt")) makeBinary(BinaryOp::Gt, i32); + else if (match("ge")) makeBinary(BinaryOp::Ge, i32); + else abort_on("type.g"); + break; + } + case 'l': { + if (match("lt_s")) makeBinary(BinaryOp::LtS, i32); + else if (match("lt_u")) makeBinary(BinaryOp::LtU, i32); + else if (match("le_s")) makeBinary(BinaryOp::LeS, i32); + else if (match("le_u")) makeBinary(BinaryOp::LeU, i32); + else if (match("load")) makeLoad(type); + else if (match("lt")) makeBinary(BinaryOp::Lt, i32); + else if (match("le")) makeBinary(BinaryOp::Le, i32); + else abort_on("type.g"); + break; + } + case 'm': { + if (match("mul")) makeBinary(BinaryOp::Mul, type); + else if (match("min")) makeBinary(BinaryOp::Min, type); + else if (match("max")) makeBinary(BinaryOp::Max, type); + else abort_on("type.m"); + break; + } + case 'n': { + if (match("neg")) makeUnary(UnaryOp::Neg, i32); + else if (match("nearest")) makeUnary(UnaryOp::Nearest, i32); + else if (match("ne")) makeBinary(BinaryOp::Ne, i32); + else abort_on("type.n"); + break; + } + case 'o': { + if (match("or")) makeBinary(BinaryOp::Or, type); + else abort_on("type.o"); + break; + } + case 'p': { + if (match("promote/f32")) makeUnary(UnaryOp::PromoteFloat32, type); + else if (match("popcnt")) makeUnary(UnaryOp::Popcnt, type); + else abort_on("type.p"); + break; + } + case 'r': { + if (match("rem_s")) makeBinary(BinaryOp::RemS, type); + else if (match("rem_u")) makeBinary(BinaryOp::RemU, type); + else if (match("reinterpret/i32") || match("reinterpret/i64")) makeUnary(UnaryOp::ReinterpretInt, type); + else if (match("reinterpret/f32") || match("reinterpret/f64")) makeUnary(UnaryOp::ReinterpretFloat, type); + else abort_on("type.r"); + break; + } + case 's': { + if (match("shr_s")) makeBinary(BinaryOp::ShrS, type); + else if (match("shr_u")) makeBinary(BinaryOp::ShrU, type); + else if (match("shl")) makeBinary(BinaryOp::Shl, type); + else if (match("sub")) makeBinary(BinaryOp::Sub, type); + else if (match("store")) makeStore(type); + else if (match("select")) makeSelect(type); + else if (match("sqrt")) makeUnary(UnaryOp::Sqrt, type); + else abort_on("type.s"); + break; + } + case 't': { + if (match("trunc_s/f32")) makeUnary(UnaryOp::TruncSFloat32, type); + else if (match("trunc_u/f32")) makeUnary(UnaryOp::TruncUFloat32, type); + else if (match("trunc_s/f64")) makeUnary(UnaryOp::TruncSFloat64, type); + else if (match("trunc_u/f64")) makeUnary(UnaryOp::TruncUFloat64, type); + else if (match("trunc")) makeUnary(UnaryOp::Trunc, type); + else abort_on("type.t"); + break; + } + case 'w': { + if (match("wrap/i64")) makeUnary(UnaryOp::WrapInt64, type); + else abort_on("type.w"); + break; + } + case 'x': { + if (match("xor")) makeBinary(BinaryOp::Xor, type); + else abort_on("type.x"); + break; + } + default: abort_on("type.?"); + } + }; + // fixups + std::vector<Block*> loopBlocks; // we need to clear their names + std::set<Name> seenLabels; // if we already used a label, we don't need it in a loop (there is a block above it, with that label) + Name lastLabel; // A loop has an 'in' label which appears before it. There might also be a block in between it and the loop, so we have to remember the last label + // main loop + while (1) { + skipWhitespace(); + if (debug) dump("main function loop"); + if (match("i32.")) { + handleTyped(i32); + } else if (match("i64.")) { + handleTyped(i64); + } else if (match("f32.")) { + handleTyped(f32); + } else if (match("f64.")) { + handleTyped(f64); + } else if (match("block")) { + auto curr = allocator.alloc<Block>(); + curr->name = getStr(); + bstack.back()->list.push_back(curr); + bstack.push_back(curr); + seenLabels.insert(curr->name); + } else if (match("BB")) { + s -= 2; + lastLabel = getStrToColon(); + s++; + skipWhitespace(); + // pop all blocks/loops that reach this target + // pop all targets with this label + while (!bstack.empty()) { + auto curr = bstack.back(); + if (curr->name == lastLabel) { + bstack.pop_back(); + continue; + } + break; + } + } else if (match("loop")) { + auto curr = allocator.alloc<Loop>(); + bstack.back()->list.push_back(curr); + curr->in = lastLabel; + Name out = getStr(); + if (seenLabels.count(out) == 0) { + curr->out = out; + } + auto block = allocator.alloc<Block>(); + block->name = out; // temporary, fake + curr->body = block; + loopBlocks.push_back(block); + bstack.push_back(block); + } else if (match("br")) { + auto curr = allocator.alloc<Break>(); + if (*s == '_') { + mustMatch("_if"); + curr->condition = getInput(); + skipComma(); + } + curr->name = getStr(); + bstack.back()->list.push_back(curr); + } else if (match("call")) { + makeCall(none); + } else if (match("copy_local")) { + Name assign = getAssign(); + skipComma(); + setOutput(getInput(), assign); + } else if (match("return")) { + Block *temp; + if (!(func->body && (temp = func->body->dyn_cast<Block>()) && temp->name == FAKE_RETURN)) { + Expression* old = func->body; + temp = allocator.alloc<Block>(); + temp->name = FAKE_RETURN; + if (old) temp->list.push_back(old); + func->body = temp; + } + auto curr = allocator.alloc<Break>(); + curr->name = FAKE_RETURN; + if (*s == '$') { + curr->value = getInput(); + } + bstack.back()->list.push_back(curr); + } else if (match("tableswitch")) { + auto curr = allocator.alloc<Switch>(); + curr->value = getInput(); + skipComma(); + curr->default_ = getCommaSeparated(); + while (skipComma()) { + curr->targets.push_back(getCommaSeparated()); + } + bstack.back()->list.push_back(curr); + } else if (match("unreachable")) { + bstack.back()->list.push_back(allocator.alloc<Unreachable>()); + } else if (match("memory_size")) { + makeHost(MemorySize); + } else if (match("grow_memory")) { + makeHost1(GrowMemory); + } else if (match("func_end")) { + s = strchr(s, '\n'); + s++; + s = strchr(s, '\n'); + break; // the function is done + } else { + abort_on("function element"); + } + } + // finishing touches + bstack.pop_back(); // remove the base block for the function body + assert(bstack.empty()); + assert(estack.empty()); + for (auto block : loopBlocks) { + block->name = Name(); + } + wasm.addFunction(func); + // XXX for now, export all functions + auto exp = allocator.alloc<Export>(); + exp->name = exp->value = func->name; + wasm.addExport(exp); + } + + void parseType() { + if (debug) dump("type"); + Name name = getStrToSep(); + skipComma(); + if (match("@function")) return parseFunction(); + else if (match("@object")) return parseObject(name); + abort_on("parseType"); + } + + void parseObject(Name name) { + if (match(".data") || match(".bss")) { + } else if (match(".section")) { + s = strchr(s, '\n'); + } else if (match(".lcomm")) { + mustMatch(name.str); + skipComma(); + getInt(); + return; + } + skipWhitespace(); + size_t align = 16; // XXX default? + if (match(".globl")) { + mustMatch(name.str); + skipWhitespace(); + } + if (match(".align")) { + align = getInt(); + skipWhitespace(); + } + if (match(".lcomm")) { + mustMatch(name.str); + skipComma(); + getInt(); + skipComma(); + getInt(); + return; // XXX wtf is this thing and what do we do with it + } + mustMatch(name.str); + mustMatch(":"); + auto raw = new std::vector<char>(); // leaked intentionally, no new allocation in Memory + bool zero = true; + while (1) { + skipWhitespace(); + if (match(".asciz")) { + *raw = getQuoted(); + raw->push_back(0); + zero = false; + } else if (match(".ascii")) { + *raw = getQuoted(); + zero = false; + } else if (match(".zero")) { + int32_t size = getInt(); + for (size_t i = 0; i < size; i++) { + raw->push_back(0); + } + } else if (match(".int32")) { + size_t size = raw->size(); + raw->resize(size + 4); + getConst((uint32_t*)&(*raw)[size]); + zero = false; + } else if (match(".int64")) { + size_t size = raw->size(); + raw->resize(size + 8); + (*(int64_t*)(&(*raw)[size])) = getInt(); + zero = false; + } else { + break; + } + } + skipWhitespace(); + size_t size = raw->size(); + if (match(".size")) { + mustMatch(name.str); + mustMatch(","); + size_t seenSize = atoi(getStr().str); // TODO: optimize + assert(seenSize == size); + } + while (nextStatic % align) nextStatic++; + // assign the address, add to memory + staticAddresses[name] = nextStatic; + if (!zero) { + addressSegments[nextStatic] = wasm.memory.segments.size(); + wasm.memory.segments.emplace_back(nextStatic, (const char*)&(*raw)[0], size); + } + nextStatic += size; + } + + void skipImports() { + while (1) { + if (match(".import")) { + s = strchr(s, '\n'); + skipWhitespace(); + continue; + } + break; + } + } + + void fix() { + for (auto& pair : addressings) { + Const* curr = pair.first; + Name name = pair.second; + curr->value = Literal(staticAddresses[name]); + assert(curr->value.i32 > 0); + curr->type = i32; + } + for (auto& relocation : relocations) { + *(relocation.data) = staticAddresses[relocation.value] + relocation.offset; + } + } + + template<class C> + void printSet(std::ostream& o, C& c) { + o << "["; + bool first = true; + for (auto& item : c) { + if (first) first = false; + else o << ","; + o << '"' << item << '"'; + } + o << "]"; + } + +public: + + // extra emscripten processing + void emscriptenGlue(std::ostream& o) { + wasm.removeImport(EMSCRIPTEN_ASM_CONST); // we create _sig versions + + o << ";; METADATA: { "; + // find asmConst calls, and emit their metadata + struct AsmConstWalker : public WasmWalker { + S2WasmBuilder* parent; + + std::map<std::string, std::set<std::string>> sigsForCode; + std::map<std::string, size_t> ids; + std::set<std::string> allSigs; + + AsmConstWalker(S2WasmBuilder* parent) : parent(parent) {} + + void visitCallImport(CallImport* curr) override { + if (curr->target == EMSCRIPTEN_ASM_CONST) { + auto arg = curr->operands[0]->cast<Const>(); + size_t segmentIndex = parent->addressSegments[arg->value.geti32()]; + std::string code = escape(parent->wasm.memory.segments[segmentIndex].data); + int32_t id; + if (ids.count(code) == 0) { + id = ids.size(); + ids[code] = id; + } else { + id = ids[code]; + } + std::string sig = getSig(curr); + sigsForCode[code].insert(sig); + std::string fixedTarget = std::string("_") + EMSCRIPTEN_ASM_CONST.str + '_' + sig; + curr->target = cashew::IString(fixedTarget.c_str(), false); + arg->value = Literal(id); + // add import, if necessary + if (allSigs.count(sig) == 0) { + allSigs.insert(sig); + auto import = parent->allocator.alloc<Import>(); + import->name = import->base = curr->target; + import->module = ENV; + parent->wasm.addImport(import); + } + } + } + + std::string escape(const char *input) { + std::string code = input; + // replace newlines quotes with escaped newlines + size_t curr = 0; + while ((curr = code.find("\\n", curr)) != std::string::npos) { + code = code.replace(curr, 2, "\\\\n"); + curr += 3; // skip this one + } + // replace double quotes with escaped single quotes + curr = 0; + while ((curr = code.find('"', curr)) != std::string::npos) { + if (curr == 0 || code[curr-1] != '\\') { + code = code.replace(curr, 1, "\\" "\""); + curr += 2; // skip this one + } else { // already escaped, escape the slash as well + code = code.replace(curr, 1, "\\" "\\" "\""); + curr += 3; // skip this one + } + } + return code; + } + }; + AsmConstWalker walker(this); + walker.startWalk(&wasm); + // print + o << "\"asmConsts\": {"; + bool first = true; + for (auto& pair : walker.sigsForCode) { + auto& code = pair.first; + auto& sigs = pair.second; + if (first) first = false; + else o << ","; + o << '"' << walker.ids[code] << "\": [\"" << code << "\", "; + printSet(o, sigs); + o << "]"; + } + o << "}"; + + o << " }"; + } +}; + +} // namespace wasm + diff --git a/src/shared-constants.h b/src/shared-constants.h new file mode 100644 index 000000000..a67475303 --- /dev/null +++ b/src/shared-constants.h @@ -0,0 +1,75 @@ +#ifndef _shared_constants_h_ +#define _shared_constants_h_ + +#include "emscripten-optimizer/optimizer.h" + +namespace wasm { + +cashew::IString GLOBAL("global"), + NAN_("NaN"), + INFINITY_("Infinity"), + TOPMOST("topmost"), + INT8ARRAY("Int8Array"), + INT16ARRAY("Int16Array"), + INT32ARRAY("Int32Array"), + UINT8ARRAY("Uint8Array"), + UINT16ARRAY("Uint16Array"), + UINT32ARRAY("Uint32Array"), + FLOAT32ARRAY("Float32Array"), + FLOAT64ARRAY("Float64Array"), + IMPOSSIBLE_CONTINUE("impossible-continue"), + MATH("Math"), + IMUL("imul"), + CLZ32("clz32"), + FROUND("fround"), + ASM2WASM("asm2wasm"), + F64_REM("f64-rem"), + F64_TO_INT("f64-to-int"), + GLOBAL_MATH("global.Math"), + ABS("abs"), + FLOOR("floor"), + SQRT("sqrt"), + I32_TEMP("asm2wasm_i32_temp"), + DEBUGGER("debugger"), + GROW_WASM_MEMORY("__growWasmMemory"), + NEW_SIZE("newSize"), + MODULE("module"), + FUNC("func"), + PARAM("param"), + RESULT("result"), + MEMORY("memory"), + SEGMENT("segment"), + EXPORT("export"), + IMPORT("import"), + TABLE("table"), + LOCAL("local"), + TYPE("type"), + CALL("call"), + CALL_IMPORT("call_import"), + CALL_INDIRECT("call_indirect"), + BR_IF("br_if"), + NEG_INFINITY("-infinity"), + NEG_NAN("-nan"), + CASE("case"), + BR("br"), + USE_ASM("use asm"), + BUFFER("buffer"), + ENV("env"), + FAKE_RETURN("fake_return_waka123"), + MATH_IMUL("Math_imul"), + MATH_CLZ32("Math_clz32"), + MATH_CTZ32("Math_ctz32"), + MATH_POPCNT32("Math_popcnt32"), + MATH_ABS("Math_abs"), + MATH_CEIL("Math_ceil"), + MATH_FLOOR("Math_floor"), + MATH_TRUNC("Math_trunc"), + MATH_NEAREST("Math_NEAREST"), + MATH_SQRT("Math_sqrt"), + MATH_MIN("Math_max"), + MATH_MAX("Math_min"); + +} + +#endif // _shared_constants_h_ + diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 1cdd86387..4a04c6912 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -32,6 +32,30 @@ enum { maxCallDepth = 250 }; +// Stuff that flows around during executing expressions: a literal, or a change in control flow +class Flow { +public: + Flow() {} + Flow(Literal value) : value(value) {} + Flow(IString breakTo) : breakTo(breakTo) {} + + Literal value; + IString breakTo; // if non-null, a break is going on + + bool breaking() { return breakTo.is(); } + + void clearIf(IString target) { + if (breakTo == target) { + breakTo.clear(); + } + } + + friend std::ostream& operator<<(std::ostream& o, Flow& flow) { + o << "(flow " << (flow.breakTo.is() ? flow.breakTo.str : "-") << " : " << flow.value << ')'; + return o; + } +}; + // // An instance of a WebAssembly module, which can execute it via AST interpretation. // @@ -57,7 +81,7 @@ public: virtual Literal load(Load* load, size_t addr) = 0; virtual void store(Store* store, size_t addr, Literal value) = 0; virtual void growMemory(size_t oldSize, size_t newSize) = 0; - virtual void trap() = 0; + virtual void trap(const char* why) = 0; }; Module& wasm; @@ -69,7 +93,7 @@ public: Literal callExport(IString name, LiteralList& arguments) { Export *export_ = wasm.exportsMap[name]; - if (!export_) externalInterface->trap(); + if (!export_) externalInterface->trap("callExport not found"); return callFunction(export_->value, arguments); } @@ -105,30 +129,6 @@ private: } }; - // Stuff that flows around during executing expressions: a literal, or a change in control flow - class Flow { - public: - Flow() {} - Flow(Literal value) : value(value) {} - Flow(IString breakTo) : breakTo(breakTo) {} - - Literal value; - IString breakTo; // if non-null, a break is going on - - bool breaking() { return breakTo.is(); } - - void clearIf(IString target) { - if (breakTo == target) { - breakTo.clear(); - } - } - - std::ostream& print(std::ostream& o) { - o << "(flow " << (breakTo.is() ? breakTo.str : "-") << " : " << value << ')'; - return o; - } - }; - #ifdef WASM_INTERPRETER_DEBUG struct IndentHandler { int& indent; @@ -137,22 +137,28 @@ private: doIndent(std::cout, indent); std::cout << "visit " << name << " :\n"; indent++; - //doIndent(std::cout, indent); - //expression->print(std::cout, indent) << '\n'; - //indent++; +#if WASM_INTERPRETER_DEBUG == 2 + doIndent(std::cout, indent); + expression->print(std::cout, indent) << '\n'; + indent++; +#endif } ~IndentHandler() { - //indent--; +#if WASM_INTERPRETER_DEBUG == 2 + indent--; +#endif indent--; doIndent(std::cout, indent); std::cout << "exit " << name << '\n'; } }; #define NOTE_ENTER(x) IndentHandler indentHandler(instance.indent, x, curr); + #define NOTE_NAME(p0) { doIndent(std::cout, instance.indent); std::cout << "name in " << indentHandler.name << '(' << Name(p0) << ")\n"; } #define NOTE_EVAL1(p0) { doIndent(std::cout, instance.indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ")\n"; } #define NOTE_EVAL2(p0, p1) { doIndent(std::cout, instance.indent); std::cout << "eval in " << indentHandler.name << '(' << p0 << ", " << p1 << ")\n"; } #else #define NOTE_ENTER(x) + #define NOTE_NAME(p0) #define NOTE_EVAL1(p0) #define NOTE_EVAL2(p0, p1) #endif @@ -182,7 +188,11 @@ private: Flow flow = visit(curr->condition); if (flow.breaking()) return flow; NOTE_EVAL1(flow.value); - if (flow.value.geti32()) return visit(curr->ifTrue); + if (flow.value.geti32()) { + Flow flow = visit(curr->ifTrue); + if (!flow.breaking() && !curr->ifFalse) flow.value = Literal(); // if_else returns a value, but if does not + return flow; + } if (curr->ifFalse) return visit(curr->ifFalse); return Flow(); } @@ -193,26 +203,25 @@ private: if (flow.breaking()) { if (flow.breakTo == curr->in) continue; // lol flow.clearIf(curr->out); - return flow; } + return flow; // loop does not loop automatically, only continue achieves that } } - Flow visitLabel(Label *curr) override { - NOTE_ENTER("Label"); - Flow flow = visit(curr->body); - flow.clearIf(curr->name); - return flow; - } Flow visitBreak(Break *curr) override { NOTE_ENTER("Break"); + bool condition = true; + if (curr->condition) { + Flow flow = visit(curr->condition); + if (flow.breaking()) return flow; + condition = flow.value.getInteger(); + } + Flow flow(curr->name); if (curr->value) { - Flow flow = visit(curr->value); - if (!flow.breaking()) { - flow.breakTo = curr->name; - } - return flow; + flow = visit(curr->value); + if (flow.breaking()) return flow; + flow.breakTo = curr->name; } - return Flow(curr->name); + return condition ? flow : Flow(); } Flow visitSwitch(Switch *curr) override { NOTE_ENTER("Switch"); @@ -263,6 +272,7 @@ private: Flow visitCall(Call *curr) override { NOTE_ENTER("Call"); + NOTE_NAME(curr->target); LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; @@ -284,10 +294,10 @@ private: Flow target = visit(curr->target); if (target.breaking()) return target; size_t index = target.value.geti32(); - if (index >= instance.wasm.table.names.size()) trap(); + if (index >= instance.wasm.table.names.size()) trap("callIndirect: overflow"); Name name = instance.wasm.table.names[index]; Function *func = instance.wasm.functionsMap[name]; - if (func->type.is() && func->type != curr->type->name) trap(); + if (func->type.is() && func->type != curr->fullType->name) trap("callIndirect: bad type"); LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; @@ -297,6 +307,7 @@ private: Flow visitGetLocal(GetLocal *curr) override { NOTE_ENTER("GetLocal"); IString name = curr->name; + NOTE_NAME(name); NOTE_EVAL1(scope.locals[name]); return scope.locals[name]; } @@ -305,7 +316,9 @@ private: IString name = curr->name; Flow flow = visit(curr->value); if (flow.breaking()) return flow; + NOTE_NAME(name); NOTE_EVAL1(flow.value); + assert(flow.value.type == curr->type); scope.locals[name] = flow.value; return flow; } @@ -344,6 +357,17 @@ private: return Literal((int32_t)safe_ctz(v)); } case Popcnt: return Literal((int32_t)__builtin_popcount(v)); + case ReinterpretInt: { + float v = value.reinterpretf32(); + if (isnan(v)) { + return Literal(Literal(value.geti32() | 0x7f800000).reinterpretf32()); + } + return Literal(value.reinterpretf32()); + } + case ExtendSInt32: return Literal(int64_t(value.geti32())); + case ExtendUInt32: return Literal(uint64_t((uint32_t)value.geti32())); + case ConvertUInt32: return curr->type == f32 ? Literal(float(uint32_t(value.geti32()))) : Literal(double(uint32_t(value.geti32()))); + case ConvertSInt32: return curr->type == f32 ? Literal(float(int32_t(value.geti32()))) : Literal(double(int32_t(value.geti32()))); default: abort(); } } @@ -362,6 +386,12 @@ private: return Literal((int64_t)safe_ctz(low)); } case Popcnt: return Literal(int64_t(__builtin_popcount(low) + __builtin_popcount(high))); + case WrapInt64: return Literal(int32_t(value.geti64())); + case ReinterpretInt: { + return Literal(value.reinterpretf64()); + } + case ConvertUInt64: return curr->type == f32 ? Literal(float((uint64_t)value.geti64())) : Literal(double((uint64_t)value.geti64())); + case ConvertSInt64: return curr->type == f32 ? Literal(float(value.geti64())) : Literal(double(value.geti64())); default: abort(); } } @@ -376,6 +406,10 @@ private: case Trunc: ret = std::trunc(v); break; case Nearest: ret = std::nearbyint(v); break; case Sqrt: ret = std::sqrt(v); break; + case TruncSFloat32: return truncSFloat(curr, value); + case TruncUFloat32: return truncUFloat(curr, value); + case ReinterpretFloat: return Literal(value.reinterpreti32()); + case PromoteFloat32: return Literal(double(value.getf32())); default: abort(); } return Literal(fixNaN(v, ret)); @@ -391,6 +425,10 @@ private: case Trunc: ret = std::trunc(v); break; case Nearest: ret = std::nearbyint(v); break; case Sqrt: ret = std::sqrt(v); break; + case TruncSFloat64: return truncSFloat(curr, value); + case TruncUFloat64: return truncUFloat(curr, value); + case ReinterpretFloat: return Literal(value.reinterpreti64()); + case DemoteFloat64: return Literal(float(value.getf64())); default: abort(); } return Literal(fixNaN(v, ret)); @@ -406,6 +444,8 @@ private: if (flow.breaking()) return flow; Literal right = flow.value; NOTE_EVAL2(left, right); + assert(left.type == curr->left->type); + assert(right.type == curr->right->type); if (left.type == i32) { int32_t l = left.geti32(), r = right.geti32(); switch (curr->op) { @@ -413,21 +453,21 @@ private: case Sub: return Literal(l - r); case Mul: return Literal(l * r); case DivS: { - if (r == 0) trap(); - if (l == INT32_MIN && r == -1) trap(); // signed division overflow + if (r == 0) trap("i32.div_s by 0"); + if (l == INT32_MIN && r == -1) trap("i32.div_s overflow"); // signed division overflow return Literal(l / r); } case DivU: { - if (r == 0) trap(); + if (r == 0) trap("i32.div_u by 0"); return Literal(int32_t(uint32_t(l) / uint32_t(r))); } case RemS: { - if (r == 0) trap(); + if (r == 0) trap("i32.rem_s by 0"); if (l == INT32_MIN && r == -1) return Literal(int32_t(0)); return Literal(l % r); } case RemU: { - if (r == 0) trap(); + if (r == 0) trap("i32.rem_u by 0"); return Literal(int32_t(uint32_t(l) % uint32_t(r))); } case And: return Literal(l & r); @@ -445,6 +485,16 @@ private: r = r & 31; return Literal(l >> r); } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case LtS: return Literal(l < r); + case LtU: return Literal(uint32_t(l) < uint32_t(r)); + case LeS: return Literal(l <= r); + case LeU: return Literal(uint32_t(l) <= uint32_t(r)); + case GtS: return Literal(l > r); + case GtU: return Literal(uint32_t(l) > uint32_t(r)); + case GeS: return Literal(l >= r); + case GeU: return Literal(uint32_t(l) >= uint32_t(r)); default: abort(); } } else if (left.type == i64) { @@ -454,21 +504,21 @@ private: case Sub: return Literal(l - r); case Mul: return Literal(l * r); case DivS: { - if (r == 0) trap(); - if (l == LLONG_MIN && r == -1) trap(); // signed division overflow + if (r == 0) trap("i64.div_s by 0"); + if (l == LLONG_MIN && r == -1) trap("i64.div_s overflow"); // signed division overflow return Literal(l / r); } case DivU: { - if (r == 0) trap(); + if (r == 0) trap("i64.div_u by 0"); return Literal(int64_t(uint64_t(l) / uint64_t(r))); } case RemS: { - if (r == 0) trap(); + if (r == 0) trap("i64.rem_s by 0"); if (l == LLONG_MIN && r == -1) return Literal(int64_t(0)); return Literal(l % r); } case RemU: { - if (r == 0) trap(); + if (r == 0) trap("i64.rem_u by 0"); return Literal(int64_t(uint64_t(l) % uint64_t(r))); } case And: return Literal(l & r); @@ -486,6 +536,16 @@ private: r = r & 63; return Literal(l >> r); } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case LtS: return Literal(l < r); + case LtU: return Literal(uint64_t(l) < uint64_t(r)); + case LeS: return Literal(l <= r); + case LeU: return Literal(uint64_t(l) <= uint64_t(r)); + case GtS: return Literal(l > r); + case GtU: return Literal(uint64_t(l) > uint64_t(r)); + case GeS: return Literal(l >= r); + case GeU: return Literal(uint64_t(l) >= uint64_t(r)); default: abort(); } } else if (left.type == f32) { @@ -510,6 +570,12 @@ private: else ret = std::max(l, r); break; } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case Lt: return Literal(l < r); + case Le: return Literal(l <= r); + case Gt: return Literal(l > r); + case Ge: return Literal(l >= r); default: abort(); } return Literal(fixNaN(l, r, ret)); @@ -535,65 +601,6 @@ private: else ret = std::max(l, r); break; } - default: abort(); - } - return Literal(fixNaN(l, r, ret)); - } - abort(); - } - Flow visitCompare(Compare *curr) override { - NOTE_ENTER("Compare"); - Flow flow = visit(curr->left); - if (flow.breaking()) return flow; - Literal left = flow.value; - flow = visit(curr->right); - if (flow.breaking()) return flow; - Literal right = flow.value; - NOTE_EVAL2(left, right); - if (left.type == i32) { - int32_t l = left.geti32(), r = right.geti32(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case LtS: return Literal(l < r); - case LtU: return Literal(uint32_t(l) < uint32_t(r)); - case LeS: return Literal(l <= r); - case LeU: return Literal(uint32_t(l) <= uint32_t(r)); - case GtS: return Literal(l > r); - case GtU: return Literal(uint32_t(l) > uint32_t(r)); - case GeS: return Literal(l >= r); - case GeU: return Literal(uint32_t(l) >= uint32_t(r)); - default: abort(); - } - } else if (left.type == i64) { - int64_t l = left.geti64(), r = right.geti64(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case LtS: return Literal(l < r); - case LtU: return Literal(uint64_t(l) < uint64_t(r)); - case LeS: return Literal(l <= r); - case LeU: return Literal(uint64_t(l) <= uint64_t(r)); - case GtS: return Literal(l > r); - case GtU: return Literal(uint64_t(l) > uint64_t(r)); - case GeS: return Literal(l >= r); - case GeU: return Literal(uint64_t(l) >= uint64_t(r)); - default: abort(); - } - } else if (left.type == f32) { - float l = left.getf32(), r = right.getf32(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case Lt: return Literal(l < r); - case Le: return Literal(l <= r); - case Gt: return Literal(l > r); - case Ge: return Literal(l >= r); - default: abort(); - } - } else if (left.type == f64) { - double l = left.getf64(), r = right.getf64(); - switch (curr->op) { case Eq: return Literal(l == r); case Ne: return Literal(l != r); case Lt: return Literal(l < r); @@ -602,67 +609,10 @@ private: case Ge: return Literal(l >= r); default: abort(); } + return Literal(fixNaN(l, r, ret)); } abort(); } - Flow visitConvert(Convert *curr) override { - NOTE_ENTER("Convert"); - Flow flow = visit(curr->value); - if (flow.breaking()) return flow; - Literal value = flow.value; - switch (curr->op) { // :-) - case ExtendSInt32: return Literal(int64_t(value.geti32())); - case ExtendUInt32: return Literal(uint64_t((uint32_t)value.geti32())); - case WrapInt64: return Literal(int32_t(value.geti64())); - case TruncSFloat32: - case TruncSFloat64: { - double val = curr->op == TruncSFloat32 ? value.getf32() : value.getf64(); - if (isnan(val)) trap(); - if (curr->type == i32) { - if (val > (double)INT_MAX || val < (double)INT_MIN) trap(); - return Literal(int32_t(val)); - } else { - int64_t converted = val; - if ((val >= 1 && converted <= 0) || val < (double)LLONG_MIN) trap(); - return Literal(converted); - } - } - case TruncUFloat32: - case TruncUFloat64: { - double val = curr->op == TruncUFloat32 ? value.getf32() : value.getf64(); - if (isnan(val)) trap(); - if (curr->type == i32) { - if (val > (double)UINT_MAX || val <= (double)-1) trap(); - return Literal(uint32_t(val)); - } else { - uint64_t converted = val; - if (converted < val - 1 || val <= (double)-1) trap(); - return Literal(converted); - } - } - case ReinterpretFloat: { - return curr->type == i32 ? Literal(value.reinterpreti32()) : Literal(value.reinterpreti64()); - } - case ConvertUInt32: return curr->type == f32 ? Literal(float(uint32_t(value.geti32()))) : Literal(double(uint32_t(value.geti32()))); - case ConvertSInt32: return curr->type == f32 ? Literal(float(int32_t(value.geti32()))) : Literal(double(int32_t(value.geti32()))); - case ConvertUInt64: return curr->type == f32 ? Literal(float((uint64_t)value.geti64())) : Literal(double((uint64_t)value.geti64())); - case ConvertSInt64: return curr->type == f32 ? Literal(float(value.geti64())) : Literal(double(value.geti64())); - case PromoteFloat32: return Literal(double(value.getf32())); - case DemoteFloat64: return Literal(float(value.getf64())); - case ReinterpretInt: { - if (curr->type == f32) { - float v = value.reinterpretf32(); - if (isnan(v)) { - return Literal(Literal(value.geti32() | 0x7f800000).reinterpretf32()); - } - return Literal(value.reinterpretf32()); - } else { - return Literal(value.reinterpretf64()); - } - } - default: abort(); - } - } Flow visitSelect(Select *curr) override { NOTE_ENTER("Select"); Flow condition = visit(curr->condition); @@ -683,11 +633,11 @@ private: Flow flow = visit(curr->operands[0]); if (flow.breaking()) return flow; uint32_t delta = flow.value.geti32(); - if (delta % pageSize != 0) trap(); - if (delta > uint32_t(-1) - pageSize) trap(); - if (instance.memorySize >= uint32_t(-1) - delta) trap(); + if (delta % pageSize != 0) trap("growMemory: delta not multiple"); + if (delta > uint32_t(-1) - pageSize) trap("growMemory: delta relatively too big"); + if (instance.memorySize >= uint32_t(-1) - delta) trap("growMemory: delta objectively too big"); uint32_t newSize = instance.memorySize + delta; - if (newSize > instance.wasm.memory.max) trap(); + if (newSize > instance.wasm.memory.max) trap("growMemory: exceeds max"); instance.externalInterface->growMemory(instance.memorySize, newSize); instance.memorySize = newSize; return Literal(); @@ -706,7 +656,7 @@ private: } Flow visitUnreachable(Unreachable *curr) override { NOTE_ENTER("Unreachable"); - trap(); + trap("unreachable"); return Flow(); } @@ -746,12 +696,38 @@ private: return Literal(int64_t(Literal(lnan ? l : r).reinterpreti64() | 0x8000000000000LL)).reinterpretf64(); } - void trap() { - instance.externalInterface->trap(); + Literal truncSFloat(Unary* curr, Literal value) { + double val = curr->op == TruncSFloat32 ? value.getf32() : value.getf64(); + if (isnan(val)) trap("truncSFloat of nan"); + if (curr->type == i32) { + if (val > (double)INT_MAX || val < (double)INT_MIN) trap("i32.truncSFloat overflow"); + return Literal(int32_t(val)); + } else { + int64_t converted = val; + if ((val >= 1 && converted <= 0) || val < (double)LLONG_MIN) trap("i32.truncSFloat overflow"); + return Literal(converted); + } + } + + Literal truncUFloat(Unary* curr, Literal value) { + double val = curr->op == TruncUFloat32 ? value.getf32() : value.getf64(); + if (isnan(val)) trap("truncUFloat of nan"); + if (curr->type == i32) { + if (val > (double)UINT_MAX || val <= (double)-1) trap("i64.truncUFloat overflow"); + return Literal(uint32_t(val)); + } else { + uint64_t converted = val; + if (converted < val - 1 || val <= (double)-1) trap("i64.truncUFloat overflow"); + return Literal(converted); + } + } + + void trap(const char* why) { + instance.externalInterface->trap(why); } }; - if (callDepth > maxCallDepth) externalInterface->trap(); + if (callDepth > maxCallDepth) externalInterface->trap("stack limit"); callDepth++; Function *function = wasm.functionsMap[name]; @@ -762,12 +738,14 @@ private: std::cout << "entering " << function->name << '\n'; #endif - Literal ret = ExpressionRunner(*this, scope).visit(function->body).value; + Flow flow = ExpressionRunner(*this, scope).visit(function->body); + assert(!flow.breaking()); // cannot still be breaking, it means we missed our stop + Literal ret = flow.value; if (function->result == none) ret = Literal(); assert(function->result == ret.type); callDepth--; #ifdef WASM_INTERPRETER_DEBUG - std::cout << "exiting " << function->name << '\n'; + std::cout << "exiting " << function->name << " with " << ret << '\n'; #endif return ret; } @@ -777,11 +755,11 @@ private: template<class LS> size_t getFinalAddress(LS *curr, Literal ptr) { uint64_t addr = ptr.type == i32 ? ptr.geti32() : ptr.geti64(); - if (memorySize < curr->offset) externalInterface->trap(); - if (addr > memorySize - curr->offset) externalInterface->trap(); + if (memorySize < curr->offset) externalInterface->trap("offset > memory"); + if (addr > memorySize - curr->offset) externalInterface->trap("final > memory"); addr += curr->offset; - if (curr->bytes > memorySize) externalInterface->trap(); - if (addr > memorySize - curr->bytes) externalInterface->trap(); + if (curr->bytes > memorySize) externalInterface->trap("bytes > memory"); + if (addr > memorySize - curr->bytes) externalInterface->trap("highest > memory"); return addr; } diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp index 868c8e6ac..6858659f6 100644 --- a/src/wasm-js.cpp +++ b/src/wasm-js.cpp @@ -10,68 +10,109 @@ #include "asm2wasm.h" #include "wasm-interpreter.h" +#include "wasm-s-parser.h" using namespace cashew; using namespace wasm; +namespace wasm { +int debug = 0; +} + // global singletons Asm2WasmBuilder* asm2wasm = nullptr; +SExpressionParser* sExpressionParser = nullptr; +SExpressionWasmBuilder* sExpressionWasmBuilder = nullptr; ModuleInstance* instance = nullptr; -Module* module = nullptr; - -// receives asm.js code, parses into wasm and returns an instance handle. -// this creates a module, an external interface, a builder, and a module instance, -// all of which are then the responsibility of the caller to free. -// note: this modifies the input. -extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { - assert(instance == nullptr); // singleton - - // emcc --separate-asm modules look like - // - // Module["asm"] = (function(global, env, buffer) { - // .. - // }); - // - // we need to clean that up. - size_t num = strlen(input); - assert(*input == 'M'); - while (*input != 'f') { - input++; - num--; - } - char *end = input + num - 1; - while (*end != '}') { - *end = 0; - end--; - } +AllocatingModule* module = nullptr; +bool wasmJSDebug = false; +static void prepare2wasm() { + assert(asm2wasm == nullptr && sExpressionParser == nullptr && sExpressionWasmBuilder == nullptr && instance == nullptr); // singletons #if WASM_JS_DEBUG - std::cerr << "parsing...\n"; + wasmJSDebug = 1; +#else + wasmJSDebug = EM_ASM_INT_V({ return !!Module['outside']['WASM_JS_DEBUG'] }); // Set WASM_JS_DEBUG on the outside Module to get debugging #endif +} + +// receives asm.js code, parses into wasm. +// note: this modifies the input. +extern "C" void EMSCRIPTEN_KEEPALIVE load_asm2wasm(char *input) { + prepare2wasm(); + + Asm2WasmPreProcessor pre; + input = pre.process(input); + + // proceed to parse and wasmify + if (wasmJSDebug) std::cerr << "asm parsing...\n"; + cashew::Parser<Ref, DotZeroValueBuilder> builder; Ref asmjs = builder.parseToplevel(input); - module = new Module(); - module->memory.initial = module->memory.max = 16*1024*1024; // TODO: receive this from emscripten + module = new AllocatingModule(); + module->memory.initial = EM_ASM_INT_V({ + return Module['providedTotalMemory']; // we receive the size of memory from emscripten + }); + module->memory.max = pre.memoryGrowth ? -1 : module->memory.initial; -#if WASM_JS_DEBUG - std::cerr << "wasming...\n"; -#endif - asm2wasm = new Asm2WasmBuilder(*module); + if (wasmJSDebug) std::cerr << "wasming...\n"; + asm2wasm = new Asm2WasmBuilder(*module, pre.memoryGrowth); asm2wasm->processAsm(asmjs); -#if WASM_JS_DEBUG - std::cerr << "optimizing...\n"; -#endif + if (wasmJSDebug) std::cerr << "optimizing...\n"; asm2wasm->optimize(); -#if WASM_JS_DEBUG - std::cerr << *module << '\n'; -#endif + if (wasmJSDebug) std::cerr << "mapping globals...\n"; + for (auto& pair : asm2wasm->mappedGlobals) { + auto name = pair.first; + auto& global = pair.second; + if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here + double value = EM_ASM_DOUBLE({ return Module['lookupImport'](Pointer_stringify($0), Pointer_stringify($1)) }, global.module.str, global.base.str); + unsigned address = global.address; + switch (global.type) { + case i32: EM_ASM_({ Module['info'].parent['HEAP32'][$0 >> 2] = $1 }, address, value); break; + case f32: EM_ASM_({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, address, value); break; + case f64: EM_ASM_({ Module['info'].parent['HEAPF64'][$0 >> 3] = $1 }, address, value); break; + default: abort(); + } + } +} + +// loads wasm code in s-expression format +extern "C" void EMSCRIPTEN_KEEPALIVE load_s_expr2wasm(char *input, char *mappedGlobals) { + prepare2wasm(); + + if (wasmJSDebug) std::cerr << "wasm-s-expression parsing...\n"; + + sExpressionParser = new SExpressionParser(input); + Element& root = *sExpressionParser->root; + if (wasmJSDebug) std::cout << root << '\n'; + + if (wasmJSDebug) std::cerr << "wasming...\n"; + + module = new AllocatingModule(); + // A .wast may have multiple modules, with some asserts after them, but we just read the first here. + sExpressionWasmBuilder = new SExpressionWasmBuilder(*module, *root[0], [&]() { + std::cerr << "error in parsing s-expressions to wasm\n"; + abort(); + }); + + module->memory.initial = EM_ASM_INT_V({ + return Module['providedTotalMemory']; // we receive the size of memory from emscripten + }); + module->memory.max = (module->exportsMap.find(GROW_WASM_MEMORY) != module->exportsMap.end()) ? -1 : module->memory.initial; + + // global mapping is done in js in post.js +} + +// instantiates the loaded wasm (which might be from asm2wasm, or +// s-expressions, or something else) with a JS external interface. +extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { + if (wasmJSDebug) std::cerr << "instantiating module: \n" << *module << '\n'; + + if (wasmJSDebug) std::cerr << "generating exports...\n"; -#if WASM_JS_DEBUG - std::cerr << "generating exports...\n"; -#endif EM_ASM({ Module['asmExports'] = {}; }); @@ -81,26 +122,25 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { var name = Pointer_stringify($0); Module['asmExports'][name] = function() { Module['tempArguments'] = Array.prototype.slice.call(arguments); - return Module['_call_from_js']($0); + Module['_call_from_js']($0); + return Module['tempReturn']; }; }, curr->name.str); } -#if WASM_JS_DEBUG - std::cerr << "creating instance...\n"; -#endif + if (wasmJSDebug) std::cerr << "creating instance...\n"; struct JSExternalInterface : ModuleInstance::ExternalInterface { Literal callImport(Import *import, ModuleInstance::LiteralList& arguments) override { -#ifdef WASM_JS_DEBUG - std::cout << "calling import " << import->name.str << '\n'; -#endif + if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n'; EM_ASM({ Module['tempArguments'] = []; }); for (auto& argument : arguments) { if (argument.type == i32) { EM_ASM_({ Module['tempArguments'].push($0) }, argument.geti32()); + } else if (argument.type == f32) { + EM_ASM_({ Module['tempArguments'].push($0) }, argument.getf32()); } else if (argument.type == f64) { EM_ASM_({ Module['tempArguments'].push($0) }, argument.getf64()); } else { @@ -115,9 +155,9 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { var lookup = Module['lookupImport'](mod, base); return lookup.apply(null, tempArguments); }, import->module.str, import->base.str); -#ifdef WASM_JS_DEBUG - std::cout << "calling import returning " << ret << '\n'; -#endif + + if (wasmJSDebug) std::cout << "calling import returning " << ret << '\n'; + switch (import->type.result) { case none: return Literal(0); case i32: return Literal((int32_t)ret); @@ -152,7 +192,7 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { abort(); } else { if (load->bytes == 4) { - return Literal(EM_ASM_DOUBLE({ return Module['info'].parent['HEAPF32'][$0 >> 2] }, addr)); // XXX expands into double + return Literal((float)EM_ASM_DOUBLE({ return Module['info'].parent['HEAPF32'][$0 >> 2] }, addr)); } else if (load->bytes == 8) { return Literal(EM_ASM_DOUBLE({ return Module['info'].parent['HEAPF64'][$0 >> 3] }, addr)); } @@ -174,7 +214,7 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { } } else { if (store->bytes == 4) { - EM_ASM_DOUBLE({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, addr, value.getf64()); + EM_ASM_DOUBLE({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, addr, value.getf32()); } else if (store->bytes == 8) { EM_ASM_DOUBLE({ Module['info'].parent['HEAPF64'][$0 >> 3] = $1 }, addr, value.getf64()); } else { @@ -184,41 +224,36 @@ extern "C" void EMSCRIPTEN_KEEPALIVE load_asm(char *input) { } void growMemory(size_t oldSize, size_t newSize) override { - abort(); + EM_ASM_({ + var size = $0; + var buffer; + try { + buffer = new ArrayBuffer(size); + } catch(e) { + // fail to grow memory. post.js notices this since the buffer is unchanged + return; + } + var oldHEAP8 = Module['outside']['HEAP8']; + var temp = new Int8Array(buffer); + temp.set(oldHEAP8); + Module['outside']['buffer'] = buffer; + }, newSize); } - void trap() override { - EM_ASM({ - abort("wasm trap!"); - }); + void trap(const char* why) override { + EM_ASM_({ + abort("wasm trap: " + Pointer_stringify($0)); + }, why); } }; instance = new ModuleInstance(*module, new JSExternalInterface()); } -// Ready the provided imported globals, copying them to their mapped locations. -extern "C" void EMSCRIPTEN_KEEPALIVE load_mapped_globals() { - for (auto& pair : asm2wasm->mappedGlobals) { - auto name = pair.first; - auto& global = pair.second; - if (!global.import) continue; // non-imports are initialized to zero in the typed array anyhow, so nothing to do here - double value = EM_ASM_DOUBLE({ return Module['lookupImport'](Pointer_stringify($0), Pointer_stringify($1)) }, global.module.str, global.base.str); - unsigned address = global.address; - switch (global.type) { - case i32: EM_ASM_({ Module['info'].parent['HEAP32'][$0 >> 2] = $1 }, address, value); break; - case f32: EM_ASM_({ Module['info'].parent['HEAPF32'][$0 >> 2] = $1 }, address, value); break; - case f64: EM_ASM_({ Module['info'].parent['HEAPF64'][$0 >> 3] = $1 }, address, value); break; - default: abort(); - } - } -} - // Does a call from js into an export of the module. -extern "C" double EMSCRIPTEN_KEEPALIVE call_from_js(const char *target) { -#ifdef WASM_JS_DEBUG - std::cout << "call_from_js " << target << '\n'; -#endif +extern "C" void EMSCRIPTEN_KEEPALIVE call_from_js(const char *target) { + if (wasmJSDebug) std::cout << "call_from_js " << target << '\n'; + IString exportName(target); IString functionName = instance->wasm.exportsMap[exportName]->value; Function *function = instance->wasm.functionsMap[functionName]; @@ -240,13 +275,13 @@ extern "C" double EMSCRIPTEN_KEEPALIVE call_from_js(const char *target) { } } Literal ret = instance->callExport(exportName, arguments); -#ifdef WASM_JS_DEBUG - std::cout << "call_from_js returning " << ret << '\n'; -#endif - if (ret.type == none) return 0; - if (ret.type == i32) return ret.i32; - if (ret.type == f32) return ret.f32; - if (ret.type == f64) return ret.f64; - abort(); + + if (wasmJSDebug) std::cout << "call_from_js returning " << ret << '\n'; + + if (ret.type == none) EM_ASM({ Module['tempReturn'] = undefined }); + else if (ret.type == i32) EM_ASM_({ Module['tempReturn'] = $0 }, ret.i32); + else if (ret.type == f32) EM_ASM_({ Module['tempReturn'] = $0 }, ret.f32); + else if (ret.type == f64) EM_ASM_({ Module['tempReturn'] = $0 }, ret.f64); + else abort(); } diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 2e01bfc5f..6061dacc5 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -5,41 +5,18 @@ // #include <cmath> -#include <sstream> #include "wasm.h" #include "mixed_arena.h" +#include "shared-constants.h" +#include "parsing.h" namespace wasm { -int debug = 0; // wasm::debug is set in main(), typically from an env var - using namespace cashew; // Globals -IString MODULE("module"), - FUNC("func"), - PARAM("param"), - RESULT("result"), - MEMORY("memory"), - SEGMENT("segment"), - EXPORT("export"), - IMPORT("import"), - TABLE("table"), - LOCAL("local"), - TYPE("type"), - CALL("call"), - CALL_IMPORT("call_import"), - CALL_INDIRECT("call_indirect"), - INFINITY_("infinity"), - NEG_INFINITY("-infinity"), - NAN_("nan"), - NEG_NAN("-nan"), - CASE("case"), - BR("br"), - FAKE_RETURN("fake_return_waka123"); - int unhex(char c) { if (c >= '0' && c <= '9') return c - '0'; if (c >= 'a' && c <= 'f') return c - 'a' + 10; @@ -176,7 +153,7 @@ private: while (1) { while (isspace(input[0])) input++; if (input[0] == ';' && input[1] == ';') { - while (input[0] != '\n') input++; + while (input[0] && input[0] != '\n') input++; } else if (input[0] == '(' && input[1] == ';') { input = strstr(input, ";)") + 2; } else { @@ -232,15 +209,21 @@ private: // class SExpressionWasmBuilder { - Module& wasm; - MixedArena allocator; + AllocatingModule& wasm; + MixedArena& allocator; std::function<void ()> onError; int functionCounter; + std::map<Name, WasmType> functionTypes; // we need to know function return types before we parse their contents public: // Assumes control of and modifies the input. - SExpressionWasmBuilder(Module& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), onError(onError), functionCounter(0) { + SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError) { assert(module[0]->str() == MODULE); + functionCounter = 0; + for (unsigned i = 1; i < module.size(); i++) { + preParseFunctionType(*module[i]); + } + functionCounter = 0; for (unsigned i = 1; i < module.size(); i++) { parseModuleElement(*module[i]); } @@ -248,6 +231,38 @@ public: private: + // pre-parse types and function definitions, so we know function return types before parsing their contents + void preParseFunctionType(Element& s) { + IString id = s[0]->str(); + if (id == TYPE) return parseType(s); + if (id != FUNC) return; + size_t i = 1; + Name name; + if (s[i]->isStr()) { + name = s[i]->str(); + i++; + } else { + // unnamed, use an index + name = Name::fromInt(functionCounter); + } + functionCounter++; + for (;i < s.size(); i++) { + Element& curr = *s[i]; + IString id = curr[0]->str(); + if (id == RESULT) { + functionTypes[name] = stringToWasmType(curr[1]->str()); + return; + } else if (id == TYPE) { + Name name = curr[1]->str(); + if (wasm.functionTypesMap.find(name) == wasm.functionTypesMap.end()) onError(); + FunctionType* type = wasm.functionTypesMap[name]; + functionTypes[name] = type->result; + return; + } + } + functionTypes[name] = none; + } + void parseModuleElement(Element& curr) { IString id = curr[0]->str(); if (id == FUNC) return parseFunction(curr); @@ -255,7 +270,7 @@ private: if (id == EXPORT) return parseExport(curr); if (id == IMPORT) return parseImport(curr); if (id == TABLE) return parseTable(curr); - if (id == TYPE) return parseType(curr); + if (id == TYPE) return; // already done std::cerr << "bad module element " << id.str << '\n'; onError(); } @@ -402,8 +417,8 @@ public: if (op[2] == 'p') return makeBinary(s, BinaryOp::CopySign, type); if (op[2] == 'n') { if (op[3] == 'v') { - if (op[8] == 's') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertSInt32 : ConvertOp::ConvertSInt64, type); - if (op[8] == 'u') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertUInt32 : ConvertOp::ConvertUInt64, type); + if (op[8] == 's') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertSInt32 : UnaryOp::ConvertSInt64, type); + if (op[8] == 'u') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertUInt32 : UnaryOp::ConvertUInt64, type); } if (op[3] == 's') return makeConst(s, type); } @@ -416,12 +431,12 @@ public: if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BinaryOp::DivU : BinaryOp::DivS, type); if (op[3] == 0) return makeBinary(s, BinaryOp::Div, type); } - if (op[1] == 'e') return makeConvert(s, ConvertOp::DemoteFloat64, type); + if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); abort_on(op); } case 'e': { - if (op[1] == 'q') return makeCompare(s, RelationalOp::Eq, type); - if (op[1] == 'x') return makeConvert(s, op[7] == 'u' ? ConvertOp::ExtendUInt32 : ConvertOp::ExtendSInt32, type); + if (op[1] == 'q') return makeBinary(s, BinaryOp::Eq, type); + if (op[1] == 'x') return makeUnary(s, op[7] == 'u' ? UnaryOp::ExtendUInt32 : UnaryOp::ExtendSInt32, type); abort_on(op); } case 'f': { @@ -430,23 +445,23 @@ public: } case 'g': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GtU : RelationalOp::GtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Gt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GtU : BinaryOp::GtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Gt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GeU : RelationalOp::GeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Ge, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GeU : BinaryOp::GeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ge, type); } abort_on(op); } case 'l': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LtU : RelationalOp::LtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Lt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LtU : BinaryOp::LtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Lt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LeU : RelationalOp::LeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Le, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LeU : BinaryOp::LeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Le, type); } if (op[1] == 'o') return makeLoad(s, type); abort_on(op); @@ -459,7 +474,7 @@ public: } case 'n': { if (op[1] == 'e') { - if (op[2] == 0) return makeCompare(s, RelationalOp::Ne, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ne, type); if (op[2] == 'a') return makeUnary(s, UnaryOp::Nearest, type); if (op[2] == 'g') return makeUnary(s, UnaryOp::Neg, type); } @@ -470,14 +485,14 @@ public: abort_on(op); } case 'p': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::PromoteFloat32, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::PromoteFloat32, type); if (op[1] == 'o') return makeUnary(s, UnaryOp::Popcnt, type); abort_on(op); } case 'r': { if (op[1] == 'e') { if (op[2] == 'm') return makeBinary(s, op[4] == 'u' ? BinaryOp::RemU : BinaryOp::RemS, type); - if (op[2] == 'i') return makeConvert(s, isWasmTypeFloat(type) ? ConvertOp::ReinterpretInt : ConvertOp::ReinterpretFloat, type); + if (op[2] == 'i') return makeUnary(s, isWasmTypeFloat(type) ? UnaryOp::ReinterpretInt : UnaryOp::ReinterpretFloat, type); } abort_on(op); } @@ -494,14 +509,14 @@ public: } case 't': { if (op[1] == 'r') { - if (op[6] == 's') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncSFloat32 : ConvertOp::TruncSFloat64, type); - if (op[6] == 'u') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncUFloat32 : ConvertOp::TruncUFloat64, type); + if (op[6] == 's') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncSFloat32 : UnaryOp::TruncSFloat64, type); + if (op[6] == 'u') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncUFloat32 : UnaryOp::TruncUFloat64, type); if (op[2] == 'u') return makeUnary(s, UnaryOp::Trunc, type); } abort_on(op); } case 'w': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::WrapInt64, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::WrapInt64, type); abort_on(op); } case 'x': { @@ -540,7 +555,6 @@ public: abort_on(str); } case 'l': { - if (str[1] == 'a') return makeLabel(s); if (str[1] == 'o') return makeLoop(s); abort_on(str); } @@ -584,7 +598,7 @@ private: ret->op = op; ret->left = parseExpression(s[1]); ret->right = parseExpression(s[2]); - ret->type = type; + ret->finalize(); return ret; } @@ -596,23 +610,6 @@ private: return ret; } - Expression* makeCompare(Element& s, RelationalOp op, WasmType type) { - auto ret = allocator.alloc<Compare>(); - ret->op = op; - ret->left = parseExpression(s[1]); - ret->right = parseExpression(s[2]); - ret->inputType = type; - return ret; - } - - Expression* makeConvert(Element& s, ConvertOp op, WasmType type) { - auto ret = allocator.alloc<Convert>(); - ret->op = op; - ret->value = parseExpression(s[1]); - ret->type = type; - return ret; - } - Expression* makeSelect(Element& s, WasmType type) { auto ret = allocator.alloc<Select>(); ret->condition = parseExpression(s[1]); @@ -630,6 +627,7 @@ private: } else { parseCallOperands(s, 1, ret); } + ret->finalize(); return ret; } @@ -666,151 +664,21 @@ private: if (s[1]->isStr()) { ret->name = s[1]->str(); i++; + } else { + ret->name = getPrefixedName("block"); } + labelStack.push_back(ret->name); for (; i < s.size(); i++) { ret->list.push_back(parseExpression(s[i])); } + labelStack.pop_back(); + ret->type = ret->list.back()->type; return ret; } Expression* makeConst(Element& s, WasmType type) { - const char *str = s[1]->c_str(); - auto ret = allocator.alloc<Const>(); - ret->type = ret->value.type = type; - if (isWasmTypeFloat(type)) { - if (s[1]->str() == INFINITY_) { - switch (type) { - case f32: ret->value.f32 = std::numeric_limits<float>::infinity(); break; - case f64: ret->value.f64 = std::numeric_limits<double>::infinity(); break; - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; - return ret; - } - if (s[1]->str() == NEG_INFINITY) { - switch (type) { - case f32: ret->value.f32 = -std::numeric_limits<float>::infinity(); break; - case f64: ret->value.f64 = -std::numeric_limits<double>::infinity(); break; - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; - return ret; - } - if (s[1]->str() == NAN_) { - switch (type) { - case f32: ret->value.f32 = std::nan(""); break; - case f64: ret->value.f64 = std::nan(""); break; - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; - return ret; - } - bool negative = str[0] == '-'; - const char *positive = negative ? str + 1 : str; - if (positive[0] == '+') positive++; - if (positive[0] == 'n' && positive[1] == 'a' && positive[2] == 'n') { - const char * modifier = positive[3] == ':' ? positive + 4 : nullptr; - assert(modifier ? positive[4] == '0' && positive[5] == 'x' : 1); - switch (type) { - case f32: { - union { - uint32_t pattern; - float f; - } u; - if (modifier) { - std::istringstream istr(modifier); - istr >> std::hex >> u.pattern; - u.pattern |= 0x7f800000; - } else { - u.pattern = 0x7fc00000; - } - if (negative) u.pattern |= 0x80000000; - if (!isnan(u.f)) u.pattern |= 1; - assert(isnan(u.f)); - ret->value.f32 = u.f; - break; - } - case f64: { - union { - uint64_t pattern; - double d; - } u; - if (modifier) { - std::istringstream istr(modifier); - istr >> std::hex >> u.pattern; - u.pattern |= 0x7ff0000000000000LL; - } else { - u.pattern = 0x7ff8000000000000L; - } - if (negative) u.pattern |= 0x8000000000000000LL; - if (!isnan(u.d)) u.pattern |= 1; - assert(isnan(u.d)); - ret->value.f64 = u.d; - break; - } - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; - return ret; - } - if (s[1]->str() == NEG_NAN) { - switch (type) { - case f32: ret->value.f32 = -std::nan(""); break; - case f64: ret->value.f64 = -std::nan(""); break; - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; - return ret; - } - } - switch (type) { - case i32: { - if ((str[0] == '0' && str[1] == 'x') || (str[0] == '-' && str[1] == '0' && str[2] == 'x')) { - bool negative = str[0] == '-'; - if (negative) str++; - std::istringstream istr(str); - uint32_t temp; - istr >> std::hex >> temp; - ret->value.i32 = negative ? -temp : temp; - } else { - std::istringstream istr(str); - int32_t temp; - istr >> temp; - ret->value.i32 = temp; - } - break; - } - case i64: { - if ((str[0] == '0' && str[1] == 'x') || (str[0] == '-' && str[1] == '0' && str[2] == 'x')) { - bool negative = str[0] == '-'; - if (negative) str++; - std::istringstream istr(str); - uint64_t temp; - istr >> std::hex >> temp; - ret->value.i64 = negative ? -temp : temp; - } else { - std::istringstream istr(str); - int64_t temp; - istr >> temp; - ret->value.i64 = temp; - } - break; - } - case f32: { - char *end; - ret->value.f32 = strtof(str, &end); - assert(!isnan(ret->value.f32)); - break; - } - case f64: { - char *end; - ret->value.f64 = strtod(str, &end); - assert(!isnan(ret->value.f64)); - break; - } - default: onError(); - } - //std::cerr << "make constant " << str << " ==> " << ret->value << '\n'; + auto ret = parseConst(s[1]->str(), type, allocator); + if (!ret) onError(); return ret; } @@ -896,31 +764,20 @@ private: ret->ifTrue = parseExpression(s[2]); if (s.size() == 4) { ret->ifFalse = parseExpression(s[3]); + ret->type = ret->ifTrue->type == ret->ifFalse->type ? ret->ifTrue->type : none; // if not the same type, this does not return a value } return ret; } - Expression* makeLabel(Element& s) { - auto ret = allocator.alloc<Label>(); - size_t i = 1; - if (s[i]->isStr()) { - ret->name = s[i]->str(); - i++; - } else { - ret->name = getPrefixedName("label"); - } - labelStack.push_back(ret->name); - ret->body = parseExpression(s[i]); - labelStack.pop_back(); - return ret; - } - Expression* makeMaybeBlock(Element& s, size_t i, size_t stopAt=-1) { if (s.size() == i+1) return parseExpression(s[i]); auto ret = allocator.alloc<Block>(); for (; i < s.size() && i < stopAt; i++) { ret->list.push_back(parseExpression(s[i])); } + if (ret->list.size() > 0) { + ret->type = ret->list.back()->type; + } return ret; } @@ -950,6 +807,7 @@ private: Expression* makeCall(Element& s) { auto ret = allocator.alloc<Call>(); ret->target = s[1]->str(); + ret->type = functionTypes[ret->target]; parseCallOperands(s, 2, ret); return ret; } @@ -957,6 +815,8 @@ private: Expression* makeCallImport(Element& s) { auto ret = allocator.alloc<CallImport>(); ret->target = s[1]->str(); + Import* import = wasm.importsMap[ret->target]; + ret->type = import->type.result; parseCallOperands(s, 2, ret); return ret; } @@ -965,7 +825,8 @@ private: auto ret = allocator.alloc<CallIndirect>(); IString type = s[1]->str(); assert(wasm.functionTypesMap.find(type) != wasm.functionTypesMap.end()); - ret->type = wasm.functionTypesMap[type]; + ret->fullType = wasm.functionTypesMap[type]; + ret->type = ret->fullType->result; ret->target = parseExpression(s[2]); parseCallOperands(s, 3, ret); return ret; @@ -981,16 +842,22 @@ private: Expression* makeBreak(Element& s) { auto ret = allocator.alloc<Break>(); - if (s[1]->dollared()) { - ret->name = s[1]->str(); + size_t i = 1; + if (s[0]->str() == BR_IF) { + ret->condition = parseExpression(s[i]); + i++; + } + if (s[i]->dollared()) { + ret->name = s[i]->str(); } else { // offset, break to nth outside label - size_t offset = atol(s[1]->c_str()); + size_t offset = atol(s[i]->c_str()); assert(offset < labelStack.size()); ret->name = labelStack[labelStack.size() - 1 - offset]; } - if (s.size() == 3) { - ret->value = parseExpression(s[2]); + i++; + if (i < s.size()) { + ret->value = parseExpression(s[i]); } return ret; } @@ -1082,23 +949,27 @@ private: im->module = s[2]->str(); if (!s[3]->isStr()) onError(); im->base = s[3]->str(); - Element& params = *s[4]; - IString id = params[0]->str(); - if (id == PARAM) { - for (size_t i = 1; i < params.size(); i++) { - im->type.params.push_back(stringToWasmType(params[i]->str())); + if (s.size() > 4) { + Element& params = *s[4]; + IString id = params[0]->str(); + if (id == PARAM) { + for (size_t i = 1; i < params.size(); i++) { + im->type.params.push_back(stringToWasmType(params[i]->str())); + } + } else if (id == RESULT) { + im->type.result = stringToWasmType(params[1]->str()); + } else if (id == TYPE) { + IString name = params[1]->str(); + assert(wasm.functionTypesMap.find(name) != wasm.functionTypesMap.end()); + im->type = *wasm.functionTypesMap[name]; + } else { + onError(); + } + if (s.size() > 5) { + Element& result = *s[5]; + assert(result[0]->str() == RESULT); + im->type.result = stringToWasmType(result[1]->str()); } - } else if (id == TYPE) { - IString name = params[1]->str(); - assert(wasm.functionTypesMap.find(name) != wasm.functionTypesMap.end()); - im->type = *wasm.functionTypesMap[name]; - } else { - onError(); - } - if (s.size() > 5) { - Element& result = *s[5]; - assert(result[0]->str() == RESULT); - im->type.result = stringToWasmType(result[1]->str()); } wasm.addImport(im); } diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 49e7a6a33..4daf9a4a7 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -19,6 +19,26 @@ public: // visitors + void visitLoop(Loop *curr) override { + if (curr->in.is()) { + // the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value. + struct ChildChecker : public WasmWalker { + Name in; + bool valid = true; + + ChildChecker(Name in) : in(in) {} + + void visitBreak(Break *curr) override { + if (curr->name == in && curr->value) { + valid = false; + } + } + }; + ChildChecker childChecker(curr->in); + childChecker.walk(curr->body); + shouldBeTrue(childChecker.valid); + } + } void visitSetLocal(SetLocal *curr) override { shouldBeTrue(curr->type == curr->value->type); } @@ -40,6 +60,13 @@ public: } shouldBeFalse(curr->default_.is() && inTable.find(curr->default_) == inTable.end()); } + void visitUnary(Unary *curr) override { + shouldBeTrue(curr->value->type == curr->type); + } + + void visitFunction(Function *curr) override { + shouldBeTrue(curr->result == curr->body->type); + } void visitMemory(Memory *curr) override { shouldBeFalse(curr->initial > curr->max); size_t top = 0; diff --git a/src/wasm.h b/src/wasm.h index d9e25783a..af3917744 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -15,6 +15,17 @@ // * Validation: See wasm-validator.h. // +// +// wasm.js internal WebAssembly representation design: +// +// * Unify where possible. Where size isn't a concern, combine +// classes, so binary ops and relational ops are joined. This +// simplifies that AST and makes traversals easier. +// * Optimize for size? This might justify separating if and if_else +// (so that if doesn't have an always-empty else; also it avoids +// a branch). +// + #ifndef __wasm_h__ #define __wasm_h__ @@ -153,7 +164,7 @@ struct Literal { } } - void printFloat(std::ostream &o, float f) { + static void printFloat(std::ostream &o, float f) { if (isnan(f)) { union { float ff; @@ -166,7 +177,7 @@ struct Literal { printDouble(o, f); } - void printDouble(std::ostream &o, double d) { + static void printDouble(std::ostream &o, double d) { if (d == 0 && 1/d < 0) { o << "-0"; return; @@ -210,26 +221,22 @@ struct Literal { enum UnaryOp { Clz, Ctz, Popcnt, // int - Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt // float + Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt, // float + // conversions + ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int + ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float }; enum BinaryOp { Add, Sub, Mul, // int or float DivS, DivU, RemS, RemU, And, Or, Xor, Shl, ShrU, ShrS, // int - Div, CopySign, Min, Max // float -}; - -enum RelationalOp { + Div, CopySign, Min, Max, // float + // relational ops Eq, Ne, // int or float LtS, LtU, LeS, LeU, GtS, GtU, GeS, GeU, // int Lt, Le, Gt, Ge // float }; -enum ConvertOp { - ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int - ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float -}; - enum HostOp { PageSize, MemorySize, GrowMemory, HasFeature }; @@ -253,28 +260,25 @@ class Expression { public: enum Id { InvalidId = 0, - BlockId = 1, - IfId = 2, - LoopId = 3, - LabelId = 4, - BreakId = 5, - SwitchId =6 , - CallId = 7, - CallImportId = 8, - CallIndirectId = 9, - GetLocalId = 10, - SetLocalId = 11, - LoadId = 12, - StoreId = 13, - ConstId = 14, - UnaryId = 15, - BinaryId = 16, - CompareId = 17, - ConvertId = 18, - SelectId = 19, - HostId = 20, - NopId = 21, - UnreachableId = 22 + BlockId, + IfId, + LoopId, + BreakId, + SwitchId, + CallId, + CallImportId, + CallIndirectId, + GetLocalId, + SetLocalId, + LoadId, + StoreId, + ConstId, + UnaryId, + BinaryId, + SelectId, + HostId, + NopId, + UnreachableId }; Id _id; @@ -293,6 +297,12 @@ public: return _id == T()._id ? (T*)this : nullptr; } + template<class T> + T* cast() { + assert(_id == T()._id); + return (T*)this; + } + inline std::ostream& print(std::ostream &o, unsigned indent); // avoid virtual here, for performance friend std::ostream& operator<<(std::ostream &o, Expression* expression) { @@ -319,7 +329,9 @@ public: class Block : public Expression { public: - Block() : Expression(BlockId) {} + Block() : Expression(BlockId) { + type = none; // blocks by default do not return, but if their last statement does, they might + } Name name; ExpressionList list; @@ -339,7 +351,9 @@ public: class If : public Expression { public: - If() : Expression(IfId), ifFalse(nullptr) {} + If() : Expression(IfId), ifFalse(nullptr) { + type = none; // by default none; if-else can have one, though + } Expression *condition, *ifTrue, *ifFalse; @@ -375,31 +389,29 @@ public: } }; -class Label : public Expression { -public: - Label() : Expression(LabelId) {} - - Name name; - Expression* body; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "label ") << name; - incIndent(o, indent); - printFullLine(o, indent, body); - return decIndent(o, indent); - } -}; - class Break : public Expression { public: - Break() : Expression(BreakId), value(nullptr) {} + Break() : Expression(BreakId), condition(nullptr), value(nullptr) {} + Expression *condition; Name name; Expression *value; std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "br ") << name; - incIndent(o, indent); + if (condition) { + printOpening(o, "br_if"); + incIndent(o, indent); + printFullLine(o, indent, condition); + doIndent(o, indent) << name << '\n'; + } else { + printOpening(o, "br ") << name; + if (!value) { + // avoid a new line just for the parens + o << ")"; + return o; + } + incIndent(o, indent); + } if (value) printFullLine(o, indent, value); return decIndent(o, indent); } @@ -428,14 +440,15 @@ public: incIndent(o, indent); printFullLine(o, indent, value); doIndent(o, indent) << "(table"; - assert(default_.is()); for (auto& t : targets) { o << " (case " << (t.is() ? t : default_) << ")"; } - o << ") (case " << default_ << ")\n"; + o << ")"; + if (default_.is()) o << " (case " << default_ << ")"; + o << "\n"; for (auto& c : cases) { doIndent(o, indent); - printMinorOpening(o, "case ") << c.name.str; + printMinorOpening(o, "case ") << c.name; incIndent(o, indent); printFullLine(o, indent, c.body); decIndent(o, indent) << '\n'; @@ -445,12 +458,18 @@ public: }; -class Call : public Expression { +class CallBase : public Expression { public: - Call() : Expression(CallId) {} + CallBase(Id which) : Expression(which) {} - Name target; ExpressionList operands; +}; + +class Call : public CallBase { +public: + Call() : CallBase(CallId) {} + + Name target; std::ostream& printBody(std::ostream &o, unsigned indent) { o << target; @@ -526,16 +545,15 @@ public: } }; -class CallIndirect : public Expression { +class CallIndirect : public CallBase { public: - CallIndirect() : Expression(CallIndirectId) {} + CallIndirect() : CallBase(CallIndirectId) {} - FunctionType *type; + FunctionType *fullType; Expression *target; - ExpressionList operands; std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "call_indirect ") << type->name; + printOpening(o, "call_indirect ") << fullType->name; incIndent(o, indent); printFullLine(o, indent, target); for (auto operand : operands) { @@ -670,16 +688,31 @@ public: o << '('; prepareColor(o) << printWasmType(type) << '.'; switch (op) { - case Clz: o << "clz"; break; - case Ctz: o << "ctz"; break; - case Popcnt: o << "popcnt"; break; - case Neg: o << "neg"; break; - case Abs: o << "abs"; break; - case Ceil: o << "ceil"; break; - case Floor: o << "floor"; break; - case Trunc: o << "trunc"; break; - case Nearest: o << "nearest"; break; - case Sqrt: o << "sqrt"; break; + case Clz: o << "clz"; break; + case Ctz: o << "ctz"; break; + case Popcnt: o << "popcnt"; break; + case Neg: o << "neg"; break; + case Abs: o << "abs"; break; + case Ceil: o << "ceil"; break; + case Floor: o << "floor"; break; + case Trunc: o << "trunc"; break; + case Nearest: o << "nearest"; break; + case Sqrt: o << "sqrt"; break; + case ExtendSInt32: o << "extend_s/i32"; break; + case ExtendUInt32: o << "extend_u/i32"; break; + case WrapInt64: o << "wrap/i64"; break; + case TruncSFloat32: o << "trunc_s/f32"; break; + case TruncUFloat32: o << "trunc_u/f32"; break; + case TruncSFloat64: o << "trunc_s/f64"; break; + case TruncUFloat64: o << "trunc_u/f64"; break; + case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; + case ConvertUInt32: o << "convert_u/i32"; break; + case ConvertSInt32: o << "convert_s/i32"; break; + case ConvertUInt64: o << "convert_u/i64"; break; + case ConvertSInt64: o << "convert_s/i64"; break; + case PromoteFloat32: o << "promote/f32"; break; + case DemoteFloat64: o << "demote/f64"; break; + case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; default: abort(); } incIndent(o, indent); @@ -697,7 +730,7 @@ public: std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; - prepareColor(o) << printWasmType(type) << '.'; + prepareColor(o) << printWasmType(isRelational() ? left->type : type) << '.'; switch (op) { case Add: o << "add"; break; case Sub: o << "sub"; break; @@ -716,7 +749,21 @@ public: case CopySign: o << "copysign"; break; case Min: o << "min"; break; case Max: o << "max"; break; - default: abort(); + case Eq: o << "eq"; break; + case Ne: o << "ne"; break; + case LtS: o << "lt_s"; break; + case LtU: o << "lt_u"; break; + case LeS: o << "le_s"; break; + case LeU: o << "le_u"; break; + case GtS: o << "gt_s"; break; + case GtU: o << "gt_u"; break; + case GeS: o << "ge_s"; break; + case GeU: o << "ge_u"; break; + case Lt: o << "lt"; break; + case Le: o << "le"; break; + case Gt: o << "gt"; break; + case Ge: o << "ge"; break; + default: abort(); } restoreNormalColor(o); incIndent(o, indent); @@ -725,82 +772,18 @@ public: return decIndent(o, indent); } - // the type is always the type of the operands - void finalize() { - type = left->type; - } -}; + // the type is always the type of the operands, + // except for relationals -class Compare : public Expression { -public: - Compare() : Expression(CompareId) { - type = WasmType::i32; // output is always i32 - } + bool isRelational() { return op >= Eq; } - RelationalOp op; - WasmType inputType; - Expression *left, *right; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(inputType) << '.'; - switch (op) { - case Eq: o << "eq"; break; - case Ne: o << "ne"; break; - case LtS: o << "lt_s"; break; - case LtU: o << "lt_u"; break; - case LeS: o << "le_s"; break; - case LeU: o << "le_u"; break; - case GtS: o << "gt_s"; break; - case GtU: o << "gt_u"; break; - case GeS: o << "ge_s"; break; - case GeU: o << "ge_u"; break; - case Lt: o << "lt"; break; - case Le: o << "le"; break; - case Gt: o << "gt"; break; - case Ge: o << "ge"; break; - default: abort(); - } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, left); - printFullLine(o, indent, right); - return decIndent(o, indent); - } -}; - -class Convert : public Expression { -public: - Convert() : Expression(ConvertId) {} - - ConvertOp op; - Expression *value; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(type) << '.'; - switch (op) { - case ExtendSInt32: o << "extend_s/i32"; break; - case ExtendUInt32: o << "extend_u/i32"; break; - case WrapInt64: o << "wrap/i64"; break; - case TruncSFloat32: o << "trunc_s/f32"; break; - case TruncUFloat32: o << "trunc_u/f32"; break; - case TruncSFloat64: o << "trunc_s/f64"; break; - case TruncUFloat64: o << "trunc_u/f64"; break; - case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; - case ConvertUInt32: o << "convert_u/i32"; break; - case ConvertSInt32: o << "convert_s/i32"; break; - case ConvertUInt64: o << "convert_u/i64"; break; - case ConvertSInt64: o << "convert_s/i64"; break; - case PromoteFloat32: o << "promote/f32"; break; - case DemoteFloat64: o << "demote/f64"; break; - case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; - default: abort(); + void finalize() { + if (isRelational()) { + type = i32; + } else { + assert(left->type == right->type); + type = left->type; } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, value); - return decIndent(o, indent); } }; @@ -832,7 +815,7 @@ public: std::ostream& doPrint(std::ostream &o, unsigned indent) { switch (op) { case PageSize: printOpening(o, "pagesize") << ')'; break; - case MemorySize: printOpening(o, "memorysize") << ')'; break; + case MemorySize: printOpening(o, "memory_size") << ')'; break; case GrowMemory: { printOpening(o, "grow_memory"); incIndent(o, indent); @@ -845,6 +828,20 @@ public: } return o; } + + void finalize() { + switch (op) { + case PageSize: case MemorySize: case HasFeature: { + type = i32; + break; + } + case GrowMemory: { + type = none; + break; + } + default: abort(); + } + } }; class Unreachable : public Expression { @@ -906,7 +903,7 @@ public: std::ostream& print(std::ostream &o, unsigned indent) { printOpening(o, "import ") << name << ' '; printText(o, module.str) << ' '; - printText(o, base.str) << ' '; + printText(o, base.str); type.print(o, indent); return o << ')'; } @@ -1028,7 +1025,28 @@ public: printOpening(o, "memory") << " " << module.memory.initial; if (module.memory.max) o << " " << module.memory.max; for (auto segment : module.memory.segments) { - o << " (segment " << segment.offset << " \"" << segment.data << "\")"; + o << " (segment " << segment.offset << " \""; + for (size_t i = 0; i < segment.size; i++) { + unsigned char c = segment.data[i]; + switch (c) { + case '\n': o << "\\n"; break; + case '\r': o << "\\0d"; break; + case '\t': o << "\\t"; break; + case '\f': o << "\\0c"; break; + case '\b': o << "\\08"; break; + case '\\': o << "\\\\"; break; + case '"' : o << "\\\""; break; + case '\'' : o << "\\'"; break; + default: { + if (c >= 32 && c < 127) { + o << c; + } else { + o << std::hex << '\\' << (c/16) << (c%16) << std::dec; + } + } + } + } + o << "\")"; } o << ")\n"; for (auto& curr : module.functionTypes) { @@ -1077,7 +1095,6 @@ struct WasmVisitor { virtual ReturnType visitBlock(Block *curr) { abort(); } virtual ReturnType visitIf(If *curr) { abort(); } virtual ReturnType visitLoop(Loop *curr) { abort(); } - virtual ReturnType visitLabel(Label *curr) { abort(); } virtual ReturnType visitBreak(Break *curr) { abort(); } virtual ReturnType visitSwitch(Switch *curr) { abort(); } virtual ReturnType visitCall(Call *curr) { abort(); } @@ -1090,8 +1107,6 @@ struct WasmVisitor { virtual ReturnType visitConst(Const *curr) { abort(); } virtual ReturnType visitUnary(Unary *curr) { abort(); } virtual ReturnType visitBinary(Binary *curr) { abort(); } - virtual ReturnType visitCompare(Compare *curr) { abort(); } - virtual ReturnType visitConvert(Convert *curr) { abort(); } virtual ReturnType visitSelect(Select *curr) { abort(); } virtual ReturnType visitHost(Host *curr) { abort(); } virtual ReturnType visitNop(Nop *curr) { abort(); } @@ -1111,7 +1126,6 @@ struct WasmVisitor { case Expression::Id::BlockId: return visitBlock((Block*)curr); case Expression::Id::IfId: return visitIf((If*)curr); case Expression::Id::LoopId: return visitLoop((Loop*)curr); - case Expression::Id::LabelId: return visitLabel((Label*)curr); case Expression::Id::BreakId: return visitBreak((Break*)curr); case Expression::Id::SwitchId: return visitSwitch((Switch*)curr); case Expression::Id::CallId: return visitCall((Call*)curr); @@ -1124,8 +1138,6 @@ struct WasmVisitor { case Expression::Id::ConstId: return visitConst((Const*)curr); case Expression::Id::UnaryId: return visitUnary((Unary*)curr); case Expression::Id::BinaryId: return visitBinary((Binary*)curr); - case Expression::Id::CompareId: return visitCompare((Compare*)curr); - case Expression::Id::ConvertId: return visitConvert((Convert*)curr); case Expression::Id::SelectId: return visitSelect((Select*)curr); case Expression::Id::HostId: return visitHost((Host*)curr); case Expression::Id::NopId: return visitNop((Nop*)curr); @@ -1148,7 +1160,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { void visitBlock(Block *curr) override { curr->doPrint(o, indent); } void visitIf(If *curr) override { curr->doPrint(o, indent); } void visitLoop(Loop *curr) override { curr->doPrint(o, indent); } - void visitLabel(Label *curr) override { curr->doPrint(o, indent); } void visitBreak(Break *curr) override { curr->doPrint(o, indent); } void visitSwitch(Switch *curr) override { curr->doPrint(o, indent); } void visitCall(Call *curr) override { curr->doPrint(o, indent); } @@ -1161,8 +1172,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { void visitConst(Const *curr) override { curr->doPrint(o, indent); } void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } - void visitCompare(Compare *curr) override { curr->doPrint(o, indent); } - void visitConvert(Convert *curr) override { curr->doPrint(o, indent); } void visitSelect(Select *curr) override { curr->doPrint(o, indent); } void visitHost(Host *curr) override { curr->doPrint(o, indent); } void visitNop(Nop *curr) override { curr->doPrint(o, indent); } @@ -1176,7 +1185,7 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { // // Simple WebAssembly children-first walking (i.e., post-order, if you look -// at the children as subtrees of the current node), with the ability to +// at the children as subtrees of the current node), with the ability to replace // the current expression node. Useful for writing optimization passes. // @@ -1194,7 +1203,6 @@ struct WasmWalker : public WasmVisitor<void> { void visitBlock(Block *curr) override {} void visitIf(If *curr) override {} void visitLoop(Loop *curr) override {} - void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override {} void visitSwitch(Switch *curr) override {} void visitCall(Call *curr) override {} @@ -1207,8 +1215,6 @@ struct WasmWalker : public WasmVisitor<void> { void visitConst(Const *curr) override {} void visitUnary(Unary *curr) override {} void visitBinary(Binary *curr) override {} - void visitCompare(Compare *curr) override {} - void visitConvert(Convert *curr) override {} void visitSelect(Select *curr) override {} void visitHost(Host *curr) override {} void visitNop(Nop *curr) override {} @@ -1245,8 +1251,8 @@ struct WasmWalker : public WasmVisitor<void> { void visitLoop(Loop *curr) override { parent.walk(curr->body); } - void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override { + parent.walk(curr->condition); parent.walk(curr->value); } void visitSwitch(Switch *curr) override { @@ -1293,13 +1299,6 @@ struct WasmWalker : public WasmVisitor<void> { parent.walk(curr->left); parent.walk(curr->right); } - void visitCompare(Compare *curr) override { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitConvert(Convert *curr) override { - parent.walk(curr->value); - } void visitSelect(Select *curr) override { parent.walk(curr->condition); parent.walk(curr->ifTrue); diff --git a/src/wasm2asm-main.cpp b/src/wasm2asm-main.cpp new file mode 100644 index 000000000..19aa83e2d --- /dev/null +++ b/src/wasm2asm-main.cpp @@ -0,0 +1,60 @@ +// +// wasm2asm console tool +// + +#include "wasm2asm.h" +#include "wasm-s-parser.h" + +using namespace cashew; +using namespace wasm; + +namespace wasm { +int debug = 0; +} + +int main(int argc, char **argv) { + debug = getenv("WASM2ASM_DEBUG") ? getenv("WASM2ASM_DEBUG")[0] - '0' : 0; + + char *infile = argv[1]; + + if (debug) std::cerr << "loading '" << infile << "'...\n"; + FILE *f = fopen(argv[1], "r"); + assert(f); + fseek(f, 0, SEEK_END); + int size = ftell(f); + char *input = new char[size+1]; + rewind(f); + int num = fread(input, 1, size, f); + // On Windows, ftell() gives the byte position (\r\n counts as two bytes), but when + // reading, fread() returns the number of characters read (\r\n is read as one char \n, and counted as one), + // so return value of fread can be less than size reported by ftell, and that is normal. + assert((num > 0 || size == 0) && num <= size); + fclose(f); + input[num] = 0; + + if (debug) std::cerr << "s-parsing...\n"; + SExpressionParser parser(input); + Element& root = *parser.root; + + if (debug) std::cerr << "w-parsing...\n"; + AllocatingModule wasm; + SExpressionWasmBuilder builder(wasm, *root[0], [&]() { abort(); }); + + if (debug) std::cerr << "asming...\n"; + Wasm2AsmBuilder wasm2asm; + Ref asmjs = wasm2asm.processWasm(&wasm); + + if (debug) { + std::cerr << "a-printing...\n"; + asmjs->stringify(std::cout, true); + std::cout << '\n'; + } + + if (debug) std::cerr << "j-printing...\n"; + JSPrinter jser(true, true, asmjs); + jser.printAst(); + std::cout << jser.buffer << "\n"; + + if (debug) std::cerr << "done.\n"; +} + diff --git a/src/wasm2asm.h b/src/wasm2asm.h new file mode 100644 index 000000000..b256a63c5 --- /dev/null +++ b/src/wasm2asm.h @@ -0,0 +1,993 @@ + +// +// WebAssembly-to-asm.js translator. Uses the Emscripten optimizer +// infrastructure. +// + +#include "wasm.h" +#include "emscripten-optimizer/optimizer.h" +#include "mixed_arena.h" +#include "asm_v_wasm.h" +#include "shared-constants.h" + +namespace wasm { + +extern int debug; + +using namespace cashew; + +IString ASM_FUNC("asmFunc"), + ABORT_FUNC("abort"), + FUNCTION_TABLE("FUNCTION_TABLE"), + NO_RESULT("wasm2asm$noresult"), // no result at all + EXPRESSION_RESULT("wasm2asm$expresult"); // result in an expression, no temp var + +// Appends extra to block, flattening out if extra is a block as well +void flattenAppend(Ref ast, Ref extra) { + int index; + if (ast[0] == BLOCK) index = 1; + else if (ast[0] == DEFUN) index = 3; + else abort(); + if (extra[0] == BLOCK) { + for (int i = 0; i < extra[1]->size(); i++) { + ast[index]->push_back(extra[1][i]); + } + } else { + ast[index]->push_back(extra); + } +} + +// +// Wasm2AsmBuilder - converts a WebAssembly module into asm.js +// +// In general, asm.js => wasm is very straightforward, as can +// be seen in asm2wasm.h. Just a single pass, plus a little +// state bookkeeping (breakStack, etc.), and a few after-the +// fact corrections for imports, etc. However, wasm => asm.js +// is tricky because wasm has statements == expressions, or in +// other words, things like `break` and `if` can show up +// in places where asm.js can't handle them, like inside an +// a loop's condition check. +// +// We therefore need the ability to lower an expression into +// a block of statements, and we keep statementizing until we +// reach a context in which we can emit those statments. This +// requires that we create temp variables to store values +// that would otherwise flow directly into their targets if +// we were an expression (e.g. if a loop's condition check +// is a bunch of statements, we execute those statements, +// then use the computed value in the loop's condition; +// we might also be able to avoid an assign to a temp var +// at the end of those statements, and put just that +// value in the loop's condition). +// +// It is possible to do this in a single pass, if we just +// allocate temp vars freely. However, pathological cases +// can easily show bad behavior here, with many unnecessary +// temp vars. We could rely on optimization passes like +// Emscripten's eliminate/registerize pair, but we want +// wasm2asm to be fairly fast to run, as it might run on +// the client. +// +// The approach taken here therefore performs 2 passes on +// each function. First, it finds which expression will need to +// be statementized. It also sees which labels can receive a break +// with a value. Given that information, in the second pass we can +// allocate // temp vars in an efficient manner, as we know when we +// need them and when their use is finished. They are allocated +// using an RAII class, so that they are automatically freed +// when the scope ends. This means that a node cannot allocate +// its own temp var; instead, the parent - which knows the +// child will return a value in a temp var - allocates it, +// and tells the child what temp var to emit to. The child +// can then pass forward that temp var to its children, +// optimizing away unnecessary forwarding. + + +class Wasm2AsmBuilder { +public: + Ref processWasm(Module* wasm); + Ref processFunction(Function* func); + + // The first pass on an expression: scan it to see whether it will + // need to be statementized, and note spooky returns of values at + // a distance (aka break with a value). + void scanFunctionBody(Expression* curr); + + // The second pass on an expression: process it fully, generating + // asm.js + // @param result Whether the context we are in receives a value, + // and its type, or if not, then we can drop our return, + // if we have one. + Ref processFunctionBody(Expression* curr, IString result); + + // Get a temp var. + IString getTemp(WasmType type) { + IString ret; + if (frees[type].size() > 0) { + ret = frees[type].back(); + frees[type].pop_back(); + } else { + size_t index = temps[type]++; + ret = IString((std::string("wasm2asm_") + printWasmType(type) + "$" + std::to_string(index)).c_str(), false); + } + return ret; + } + // Free a temp var. + void freeTemp(WasmType type, IString temp) { + frees[type].push_back(temp); + } + + static IString fromName(Name name) { + // TODO: more clever name fixing, including checking we do not collide + const char *str = name.str; + // check the various issues, and recurse so we check the others + if (strchr(str, '-')) { + char *mod = strdup(str); // XXX leak + str = mod; + while (*mod) { + if (*mod == '-') *mod = '_'; + mod++; + } + return fromName(IString(str, false)); + } + if (isdigit(str[0])) { + std::string prefixed = "$$"; + prefixed += name.str; + return fromName(IString(prefixed.c_str(), false)); + } + return name; + } + + void setStatement(Expression* curr) { + willBeStatement.insert(curr); + } + bool isStatement(Expression* curr) { + return curr && willBeStatement.find(curr) != willBeStatement.end(); + } + + size_t getTableSize() { + return tableSize; + } + +private: + // How many temp vars we need + std::vector<int> temps; // type => num temps + // Which are currently free to use + std::vector<std::vector<IString>> frees; // type => list of free names + + // Expressions that will be a statement. + std::set<Expression*> willBeStatement; + + // All our function tables have the same size TODO: optimize? + size_t tableSize; + + void addBasics(Ref ast); + void addImport(Ref ast, Import *import); + void addTables(Ref ast, Module *wasm); + void addExports(Ref ast, Module *wasm); +}; + +Ref Wasm2AsmBuilder::processWasm(Module* wasm) { + Ref ret = ValueBuilder::makeToplevel(); + Ref asmFunc = ValueBuilder::makeFunction(ASM_FUNC); + ret[1]->push_back(asmFunc); + ValueBuilder::appendArgumentToFunction(asmFunc, GLOBAL); + ValueBuilder::appendArgumentToFunction(asmFunc, ENV); + ValueBuilder::appendArgumentToFunction(asmFunc, BUFFER); + asmFunc[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeString(USE_ASM))); + // create heaps, etc + addBasics(asmFunc[3]); + for (auto import : wasm->imports) { + addImport(asmFunc[3], import); + } + // figure out the table size + tableSize = wasm->table.names.size(); + size_t pow2ed = 1; + while (pow2ed < tableSize) { + pow2ed <<= 1; + } + tableSize = pow2ed; + // functions + for (auto func : wasm->functions) { + asmFunc[3]->push_back(processFunction(func)); + } + addTables(asmFunc[3], wasm); + // memory XXX + addExports(asmFunc[3], wasm); + return ret; +} + +void Wasm2AsmBuilder::addBasics(Ref ast) { + // heaps, var HEAP8 = new global.Int8Array(buffer); etc + auto addHeap = [&](IString name, IString view) { + Ref theVar = ValueBuilder::makeVar(); + ast->push_back(theVar); + ValueBuilder::appendToVar(theVar, + name, + ValueBuilder::makeNew( + ValueBuilder::makeCall( + ValueBuilder::makeDot( + ValueBuilder::makeName(GLOBAL), + view + ), + ValueBuilder::makeName(BUFFER) + ) + ) + ); + }; + addHeap(HEAP8, INT8ARRAY); + addHeap(HEAP16, INT16ARRAY); + addHeap(HEAP32, INT32ARRAY); + addHeap(HEAPU8, UINT8ARRAY); + addHeap(HEAPU16, UINT16ARRAY); + addHeap(HEAPU32, UINT32ARRAY); + addHeap(HEAPF32, FLOAT32ARRAY); + addHeap(HEAPF64, FLOAT64ARRAY); + // core asm.js imports + auto addMath = [&](IString name, IString base) { + Ref theVar = ValueBuilder::makeVar(); + ast->push_back(theVar); + ValueBuilder::appendToVar(theVar, + name, + ValueBuilder::makeDot( + ValueBuilder::makeDot( + ValueBuilder::makeName(GLOBAL), + MATH + ), + base + ) + ); + }; + addMath(MATH_IMUL, IMUL); + addMath(MATH_FROUND, FROUND); + addMath(MATH_ABS, ABS); + addMath(MATH_CLZ32, CLZ32); +} + +void Wasm2AsmBuilder::addImport(Ref ast, Import *import) { + Ref theVar = ValueBuilder::makeVar(); + ast->push_back(theVar); + Ref module = ValueBuilder::makeName(ENV); // TODO: handle nested module imports + ValueBuilder::appendToVar(theVar, + fromName(import->name), + ValueBuilder::makeDot( + module, + fromName(import->base) + ) + ); +} + +void Wasm2AsmBuilder::addTables(Ref ast, Module *wasm) { + std::map<std::string, std::vector<IString>> tables; // asm.js tables, sig => contents of table + for (size_t i = 0; i < wasm->table.names.size(); i++) { + Name name = wasm->table.names[i]; + auto func = wasm->functionsMap[name]; + std::string sig = getSig(func); + auto& table = tables[sig]; + if (table.size() == 0) { + // fill it with the first of its type seen. we have to fill with something; and for asm2wasm output, the first is the null anyhow + table.resize(tableSize); + for (int j = 0; j < tableSize; j++) { + table[j] = fromName(name); + } + } else { + table[i] = fromName(name); + } + } + for (auto& pair : tables) { + auto& sig = pair.first; + auto& table = pair.second; + std::string stable = std::string("FUNCTION_TABLE_") + sig; + IString asmName = IString(stable.c_str(), false); + // add to asm module + Ref theVar = ValueBuilder::makeVar(); + ast->push_back(theVar); + Ref theArray = ValueBuilder::makeArray(); + ValueBuilder::appendToVar(theVar, asmName, theArray); + for (auto& name : table) { + ValueBuilder::appendToArray(theArray, ValueBuilder::makeName(name)); + } + } +} + +void Wasm2AsmBuilder::addExports(Ref ast, Module *wasm) { + Ref exports = ValueBuilder::makeObject(); + for (auto export_ : wasm->exports) { + ValueBuilder::appendToObject(exports, fromName(export_->name), ValueBuilder::makeName(fromName(export_->value))); + } + ast->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(exports))); +} + +Ref Wasm2AsmBuilder::processFunction(Function* func) { + if (debug) std::cerr << " processFunction " << func->name << '\n'; + Ref ret = ValueBuilder::makeFunction(fromName(func->name)); + frees.clear(); + frees.resize(std::max(i32, std::max(f32, f64)) + 1); + temps.clear(); + temps.resize(std::max(i32, std::max(f32, f64)) + 1); + temps[i32] = temps[f32] = temps[f64] = 0; + // arguments + for (auto& param : func->params) { + IString name = fromName(param.name); + ValueBuilder::appendArgumentToFunction(ret, name); + ret[3]->push_back( + ValueBuilder::makeStatement( + ValueBuilder::makeAssign( + ValueBuilder::makeName(name), + makeAsmCoercion(ValueBuilder::makeName(name), wasmToAsmType(param.type)) + ) + ) + ); + } + Ref theVar = ValueBuilder::makeVar(); + size_t theVarIndex = ret[3]->size(); + ret[3]->push_back(theVar); + // body + scanFunctionBody(func->body); + if (isStatement(func->body)) { + IString result = func->result != none ? getTemp(func->result) : NO_RESULT; + flattenAppend(ret, ValueBuilder::makeStatement(processFunctionBody(func->body, result))); + if (func->result != none) { + // do the actual return + ret[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(makeAsmCoercion(ValueBuilder::makeName(result), wasmToAsmType(func->result))))); + freeTemp(func->result, result); + } + } else { + // whole thing is an expression, just do a return + if (func->result != none) { + ret[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeReturn(makeAsmCoercion(processFunctionBody(func->body, EXPRESSION_RESULT), wasmToAsmType(func->result))))); + } else { + flattenAppend(ret, processFunctionBody(func->body, NO_RESULT)); + } + } + // locals, including new temp locals + for (auto& local : func->locals) { + ValueBuilder::appendToVar(theVar, fromName(local.name), makeAsmCoercedZero(wasmToAsmType(local.type))); + } + for (auto f : frees[i32]) { + ValueBuilder::appendToVar(theVar, f, makeAsmCoercedZero(ASM_INT)); + } + for (auto f : frees[f32]) { + ValueBuilder::appendToVar(theVar, f, makeAsmCoercedZero(ASM_FLOAT)); + } + for (auto f : frees[f64]) { + ValueBuilder::appendToVar(theVar, f, makeAsmCoercedZero(ASM_DOUBLE)); + } + if (theVar[1]->size() == 0) { + ret[3]->splice(theVarIndex, 1); + } + // checks + assert(frees[i32].size() == temps[i32]); // all temp vars should be free at the end + assert(frees[f32].size() == temps[f32]); // all temp vars should be free at the end + assert(frees[f64].size() == temps[f64]); // all temp vars should be free at the end + // cleanups + willBeStatement.clear(); + return ret; +} + +void Wasm2AsmBuilder::scanFunctionBody(Expression* curr) { + struct ExpressionScanner : public WasmWalker { + Wasm2AsmBuilder* parent; + + ExpressionScanner(Wasm2AsmBuilder* parent) : parent(parent) {} + + // Visitors + + void visitBlock(Block *curr) override { + parent->setStatement(curr); + } + void visitIf(If *curr) override { + parent->setStatement(curr); + } + void visitLoop(Loop *curr) override { + parent->setStatement(curr); + } + void visitBreak(Break *curr) override { + parent->setStatement(curr); + } + void visitSwitch(Switch *curr) override { + parent->setStatement(curr); + } + void visitCall(Call *curr) override { + for (auto item : curr->operands) { + if (parent->isStatement(item)) { + parent->setStatement(curr); + break; + } + } + } + void visitCallImport(CallImport *curr) override { + visitCall(curr); + } + void visitCallIndirect(CallIndirect *curr) override { + if (parent->isStatement(curr->target)) { + parent->setStatement(curr); + return; + } + for (auto item : curr->operands) { + if (parent->isStatement(item)) { + parent->setStatement(curr); + break; + } + } + } + void visitSetLocal(SetLocal *curr) override { + if (parent->isStatement(curr->value)) { + parent->setStatement(curr); + } + } + void visitLoad(Load *curr) override { + if (parent->isStatement(curr->ptr)) { + parent->setStatement(curr); + } + } + void visitStore(Store *curr) override { + if (parent->isStatement(curr->ptr) || parent->isStatement(curr->value)) { + parent->setStatement(curr); + } + } + void visitUnary(Unary *curr) override { + if (parent->isStatement(curr->value)) { + parent->setStatement(curr); + } + } + void visitBinary(Binary *curr) override { + if (parent->isStatement(curr->left) || parent->isStatement(curr->right)) { + parent->setStatement(curr); + } + } + void visitSelect(Select *curr) override { + if (parent->isStatement(curr->condition) || parent->isStatement(curr->ifTrue) || parent->isStatement(curr->ifFalse)) { + parent->setStatement(curr); + } + } + void visitHost(Host *curr) override { + for (auto item : curr->operands) { + if (parent->isStatement(item)) { + parent->setStatement(curr); + break; + } + } + } + }; + ExpressionScanner(this).walk(curr); +} + +Ref Wasm2AsmBuilder::processFunctionBody(Expression* curr, IString result) { + struct ExpressionProcessor : public WasmVisitor<Ref> { + Wasm2AsmBuilder* parent; + IString result; + ExpressionProcessor(Wasm2AsmBuilder* parent) : parent(parent) {} + + // A scoped temporary variable. + struct ScopedTemp { + Wasm2AsmBuilder* parent; + WasmType type; + IString temp; + bool needFree; + // @param possible if provided, this is a variable we can use as our temp. it has already been + // allocated in a higher scope, and we can just assign to it as our result is + // going there anyhow. + ScopedTemp(WasmType type, Wasm2AsmBuilder* parent, IString possible = NO_RESULT) : parent(parent), type(type) { + assert(possible != EXPRESSION_RESULT); + if (possible == NO_RESULT) { + temp = parent->getTemp(type); + needFree = true; + } else { + temp = possible; + needFree = false; + } + } + ~ScopedTemp() { + if (needFree) { + parent->freeTemp(type, temp); + } + } + + IString getName() { + return temp; + } + Ref getAstName() { + return ValueBuilder::makeName(temp); + } + }; + + Ref visit(Expression* curr, IString nextResult) { + IString old = result; + result = nextResult; + Ref ret = WasmVisitor::visit(curr); + result = old; // keep it consistent for the rest of this frame, which may call visit on multiple children + return ret; + } + + Ref visit(Expression* curr, ScopedTemp& temp) { + return visit(curr, temp.temp); + } + + Ref visitForExpression(Expression* curr, WasmType type, IString& tempName) { // this result is for an asm expression slot, but it might be a statement + if (isStatement(curr)) { + ScopedTemp temp(type, parent); + tempName = temp.temp; + return visit(curr, temp); + } else { + return visit(curr, EXPRESSION_RESULT); + } + } + + Ref visitAndAssign(Expression* curr, IString result) { + Ref ret = visit(curr, result); + // if it's not already a statement, then it's an expression, and we need to assign it + // (if it is a statement, it already assigns to the result var) + if (!isStatement(curr) && result != NO_RESULT) { + ret = ValueBuilder::makeStatement(ValueBuilder::makeAssign(ValueBuilder::makeName(result), ret)); + } + return ret; + } + + Ref visitAndAssign(Expression* curr, ScopedTemp& temp) { + return visitAndAssign(curr, temp.getName()); + } + + bool isStatement(Expression* curr) { + return parent->isStatement(curr); + } + + // Expressions with control flow turn into a block, which we must + // then handle, even if we are an expression. + bool isBlock(Ref ast) { + return !!ast && ast[0] == BLOCK; + } + + Ref blockify(Ref ast) { + if (isBlock(ast)) return ast; + Ref ret = ValueBuilder::makeBlock(); + ret[1]->push_back(ValueBuilder::makeStatement(ast)); + return ret; + } + + // For spooky return-at-a-distance/break-with-result, this tells us + // what the result var is for a specific label. + std::map<Name, IString> breakResults; + + // Breaks to the top of a loop should be emitted as continues, to that loop's main label + std::map<Name, Name> continueLabels; + + IString fromName(Name name) { + return parent->fromName(name); + } + + // Visitors + + Ref visitBlock(Block *curr) override { + breakResults[curr->name] = result; + Ref ret = ValueBuilder::makeBlock(); + size_t size = curr->list.size(); + int noResults = result == NO_RESULT ? size : size-1; + for (size_t i = 0; i < noResults; i++) { + flattenAppend(ret, ValueBuilder::makeStatement(visit(curr->list[i], NO_RESULT))); + } + if (result != NO_RESULT) { + flattenAppend(ret, visitAndAssign(curr->list[size-1], result)); + } + if (curr->name.is()) { + ret = ValueBuilder::makeLabel(fromName(curr->name), ret); + } + return ret; + } + Ref visitIf(If *curr) override { + IString temp; + Ref condition = visitForExpression(curr->condition, i32, temp); + Ref ifTrue = ValueBuilder::makeStatement(visitAndAssign(curr->ifTrue, result)); + Ref ifFalse; + if (curr->ifFalse) { + ifFalse = ValueBuilder::makeStatement(visitAndAssign(curr->ifFalse, result)); + } + if (temp.isNull()) { + return ValueBuilder::makeIf(condition, ifTrue, ifFalse); // simple if + } + condition = blockify(condition); + // just add an if to the block + condition[1]->push_back(ValueBuilder::makeIf(ValueBuilder::makeName(temp), ifTrue, ifFalse)); + return condition; + } + Ref visitLoop(Loop *curr) override { + Name asmLabel = curr->out.is() ? curr->out : curr->in; // label using the outside, normal for breaks. if no outside, then inside + if (curr->in.is()) continueLabels[curr->in] = asmLabel; + Ref body = visit(curr->body, result); + Ref ret = ValueBuilder::makeDo(body, ValueBuilder::makeInt(0)); + if (asmLabel.is()) { + ret = ValueBuilder::makeLabel(fromName(asmLabel), ret); + } + return ret; + } + Ref visitBreak(Break *curr) override { + if (curr->condition) { + // we need an equivalent to an if here, so use that code + Break fakeBreak = *curr; + fakeBreak.condition = nullptr; + If fakeIf; + fakeIf.condition = curr->condition; + fakeIf.ifTrue = &fakeBreak; + return visit(&fakeIf, result); + } + Ref theBreak; + auto iter = continueLabels.find(curr->name); + if (iter == continueLabels.end()) { + theBreak = ValueBuilder::makeBreak(fromName(curr->name)); + } else { + theBreak = ValueBuilder::makeContinue(fromName(iter->second)); + } + if (!curr->value) return theBreak; + // generate the value, including assigning to the result, and then do the break + Ref ret = visitAndAssign(curr->value, breakResults[curr->name]); + ret = blockify(ret); + ret[1]->push_back(theBreak); + return ret; + } + Ref visitSwitch(Switch *curr) override { + Ref ret = ValueBuilder::makeLabel(fromName(curr->name), ValueBuilder::makeBlock()); + Ref value; + if (isStatement(curr->value)) { + ScopedTemp temp(i32, parent); + flattenAppend(ret[2], visit(curr->value, temp)); + value = temp.getAstName(); + } else { + value = visit(curr->value, EXPRESSION_RESULT); + } + Ref theSwitch = ValueBuilder::makeSwitch(value); + ret[2][1]->push_back(theSwitch); + for (auto& c : curr->cases) { + bool added = false; + for (size_t i = 0; i < curr->targets.size(); i++) { + if (curr->targets[i] == c.name) { + ValueBuilder::appendCaseToSwitch(theSwitch, ValueBuilder::makeNum(i)); + added = true; + } + } + if (c.name == curr->default_) { + ValueBuilder::appendDefaultToSwitch(theSwitch); + added = true; + } + assert(added); + ValueBuilder::appendCodeToSwitch(theSwitch, blockify(visit(c.body, NO_RESULT)), false); + } + return ret; + } + + Ref makeStatementizedCall(ExpressionList& operands, Ref ret, Ref theCall, IString result, WasmType type) { + std::vector<ScopedTemp*> temps; // TODO: utility class, with destructor? + for (auto& operand : operands) { + temps.push_back(new ScopedTemp(operand->type, parent)); + IString temp = temps.back()->temp; + flattenAppend(ret, visitAndAssign(operand, temp)); + theCall[2]->push_back(makeAsmCoercion(ValueBuilder::makeName(temp), wasmToAsmType(operand->type))); + } + theCall = makeAsmCoercion(theCall, wasmToAsmType(type)); + if (result != NO_RESULT) { + theCall = ValueBuilder::makeStatement(ValueBuilder::makeAssign(ValueBuilder::makeName(result), theCall)); + } + flattenAppend(ret, theCall); + for (auto temp : temps) { + delete temp; + } + return ret; + } + + Ref visitCall(Call *curr) override { + Ref theCall = ValueBuilder::makeCall(fromName(curr->target)); + if (!isStatement(curr)) { + // none of our operands is a statement; go right ahead and create a simple expression + for (auto operand : curr->operands) { + theCall[2]->push_back(makeAsmCoercion(visit(operand, EXPRESSION_RESULT), wasmToAsmType(operand->type))); + } + return makeAsmCoercion(theCall, wasmToAsmType(curr->type)); + } + // we must statementize them all + return makeStatementizedCall(curr->operands, ValueBuilder::makeBlock(), theCall, result, curr->type); + } + Ref visitCallImport(CallImport *curr) override { + return visitCall(curr); + } + Ref visitCallIndirect(CallIndirect *curr) override { + std::string stable = std::string("FUNCTION_TABLE_") + getSig(curr->fullType); + IString table = IString(stable.c_str(), false); + auto makeTableCall = [&](Ref target) { + return ValueBuilder::makeCall(ValueBuilder::makeSub( + ValueBuilder::makeName(table), + ValueBuilder::makeBinary(target, AND, ValueBuilder::makeInt(parent->getTableSize()-1)) + )); + }; + if (!isStatement(curr)) { + // none of our operands is a statement; go right ahead and create a simple expression + Ref theCall = makeTableCall(visit(curr->target, EXPRESSION_RESULT)); + for (auto operand : curr->operands) { + theCall[2]->push_back(makeAsmCoercion(visit(operand, EXPRESSION_RESULT), wasmToAsmType(operand->type))); + } + return makeAsmCoercion(theCall, wasmToAsmType(curr->type)); + } + // we must statementize them all + Ref ret = ValueBuilder::makeBlock(); + ScopedTemp temp(i32, parent); + flattenAppend(ret, visit(curr->target, temp)); + Ref theCall = makeTableCall(temp.getAstName()); + return makeStatementizedCall(curr->operands, ret, theCall, result, curr->type); + } + Ref visitGetLocal(GetLocal *curr) override { + return ValueBuilder::makeName(fromName(curr->name)); + } + Ref visitSetLocal(SetLocal *curr) override { + if (!isStatement(curr)) { + return ValueBuilder::makeAssign(ValueBuilder::makeName(fromName(curr->name)), visit(curr->value, EXPRESSION_RESULT)); + } + ScopedTemp temp(curr->type, parent, result); // if result was provided, our child can just assign there. otherwise, allocate a temp for it to assign to. + Ref ret = blockify(visit(curr->value, temp)); + // the output was assigned to result, so we can just assign it to our target + ret[1]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeAssign(ValueBuilder::makeName(fromName(curr->name)), temp.getAstName()))); + return ret; + } + Ref visitLoad(Load *curr) override { + if (isStatement(curr)) { + ScopedTemp temp(i32, parent); + GetLocal fakeLocal; + fakeLocal.name = temp.getName(); + Load fakeLoad = *curr; + fakeLoad.ptr = &fakeLocal; + Ref ret = blockify(visitAndAssign(curr->ptr, temp)); + flattenAppend(ret, visitAndAssign(&fakeLoad, result)); + return ret; + } + // normal load + assert(curr->bytes == curr->align); // TODO: unaligned + Ref ptr = visit(curr->ptr, EXPRESSION_RESULT); + Ref ret; + switch (curr->type) { + case i32: { + switch (curr->bytes) { + case 1: ret = ValueBuilder::makeSub(ValueBuilder::makeName(curr->signed_ ? HEAP8 : HEAPU8 ), ValueBuilder::makePtrShift(ptr, 0)); break; + case 2: ret = ValueBuilder::makeSub(ValueBuilder::makeName(curr->signed_ ? HEAP16 : HEAPU16), ValueBuilder::makePtrShift(ptr, 1)); break; + case 4: ret = ValueBuilder::makeSub(ValueBuilder::makeName(curr->signed_ ? HEAP32 : HEAPU32), ValueBuilder::makePtrShift(ptr, 2)); break; + default: abort(); + } + break; + } + case f32: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAPF32), ValueBuilder::makePtrShift(ptr, 2)); break; + case f64: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAPF64), ValueBuilder::makePtrShift(ptr, 3)); break; + default: abort(); + } + return makeAsmCoercion(ret, wasmToAsmType(curr->type)); + } + Ref visitStore(Store *curr) override { + if (isStatement(curr)) { + ScopedTemp tempPtr(i32, parent); + ScopedTemp tempValue(curr->type, parent); + GetLocal fakeLocalPtr; + fakeLocalPtr.name = tempPtr.getName(); + GetLocal fakeLocalValue; + fakeLocalValue.name = tempValue.getName(); + Store fakeStore = *curr; + fakeStore.ptr = &fakeLocalPtr; + fakeStore.value = &fakeLocalValue; + Ref ret = blockify(visitAndAssign(curr->ptr, tempPtr)); + flattenAppend(ret, visitAndAssign(curr->value, tempValue)); + flattenAppend(ret, visitAndAssign(&fakeStore, result)); + return ret; + } + // normal store + assert(curr->bytes == curr->align); // TODO: unaligned + Ref ptr = visit(curr->ptr, EXPRESSION_RESULT); + Ref value = visit(curr->value, EXPRESSION_RESULT); + Ref ret; + switch (curr->type) { + case i32: { + switch (curr->bytes) { + case 1: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAP8), ValueBuilder::makePtrShift(ptr, 0)); break; + case 2: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAP16), ValueBuilder::makePtrShift(ptr, 1)); break; + case 4: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAP32), ValueBuilder::makePtrShift(ptr, 2)); break; + default: abort(); + } + break; + } + case f32: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAPF32), ValueBuilder::makePtrShift(ptr, 2)); break; + case f64: ret = ValueBuilder::makeSub(ValueBuilder::makeName(HEAPF64), ValueBuilder::makePtrShift(ptr, 3)); break; + default: abort(); + } + return ValueBuilder::makeAssign(ret, value); + } + Ref visitConst(Const *curr) override { + switch (curr->type) { + case i32: return ValueBuilder::makeInt(curr->value.i32); + case f32: { + Ref ret = ValueBuilder::makeCall(MATH_FROUND); + Const fake; + fake.value = double(curr->value.f32); + fake.type = f64; + ret[2]->push_back(visitConst(&fake)); + return ret; + } + case f64: { + double d = curr->value.f64; + if (d == 0 && 1/d < 0) { // negative zero + return ValueBuilder::makeUnary(PLUS, ValueBuilder::makeUnary(MINUS, ValueBuilder::makeDouble(0))); + } + return ValueBuilder::makeUnary(PLUS, ValueBuilder::makeDouble(curr->value.f64)); + } + default: abort(); + } + } + Ref visitUnary(Unary *curr) override { + if (isStatement(curr)) { + ScopedTemp temp(curr->value->type, parent); + GetLocal fakeLocal; + fakeLocal.name = temp.getName(); + Unary fakeUnary = *curr; + fakeUnary.value = &fakeLocal; + Ref ret = blockify(visitAndAssign(curr->value, temp)); + flattenAppend(ret, visitAndAssign(&fakeUnary, result)); + return ret; + } + // normal unary + Ref value = visit(curr->value, EXPRESSION_RESULT); + switch (curr->type) { + case i32: { + switch (curr->op) { + case Clz: return ValueBuilder::makeCall(MATH_CLZ32, value); + case Ctz: return ValueBuilder::makeCall(MATH_CTZ32, value); + case Popcnt: return ValueBuilder::makeCall(MATH_POPCNT32, value); + default: abort(); + } + } + case f32: + case f64: { + Ref ret; + switch (curr->op) { + case Neg: ret = ValueBuilder::makeUnary(MINUS, value); break; + case Abs: ret = ValueBuilder::makeCall(MATH_ABS, value); break; + case Ceil: ret = ValueBuilder::makeCall(MATH_CEIL, value); break; + case Floor: ret = ValueBuilder::makeCall(MATH_FLOOR, value); break; + case Trunc: ret = ValueBuilder::makeCall(MATH_TRUNC, value); break; + case Nearest: ret = ValueBuilder::makeCall(MATH_NEAREST, value); break; + case Sqrt: ret = ValueBuilder::makeCall(MATH_SQRT, value); break; + case TruncSFloat32: ret = ValueBuilder::makePrefix(B_NOT, ValueBuilder::makePrefix(B_NOT, value)); break; + case PromoteFloat32: + case ConvertSInt32: ret = ValueBuilder::makePrefix(PLUS, ValueBuilder::makeBinary(value, OR, ValueBuilder::makeNum(0))); break; + case ConvertUInt32: ret = ValueBuilder::makePrefix(PLUS, ValueBuilder::makeBinary(value, TRSHIFT, ValueBuilder::makeNum(0))); break; + case DemoteFloat64: ret = value; break; + default: std::cerr << curr << '\n'; abort(); + } + if (curr->type == f32) { // doubles need much less coercing + return makeAsmCoercion(ret, ASM_FLOAT); + } + return ret; + } + default: abort(); + } + } + Ref visitBinary(Binary *curr) override { + if (isStatement(curr)) { + ScopedTemp tempLeft(curr->left->type, parent); + GetLocal fakeLocalLeft; + fakeLocalLeft.name = tempLeft.getName(); + ScopedTemp tempRight(curr->right->type, parent); + GetLocal fakeLocalRight; + fakeLocalRight.name = tempRight.getName(); + Binary fakeBinary = *curr; + fakeBinary.left = &fakeLocalLeft; + fakeBinary.right = &fakeLocalRight; + Ref ret = blockify(visitAndAssign(curr->left, tempLeft)); + flattenAppend(ret, visitAndAssign(curr->right, tempRight)); + flattenAppend(ret, visitAndAssign(&fakeBinary, result)); + return ret; + } + // normal binary + Ref left = visit(curr->left, EXPRESSION_RESULT); + Ref right = visit(curr->right, EXPRESSION_RESULT); + Ref ret; + switch (curr->op) { + case Add: ret = ValueBuilder::makeBinary(left, PLUS, right); break; + case Sub: ret = ValueBuilder::makeBinary(left, MINUS, right); break; + case Mul: { + if (curr->type == i32) { + return ValueBuilder::makeCall(MATH_IMUL, left, right); // TODO: when one operand is a small int, emit a multiply + } else { + return ValueBuilder::makeBinary(left, MINUS, right); break; + } + } + case DivS: ret = ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), DIV, makeSigning(right, ASM_SIGNED)); break; + case DivU: ret = ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), DIV, makeSigning(right, ASM_UNSIGNED)); break; + case RemS: ret = ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), MOD, makeSigning(right, ASM_SIGNED)); break; + case RemU: ret = ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), MOD, makeSigning(right, ASM_UNSIGNED)); break; + case And: ret = ValueBuilder::makeBinary(left, AND, right); break; + case Or: ret = ValueBuilder::makeBinary(left, OR, right); break; + case Xor: ret = ValueBuilder::makeBinary(left, XOR, right); break; + case Shl: ret = ValueBuilder::makeBinary(left, LSHIFT, right); break; + case ShrU: ret = ValueBuilder::makeBinary(left, TRSHIFT, right); break; + case ShrS: ret = ValueBuilder::makeBinary(left, RSHIFT, right); break; + case Div: ret = ValueBuilder::makeBinary(left, DIV, right); break; + case Min: ret = ValueBuilder::makeCall(MATH_MIN, left, right); break; + case Max: ret = ValueBuilder::makeCall(MATH_MAX, left, right); break; + case Eq: { + if (curr->left->type == i32) { + return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), EQ, makeSigning(right, ASM_SIGNED)); + } else { + return ValueBuilder::makeBinary(left, EQ, right); + } + } + case Ne: { + if (curr->left->type == i32) { + return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), NE, makeSigning(right, ASM_SIGNED)); + } else { + return ValueBuilder::makeBinary(left, NE, right); + } + } + case LtS: return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), LT, makeSigning(right, ASM_SIGNED)); + case LtU: return ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), LT, makeSigning(right, ASM_UNSIGNED)); + case LeS: return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), LE, makeSigning(right, ASM_SIGNED)); + case LeU: return ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), LE, makeSigning(right, ASM_UNSIGNED)); + case GtS: return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), GT, makeSigning(right, ASM_SIGNED)); + case GtU: return ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), GT, makeSigning(right, ASM_UNSIGNED)); + case GeS: return ValueBuilder::makeBinary(makeSigning(left, ASM_SIGNED), GE, makeSigning(right, ASM_SIGNED)); + case GeU: return ValueBuilder::makeBinary(makeSigning(left, ASM_UNSIGNED), GE, makeSigning(right, ASM_UNSIGNED)); + case Lt: return ValueBuilder::makeBinary(left, LT, right); + case Le: return ValueBuilder::makeBinary(left, LE, right); + case Gt: return ValueBuilder::makeBinary(left, GT, right); + case Ge: return ValueBuilder::makeBinary(left, GE, right); + default: abort(); + } + return makeAsmCoercion(ret, wasmToAsmType(curr->type)); + } + Ref visitSelect(Select *curr) override { + if (isStatement(curr)) { + ScopedTemp tempCondition(i32, parent); + GetLocal fakeCondition; + fakeCondition.name = tempCondition.getName(); + ScopedTemp tempIfTrue(curr->ifTrue->type, parent); + GetLocal fakeLocalIfTrue; + fakeLocalIfTrue.name = tempIfTrue.getName(); + ScopedTemp tempIfFalse(curr->ifFalse->type, parent); + GetLocal fakeLocalIfFalse; + fakeLocalIfFalse.name = tempIfFalse.getName(); + Select fakeSelect = *curr; + fakeSelect.condition = &fakeCondition; + fakeSelect.ifTrue = &fakeLocalIfTrue; + fakeSelect.ifFalse = &fakeLocalIfFalse; + Ref ret = blockify(visitAndAssign(curr->condition, tempCondition)); + flattenAppend(ret, visitAndAssign(curr->ifTrue, tempIfTrue)); + flattenAppend(ret, visitAndAssign(curr->ifFalse, tempIfFalse)); + flattenAppend(ret, visitAndAssign(&fakeSelect, result)); + return ret; + } + // normal select + Ref condition = visit(curr->condition, EXPRESSION_RESULT); + Ref ifTrue = visit(curr->ifTrue, EXPRESSION_RESULT); + Ref ifFalse = visit(curr->ifFalse, EXPRESSION_RESULT); + ScopedTemp tempCondition(i32, parent), + tempIfTrue(curr->type, parent), + tempIfFalse(curr->type, parent); + return + ValueBuilder::makeSeq( + ValueBuilder::makeAssign(tempCondition.getAstName(), condition), + ValueBuilder::makeSeq( + ValueBuilder::makeAssign(tempIfTrue.getAstName(), ifTrue), + ValueBuilder::makeSeq( + ValueBuilder::makeAssign(tempIfFalse.getAstName(), ifFalse), + ValueBuilder::makeConditional(tempCondition.getAstName(), tempIfTrue.getAstName(), tempIfFalse.getAstName()) + ) + ) + ); + } + Ref visitHost(Host *curr) override { + abort(); + } + Ref visitNop(Nop *curr) override { + return ValueBuilder::makeToplevel(); + } + Ref visitUnreachable(Unreachable *curr) override { + return ValueBuilder::makeCall(ABORT_FUNC); + } + }; + return ExpressionProcessor(this).visit(curr, result); +} + +} // namespace wasm + |