diff options
author | Alon Zakai <alonzakai@gmail.com> | 2015-11-27 16:39:05 -0800 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2015-11-27 18:45:24 -0800 |
commit | e7461ed17e5f0cc9e49ada34b0fb340dce8e9b49 (patch) | |
tree | 49a5810350437931c48612107f9f0694e4b2bc51 | |
parent | 26842a8da165276fd9f38dc4bab2267269c237a0 (diff) | |
download | binaryen-e7461ed17e5f0cc9e49ada34b0fb340dce8e9b49.tar.gz binaryen-e7461ed17e5f0cc9e49ada34b0fb340dce8e9b49.tar.bz2 binaryen-e7461ed17e5f0cc9e49ada34b0fb340dce8e9b49.zip |
unify convert/compare into unary/binary, and do a pre-pass in s-expression parser for function types, to fix new assertions that notice some missing types
-rw-r--r-- | src/asm2wasm.h | 138 | ||||
-rw-r--r-- | src/binaryen-shell.cpp | 5 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 197 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 115 | ||||
-rw-r--r-- | src/wasm.h | 202 |
5 files changed, 293 insertions, 364 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 20ea6f056..7f7be358b 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -327,19 +327,18 @@ private: return detectSign(ast, Math_fround) == ASM_UNSIGNED; } - // an asm.js binary op can either be a binary or a relational in wasm - bool parseAsmBinaryOp(IString op, Ref left, Ref right, BinaryOp &binary, RelationalOp &relational, AsmData *asmData) { - if (op == PLUS) { binary = BinaryOp::Add; return true; } - if (op == MINUS) { binary = BinaryOp::Sub; return true; } - if (op == MUL) { binary = BinaryOp::Mul; return true; } - if (op == AND) { binary = BinaryOp::And; return true; } - if (op == OR) { binary = BinaryOp::Or; return true; } - if (op == XOR) { binary = BinaryOp::Xor; return true; } - if (op == LSHIFT) { binary = BinaryOp::Shl; return true; } - if (op == RSHIFT) { binary = BinaryOp::ShrS; return true; } - if (op == TRSHIFT) { binary = BinaryOp::ShrU; return true; } - if (op == EQ) { relational = RelationalOp::Eq; return false; } - if (op == NE) { relational = RelationalOp::Ne; return false; } + BinaryOp parseAsmBinaryOp(IString op, Ref left, Ref right, AsmData *asmData) { + if (op == PLUS) return BinaryOp::Add; + if (op == MINUS) return BinaryOp::Sub; + if (op == MUL) return BinaryOp::Mul; + if (op == AND) return BinaryOp::And; + if (op == OR) return BinaryOp::Or; + if (op == XOR) return BinaryOp::Xor; + if (op == LSHIFT) return BinaryOp::Shl; + if (op == RSHIFT) return BinaryOp::ShrS; + if (op == TRSHIFT) return BinaryOp::ShrU; + if (op == EQ) return BinaryOp::Eq; + if (op == NE) return BinaryOp::Ne; WasmType leftType = detectWasmType(left, asmData); #if 0 std::cout << "CHECK\n"; @@ -353,42 +352,42 @@ private: bool isUnsigned = isUnsignedCoercion(left) || isUnsignedCoercion(right); if (op == DIV) { if (isInteger) { - { binary = isUnsigned ? BinaryOp::DivU : BinaryOp::DivS; return true; } + return isUnsigned ? BinaryOp::DivU : BinaryOp::DivS; } - { binary = BinaryOp::Div; return true; } + return BinaryOp::Div; } if (op == MOD) { if (isInteger) { - { binary = isUnsigned ? BinaryOp::RemU : BinaryOp::RemS; return true; } + return isUnsigned ? BinaryOp::RemU : BinaryOp::RemS; } - { binary = BinaryOp::RemS; return true; } // XXX no floating-point remainder op, this must be handled by the caller + return BinaryOp::RemS; // XXX no floating-point remainder op, this must be handled by the caller } if (op == GE) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::GeU : RelationalOp::GeS; return false; } + return isUnsigned ? BinaryOp::GeU : BinaryOp::GeS; } - { relational = RelationalOp::Ge; return false; } + return BinaryOp::Ge; } if (op == GT) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::GtU : RelationalOp::GtS; return false; } + return isUnsigned ? BinaryOp::GtU : BinaryOp::GtS; } - { relational = RelationalOp::Gt; return false; } + return BinaryOp::Gt; } if (op == LE) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::LeU : RelationalOp::LeS; return false; } + return isUnsigned ? BinaryOp::LeU : BinaryOp::LeS; } - { relational = RelationalOp::Le; return false; } + return BinaryOp::Le; } if (op == LT) { if (isInteger) { - { relational = isUnsigned ? RelationalOp::LtU : RelationalOp::LtS; return false; } + return isUnsigned ? BinaryOp::LtU : BinaryOp::LtS; } - { relational = RelationalOp::Lt; return false; } + return BinaryOp::Lt; } abort_on("bad wasm binary op", op); - return false; // avoid warning + abort(); // avoid warning } unsigned bytesToShift(unsigned bytes) { @@ -891,7 +890,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { if (ret->type != ret->value->type) { // in asm.js we have some implicit coercions that we must do explicitly here if (ret->type == f32 && ret->value->type == f64) { - auto conv = allocator.alloc<Convert>(); + auto conv = allocator.alloc<Unary>(); conv->op = DemoteFloat64; conv->value = ret->value; conv->type = WasmType::f32; @@ -909,47 +908,35 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { fixCallType(ret, i32); return ret; } - BinaryOp binary; - RelationalOp relational; - bool isBinary = parseAsmBinaryOp(ast[1]->getIString(), ast[2], ast[3], binary, relational, &asmData); - if (isBinary) { - auto ret = allocator.alloc<Binary>(); - ret->op = binary; - ret->left = process(ast[2]); - ret->right = process(ast[3]); - ret->type = ret->left->type; - if (binary == BinaryOp::RemS && isWasmTypeFloat(ret->type)) { - // WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours - CallImport *call = allocator.alloc<CallImport>(); - call->target = F64_REM; - call->operands.push_back(ret->left); - call->operands.push_back(ret->right); - call->type = f64; - static bool addedImport = false; - if (!addedImport) { - addedImport = true; - auto import = allocator.alloc<Import>(); // f64-rem = asm2wasm.f64-rem; - import->name = F64_REM; - import->module = ASM2WASM; - import->base = F64_REM; - import->type.name = F64_REM; - import->type.result = f64; - import->type.params.push_back(f64); - import->type.params.push_back(f64); - wasm.addImport(import); - } - return call; + BinaryOp binary = parseAsmBinaryOp(ast[1]->getIString(), ast[2], ast[3], &asmData); + auto ret = allocator.alloc<Binary>(); + ret->op = binary; + ret->left = process(ast[2]); + ret->right = process(ast[3]); + ret->finalize(); + if (binary == BinaryOp::RemS && isWasmTypeFloat(ret->type)) { + // WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours + CallImport *call = allocator.alloc<CallImport>(); + call->target = F64_REM; + call->operands.push_back(ret->left); + call->operands.push_back(ret->right); + call->type = f64; + static bool addedImport = false; + if (!addedImport) { + addedImport = true; + auto import = allocator.alloc<Import>(); // f64-rem = asm2wasm.f64-rem; + import->name = F64_REM; + import->module = ASM2WASM; + import->base = F64_REM; + import->type.name = F64_REM; + import->type.result = f64; + import->type.params.push_back(f64); + import->type.params.push_back(f64); + wasm.addImport(import); } - return ret; - } else { - auto ret = allocator.alloc<Compare>(); - ret->op = relational; - ret->left = process(ast[2]); - ret->right = process(ast[3]); - assert(ret->left->type == ret->right->type); - ret->inputType = ret->left->type; - return ret; + return call; } + return ret; } else if (what == NUM) { auto ret = allocator.alloc<Const>(); double num = ast[1]->getNumber(); @@ -1024,14 +1011,14 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } auto ret = process(ast[2]); // we are a +() coercion if (ret->type == i32) { - auto conv = allocator.alloc<Convert>(); + auto conv = allocator.alloc<Unary>(); conv->op = isUnsignedCoercion(ast[2]) ? ConvertUInt32 : ConvertSInt32; conv->value = ret; conv->type = WasmType::f64; return conv; } if (ret->type == f32) { - auto conv = allocator.alloc<Convert>(); + auto conv = allocator.alloc<Unary>(); conv->op = PromoteFloat32; conv->value = ret; conv->type = WasmType::f64; @@ -1071,7 +1058,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { // ~, might be ~~ as a coercion or just a not if (ast[2][0] == UNARY_PREFIX && ast[2][1] == B_NOT) { #if 0 - auto ret = allocator.alloc<Convert>(); + auto ret = allocator.alloc<Unary>(); ret->op = TruncSFloat64; // equivalent to U, except for error handling, which asm.js doesn't have anyhow ret->value = process(ast[2][2]); ret->type = WasmType::i32; @@ -1105,12 +1092,12 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { return ret; } else if (ast[1] == L_NOT) { // no logical unary not, so do == 0 - auto ret = allocator.alloc<Compare>(); + auto ret = allocator.alloc<Binary>(); ret->op = Eq; ret->left = process(ast[2]); ret->right = allocator.alloc<Const>()->set(Literal(0)); assert(ret->left->type == ret->right->type); - ret->inputType = ret->left->type; + ret->finalize(); return ret; } abort_on("bad unary", ast); @@ -1148,7 +1135,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } else if (lit.type == f64) { return allocator.alloc<Const>()->set(Literal((float)lit.getf64())); } - auto ret = allocator.alloc<Convert>(); + auto ret = allocator.alloc<Unary>(); ret->value = process(ast[2][0]); if (ret->value->type == f64) { ret->op = DemoteFloat64; @@ -1181,11 +1168,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->type = i32; return ret; }; - auto isNegative = allocator.alloc<Compare>(); + auto isNegative = allocator.alloc<Binary>(); isNegative->op = LtS; - isNegative->inputType = i32; isNegative->left = get(); isNegative->right = allocator.alloc<Const>()->set(0); + isNegative->finalize(); auto block = allocator.alloc<Block>(); block->list.push_back(set); auto flip = allocator.alloc<Binary>(); @@ -1249,7 +1236,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { for (unsigned i = 0; i < args->size(); i++) { ret->operands.push_back(process(args[i])); } - ret->type = getFunctionType(astStackHelper.getParent(), ret->operands); + ret->fullType = getFunctionType(astStackHelper.getParent(), ret->operands); + ret->type = ret->fullType->result; callIndirects[ret] = target[1][1]->getIString(); // we need to fix this up later, when we know how asm function tables are layed out inside the wasm table. return ret; } else if (what == RETURN) { diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp index 9c1a05447..b2ba51544 100644 --- a/src/binaryen-shell.cpp +++ b/src/binaryen-shell.cpp @@ -105,8 +105,9 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { } break; } - case f32: *((float*)(memory+addr)) = value.getf32(); break; - case f64: *((double*)(memory+addr)) = value.getf64(); break; + // write floats carefully, ensuring all bits reach memory + case f32: *((int32_t*)(memory+addr)) = value.reinterpreti32(); break; + case f64: *((int64_t*)(memory+addr)) = value.reinterpreti64(); break; default: abort(); } } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index cd7708657..83d974d47 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -303,7 +303,7 @@ private: if (index >= instance.wasm.table.names.size()) trap("callIndirect: overflow"); Name name = instance.wasm.table.names[index]; Function *func = instance.wasm.functionsMap[name]; - if (func->type.is() && func->type != curr->type->name) trap("callIndirect: bad type"); + if (func->type.is() && func->type != curr->fullType->name) trap("callIndirect: bad type"); LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; @@ -363,6 +363,17 @@ private: return Literal((int32_t)safe_ctz(v)); } case Popcnt: return Literal((int32_t)__builtin_popcount(v)); + case ReinterpretInt: { + float v = value.reinterpretf32(); + if (isnan(v)) { + return Literal(Literal(value.geti32() | 0x7f800000).reinterpretf32()); + } + return Literal(value.reinterpretf32()); + } + case ExtendSInt32: return Literal(int64_t(value.geti32())); + case ExtendUInt32: return Literal(uint64_t((uint32_t)value.geti32())); + case ConvertUInt32: return curr->type == f32 ? Literal(float(uint32_t(value.geti32()))) : Literal(double(uint32_t(value.geti32()))); + case ConvertSInt32: return curr->type == f32 ? Literal(float(int32_t(value.geti32()))) : Literal(double(int32_t(value.geti32()))); default: abort(); } } @@ -381,6 +392,12 @@ private: return Literal((int64_t)safe_ctz(low)); } case Popcnt: return Literal(int64_t(__builtin_popcount(low) + __builtin_popcount(high))); + case WrapInt64: return Literal(int32_t(value.geti64())); + case ReinterpretInt: { + return Literal(value.reinterpretf64()); + } + case ConvertUInt64: return curr->type == f32 ? Literal(float((uint64_t)value.geti64())) : Literal(double((uint64_t)value.geti64())); + case ConvertSInt64: return curr->type == f32 ? Literal(float(value.geti64())) : Literal(double(value.geti64())); default: abort(); } } @@ -395,6 +412,10 @@ private: case Trunc: ret = std::trunc(v); break; case Nearest: ret = std::nearbyint(v); break; case Sqrt: ret = std::sqrt(v); break; + case TruncSFloat32: return truncSFloat(curr, value); + case TruncUFloat32: return truncUFloat(curr, value); + case ReinterpretFloat: return Literal(value.reinterpreti32()); + case PromoteFloat32: return Literal(double(value.getf32())); default: abort(); } return Literal(fixNaN(v, ret)); @@ -410,6 +431,10 @@ private: case Trunc: ret = std::trunc(v); break; case Nearest: ret = std::nearbyint(v); break; case Sqrt: ret = std::sqrt(v); break; + case TruncSFloat64: return truncSFloat(curr, value); + case TruncUFloat64: return truncUFloat(curr, value); + case ReinterpretFloat: return Literal(value.reinterpreti64()); + case DemoteFloat64: return Literal(float(value.getf64())); default: abort(); } return Literal(fixNaN(v, ret)); @@ -466,6 +491,16 @@ private: r = r & 31; return Literal(l >> r); } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case LtS: return Literal(l < r); + case LtU: return Literal(uint32_t(l) < uint32_t(r)); + case LeS: return Literal(l <= r); + case LeU: return Literal(uint32_t(l) <= uint32_t(r)); + case GtS: return Literal(l > r); + case GtU: return Literal(uint32_t(l) > uint32_t(r)); + case GeS: return Literal(l >= r); + case GeU: return Literal(uint32_t(l) >= uint32_t(r)); default: abort(); } } else if (left.type == i64) { @@ -507,6 +542,16 @@ private: r = r & 63; return Literal(l >> r); } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case LtS: return Literal(l < r); + case LtU: return Literal(uint64_t(l) < uint64_t(r)); + case LeS: return Literal(l <= r); + case LeU: return Literal(uint64_t(l) <= uint64_t(r)); + case GtS: return Literal(l > r); + case GtU: return Literal(uint64_t(l) > uint64_t(r)); + case GeS: return Literal(l >= r); + case GeU: return Literal(uint64_t(l) >= uint64_t(r)); default: abort(); } } else if (left.type == f32) { @@ -531,6 +576,12 @@ private: else ret = std::max(l, r); break; } + case Eq: return Literal(l == r); + case Ne: return Literal(l != r); + case Lt: return Literal(l < r); + case Le: return Literal(l <= r); + case Gt: return Literal(l > r); + case Ge: return Literal(l >= r); default: abort(); } return Literal(fixNaN(l, r, ret)); @@ -556,65 +607,6 @@ private: else ret = std::max(l, r); break; } - default: abort(); - } - return Literal(fixNaN(l, r, ret)); - } - abort(); - } - Flow visitCompare(Compare *curr) override { - NOTE_ENTER("Compare"); - Flow flow = visit(curr->left); - if (flow.breaking()) return flow; - Literal left = flow.value; - flow = visit(curr->right); - if (flow.breaking()) return flow; - Literal right = flow.value; - NOTE_EVAL2(left, right); - if (left.type == i32) { - int32_t l = left.geti32(), r = right.geti32(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case LtS: return Literal(l < r); - case LtU: return Literal(uint32_t(l) < uint32_t(r)); - case LeS: return Literal(l <= r); - case LeU: return Literal(uint32_t(l) <= uint32_t(r)); - case GtS: return Literal(l > r); - case GtU: return Literal(uint32_t(l) > uint32_t(r)); - case GeS: return Literal(l >= r); - case GeU: return Literal(uint32_t(l) >= uint32_t(r)); - default: abort(); - } - } else if (left.type == i64) { - int64_t l = left.geti64(), r = right.geti64(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case LtS: return Literal(l < r); - case LtU: return Literal(uint64_t(l) < uint64_t(r)); - case LeS: return Literal(l <= r); - case LeU: return Literal(uint64_t(l) <= uint64_t(r)); - case GtS: return Literal(l > r); - case GtU: return Literal(uint64_t(l) > uint64_t(r)); - case GeS: return Literal(l >= r); - case GeU: return Literal(uint64_t(l) >= uint64_t(r)); - default: abort(); - } - } else if (left.type == f32) { - float l = left.getf32(), r = right.getf32(); - switch (curr->op) { - case Eq: return Literal(l == r); - case Ne: return Literal(l != r); - case Lt: return Literal(l < r); - case Le: return Literal(l <= r); - case Gt: return Literal(l > r); - case Ge: return Literal(l >= r); - default: abort(); - } - } else if (left.type == f64) { - double l = left.getf64(), r = right.getf64(); - switch (curr->op) { case Eq: return Literal(l == r); case Ne: return Literal(l != r); case Lt: return Literal(l < r); @@ -623,67 +615,10 @@ private: case Ge: return Literal(l >= r); default: abort(); } + return Literal(fixNaN(l, r, ret)); } abort(); } - Flow visitConvert(Convert *curr) override { - NOTE_ENTER("Convert"); - Flow flow = visit(curr->value); - if (flow.breaking()) return flow; - Literal value = flow.value; - switch (curr->op) { // :-) - case ExtendSInt32: return Literal(int64_t(value.geti32())); - case ExtendUInt32: return Literal(uint64_t((uint32_t)value.geti32())); - case WrapInt64: return Literal(int32_t(value.geti64())); - case TruncSFloat32: - case TruncSFloat64: { - double val = curr->op == TruncSFloat32 ? value.getf32() : value.getf64(); - if (isnan(val)) trap("truncSFloat of nan"); - if (curr->type == i32) { - if (val > (double)INT_MAX || val < (double)INT_MIN) trap("i32.truncSFloat overflow"); - return Literal(int32_t(val)); - } else { - int64_t converted = val; - if ((val >= 1 && converted <= 0) || val < (double)LLONG_MIN) trap("i32.truncSFloat overflow"); - return Literal(converted); - } - } - case TruncUFloat32: - case TruncUFloat64: { - double val = curr->op == TruncUFloat32 ? value.getf32() : value.getf64(); - if (isnan(val)) trap("truncUFloat of nan"); - if (curr->type == i32) { - if (val > (double)UINT_MAX || val <= (double)-1) trap("i64.truncUFloat overflow"); - return Literal(uint32_t(val)); - } else { - uint64_t converted = val; - if (converted < val - 1 || val <= (double)-1) trap("i64.truncUFloat overflow"); - return Literal(converted); - } - } - case ReinterpretFloat: { - return curr->type == i32 ? Literal(value.reinterpreti32()) : Literal(value.reinterpreti64()); - } - case ConvertUInt32: return curr->type == f32 ? Literal(float(uint32_t(value.geti32()))) : Literal(double(uint32_t(value.geti32()))); - case ConvertSInt32: return curr->type == f32 ? Literal(float(int32_t(value.geti32()))) : Literal(double(int32_t(value.geti32()))); - case ConvertUInt64: return curr->type == f32 ? Literal(float((uint64_t)value.geti64())) : Literal(double((uint64_t)value.geti64())); - case ConvertSInt64: return curr->type == f32 ? Literal(float(value.geti64())) : Literal(double(value.geti64())); - case PromoteFloat32: return Literal(double(value.getf32())); - case DemoteFloat64: return Literal(float(value.getf64())); - case ReinterpretInt: { - if (curr->type == f32) { - float v = value.reinterpretf32(); - if (isnan(v)) { - return Literal(Literal(value.geti32() | 0x7f800000).reinterpretf32()); - } - return Literal(value.reinterpretf32()); - } else { - return Literal(value.reinterpretf64()); - } - } - default: abort(); - } - } Flow visitSelect(Select *curr) override { NOTE_ENTER("Select"); Flow condition = visit(curr->condition); @@ -767,6 +702,32 @@ private: return Literal(int64_t(Literal(lnan ? l : r).reinterpreti64() | 0x8000000000000LL)).reinterpretf64(); } + Literal truncSFloat(Unary* curr, Literal value) { + double val = curr->op == TruncSFloat32 ? value.getf32() : value.getf64(); + if (isnan(val)) trap("truncSFloat of nan"); + if (curr->type == i32) { + if (val > (double)INT_MAX || val < (double)INT_MIN) trap("i32.truncSFloat overflow"); + return Literal(int32_t(val)); + } else { + int64_t converted = val; + if ((val >= 1 && converted <= 0) || val < (double)LLONG_MIN) trap("i32.truncSFloat overflow"); + return Literal(converted); + } + } + + Literal truncUFloat(Unary* curr, Literal value) { + double val = curr->op == TruncUFloat32 ? value.getf32() : value.getf64(); + if (isnan(val)) trap("truncUFloat of nan"); + if (curr->type == i32) { + if (val > (double)UINT_MAX || val <= (double)-1) trap("i64.truncUFloat overflow"); + return Literal(uint32_t(val)); + } else { + uint64_t converted = val; + if (converted < val - 1 || val <= (double)-1) trap("i64.truncUFloat overflow"); + return Literal(converted); + } + } + void trap(const char* why) { instance.externalInterface->trap(why); } diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 718c459ca..3d55a1bc1 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -237,24 +237,56 @@ class SExpressionWasmBuilder { MixedArena& allocator; std::function<void ()> onError; int functionCounter; - std::vector<Call*> calls; // we only know call types afterwards, so we set their type in a post-pass + std::map<Name, WasmType> functionTypes; // we need to know function return types before we parse their contents public: // Assumes control of and modifies the input. - SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError), functionCounter(0) { + SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError) { assert(module[0]->str() == MODULE); + functionCounter = 0; for (unsigned i = 1; i < module.size(); i++) { - parseModuleElement(*module[i]); + preParseFunctionType(*module[i]); } - // post-pass, fix up call types - for (auto call : calls) { - call->type = wasm.functionsMap[call->target]->result; + functionCounter = 0; + for (unsigned i = 1; i < module.size(); i++) { + parseModuleElement(*module[i]); } - calls.clear(); } private: + // pre-parse types and function definitions, so we know function return types before parsing their contents + void preParseFunctionType(Element& s) { + IString id = s[0]->str(); + if (id == TYPE) return parseType(s); + if (id != FUNC) return; + size_t i = 1; + Name name; + if (s[i]->isStr()) { + name = s[i]->str(); + i++; + } else { + // unnamed, use an index + name = Name::fromInt(functionCounter); + } + functionCounter++; + for (;i < s.size(); i++) { + Element& curr = *s[i]; + IString id = curr[0]->str(); + if (id == RESULT) { + functionTypes[name] = stringToWasmType(curr[1]->str()); + return; + } else if (id == TYPE) { + Name name = curr[1]->str(); + if (wasm.functionTypesMap.find(name) == wasm.functionTypesMap.end()) onError(); + FunctionType* type = wasm.functionTypesMap[name]; + functionTypes[name] = type->result; + return; + } + } + functionTypes[name] = none; + } + void parseModuleElement(Element& curr) { IString id = curr[0]->str(); if (id == FUNC) return parseFunction(curr); @@ -262,7 +294,7 @@ private: if (id == EXPORT) return parseExport(curr); if (id == IMPORT) return parseImport(curr); if (id == TABLE) return parseTable(curr); - if (id == TYPE) return parseType(curr); + if (id == TYPE) return; // already done std::cerr << "bad module element " << id.str << '\n'; onError(); } @@ -409,8 +441,8 @@ public: if (op[2] == 'p') return makeBinary(s, BinaryOp::CopySign, type); if (op[2] == 'n') { if (op[3] == 'v') { - if (op[8] == 's') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertSInt32 : ConvertOp::ConvertSInt64, type); - if (op[8] == 'u') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertUInt32 : ConvertOp::ConvertUInt64, type); + if (op[8] == 's') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertSInt32 : UnaryOp::ConvertSInt64, type); + if (op[8] == 'u') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertUInt32 : UnaryOp::ConvertUInt64, type); } if (op[3] == 's') return makeConst(s, type); } @@ -423,12 +455,12 @@ public: if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BinaryOp::DivU : BinaryOp::DivS, type); if (op[3] == 0) return makeBinary(s, BinaryOp::Div, type); } - if (op[1] == 'e') return makeConvert(s, ConvertOp::DemoteFloat64, type); + if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); abort_on(op); } case 'e': { - if (op[1] == 'q') return makeCompare(s, RelationalOp::Eq, type); - if (op[1] == 'x') return makeConvert(s, op[7] == 'u' ? ConvertOp::ExtendUInt32 : ConvertOp::ExtendSInt32, type); + if (op[1] == 'q') return makeBinary(s, BinaryOp::Eq, type); + if (op[1] == 'x') return makeUnary(s, op[7] == 'u' ? UnaryOp::ExtendUInt32 : UnaryOp::ExtendSInt32, type); abort_on(op); } case 'f': { @@ -437,23 +469,23 @@ public: } case 'g': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GtU : RelationalOp::GtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Gt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GtU : BinaryOp::GtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Gt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GeU : RelationalOp::GeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Ge, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GeU : BinaryOp::GeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ge, type); } abort_on(op); } case 'l': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LtU : RelationalOp::LtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Lt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LtU : BinaryOp::LtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Lt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LeU : RelationalOp::LeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Le, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LeU : BinaryOp::LeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Le, type); } if (op[1] == 'o') return makeLoad(s, type); abort_on(op); @@ -466,7 +498,7 @@ public: } case 'n': { if (op[1] == 'e') { - if (op[2] == 0) return makeCompare(s, RelationalOp::Ne, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ne, type); if (op[2] == 'a') return makeUnary(s, UnaryOp::Nearest, type); if (op[2] == 'g') return makeUnary(s, UnaryOp::Neg, type); } @@ -477,14 +509,14 @@ public: abort_on(op); } case 'p': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::PromoteFloat32, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::PromoteFloat32, type); if (op[1] == 'o') return makeUnary(s, UnaryOp::Popcnt, type); abort_on(op); } case 'r': { if (op[1] == 'e') { if (op[2] == 'm') return makeBinary(s, op[4] == 'u' ? BinaryOp::RemU : BinaryOp::RemS, type); - if (op[2] == 'i') return makeConvert(s, isWasmTypeFloat(type) ? ConvertOp::ReinterpretInt : ConvertOp::ReinterpretFloat, type); + if (op[2] == 'i') return makeUnary(s, isWasmTypeFloat(type) ? UnaryOp::ReinterpretInt : UnaryOp::ReinterpretFloat, type); } abort_on(op); } @@ -501,14 +533,14 @@ public: } case 't': { if (op[1] == 'r') { - if (op[6] == 's') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncSFloat32 : ConvertOp::TruncSFloat64, type); - if (op[6] == 'u') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncUFloat32 : ConvertOp::TruncUFloat64, type); + if (op[6] == 's') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncSFloat32 : UnaryOp::TruncSFloat64, type); + if (op[6] == 'u') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncUFloat32 : UnaryOp::TruncUFloat64, type); if (op[2] == 'u') return makeUnary(s, UnaryOp::Trunc, type); } abort_on(op); } case 'w': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::WrapInt64, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::WrapInt64, type); abort_on(op); } case 'x': { @@ -591,7 +623,7 @@ private: ret->op = op; ret->left = parseExpression(s[1]); ret->right = parseExpression(s[2]); - ret->type = type; + ret->finalize(); return ret; } @@ -603,23 +635,6 @@ private: return ret; } - Expression* makeCompare(Element& s, RelationalOp op, WasmType type) { - auto ret = allocator.alloc<Compare>(); - ret->op = op; - ret->left = parseExpression(s[1]); - ret->right = parseExpression(s[2]); - ret->inputType = type; - return ret; - } - - Expression* makeConvert(Element& s, ConvertOp op, WasmType type) { - auto ret = allocator.alloc<Convert>(); - ret->op = op; - ret->value = parseExpression(s[1]); - ret->type = type; - return ret; - } - Expression* makeSelect(Element& s, WasmType type) { auto ret = allocator.alloc<Select>(); ret->condition = parseExpression(s[1]); @@ -678,6 +693,7 @@ private: for (; i < s.size(); i++) { ret->list.push_back(parseExpression(s[i])); } + ret->type = ret->list.back()->type; return ret; } @@ -904,6 +920,7 @@ private: ret->ifTrue = parseExpression(s[2]); if (s.size() == 4) { ret->ifFalse = parseExpression(s[3]); + ret->type = ret->ifTrue->type == ret->ifFalse->type ? ret->ifTrue->type : none; // if not the same type, this does not return a value } return ret; } @@ -929,6 +946,7 @@ private: for (; i < s.size() && i < stopAt; i++) { ret->list.push_back(parseExpression(s[i])); } + ret->type = ret->list.back()->type; return ret; } @@ -957,8 +975,8 @@ private: Expression* makeCall(Element& s) { auto ret = allocator.alloc<Call>(); - calls.push_back(ret); ret->target = s[1]->str(); + ret->type = functionTypes[ret->target]; parseCallOperands(s, 2, ret); return ret; } @@ -966,6 +984,8 @@ private: Expression* makeCallImport(Element& s) { auto ret = allocator.alloc<CallImport>(); ret->target = s[1]->str(); + Import* import = wasm.importsMap[ret->target]; + ret->type = import->type.result; parseCallOperands(s, 2, ret); return ret; } @@ -974,7 +994,8 @@ private: auto ret = allocator.alloc<CallIndirect>(); IString type = s[1]->str(); assert(wasm.functionTypesMap.find(type) != wasm.functionTypesMap.end()); - ret->type = wasm.functionTypesMap[type]; + ret->fullType = wasm.functionTypesMap[type]; + ret->type = ret->fullType->result; ret->target = parseExpression(s[2]); parseCallOperands(s, 3, ret); return ret; diff --git a/src/wasm.h b/src/wasm.h index 04b524260..d272f2037 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -15,6 +15,16 @@ // * Validation: See wasm-validator.h. // +// +// wasm.js internal WebAssembly representation design: +// +// * Optimize for size. This justifies separating if and if_else +// (so that if doesn't have an always-empty else) +// * Unify where possible. Where size isn't a concern, combine +// classes, so binary ops and relational ops are joined. This +// simplifies that AST and makes traversals easier. +// + #ifndef __wasm_h__ #define __wasm_h__ @@ -153,7 +163,7 @@ struct Literal { } } - void printFloat(std::ostream &o, float f) { + static void printFloat(std::ostream &o, float f) { if (isnan(f)) { union { float ff; @@ -166,7 +176,7 @@ struct Literal { printDouble(o, f); } - void printDouble(std::ostream &o, double d) { + static void printDouble(std::ostream &o, double d) { if (d == 0 && 1/d < 0) { o << "-0"; return; @@ -210,26 +220,22 @@ struct Literal { enum UnaryOp { Clz, Ctz, Popcnt, // int - Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt // float + Neg, Abs, Ceil, Floor, Trunc, Nearest, Sqrt, // float + // conversions + ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int + ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float }; enum BinaryOp { Add, Sub, Mul, // int or float DivS, DivU, RemS, RemU, And, Or, Xor, Shl, ShrU, ShrS, // int - Div, CopySign, Min, Max // float -}; - -enum RelationalOp { + Div, CopySign, Min, Max, // float + // relational ops Eq, Ne, // int or float LtS, LtU, LeS, LeU, GtS, GtU, GeS, GeU, // int Lt, Le, Gt, Ge // float }; -enum ConvertOp { - ExtendSInt32, ExtendUInt32, WrapInt64, TruncSFloat32, TruncUFloat32, TruncSFloat64, TruncUFloat64, ReinterpretFloat, // int - ConvertSInt32, ConvertUInt32, ConvertSInt64, ConvertUInt64, PromoteFloat32, DemoteFloat64, ReinterpretInt // float -}; - enum HostOp { PageSize, MemorySize, GrowMemory, HasFeature }; @@ -269,12 +275,10 @@ public: ConstId = 14, UnaryId = 15, BinaryId = 16, - CompareId = 17, - ConvertId = 18, - SelectId = 19, - HostId = 20, - NopId = 21, - UnreachableId = 22 + SelectId = 17, + HostId = 18, + NopId = 19, + UnreachableId = 20 }; Id _id; @@ -319,7 +323,9 @@ public: class Block : public Expression { public: - Block() : Expression(BlockId) {} + Block() : Expression(BlockId) { + type = none; // blocks by default do not return, but if their last statement does, they might + } Name name; ExpressionList list; @@ -339,7 +345,9 @@ public: class If : public Expression { public: - If() : Expression(IfId), ifFalse(nullptr) {} + If() : Expression(IfId), ifFalse(nullptr) { + type = none; // by default none; if-else can have one, though + } Expression *condition, *ifTrue, *ifFalse; @@ -543,12 +551,12 @@ class CallIndirect : public Expression { public: CallIndirect() : Expression(CallIndirectId) {} - FunctionType *type; + FunctionType *fullType; Expression *target; ExpressionList operands; std::ostream& doPrint(std::ostream &o, unsigned indent) { - printOpening(o, "call_indirect ") << type->name; + printOpening(o, "call_indirect ") << fullType->name; incIndent(o, indent); printFullLine(o, indent, target); for (auto operand : operands) { @@ -683,16 +691,31 @@ public: o << '('; prepareColor(o) << printWasmType(type) << '.'; switch (op) { - case Clz: o << "clz"; break; - case Ctz: o << "ctz"; break; - case Popcnt: o << "popcnt"; break; - case Neg: o << "neg"; break; - case Abs: o << "abs"; break; - case Ceil: o << "ceil"; break; - case Floor: o << "floor"; break; - case Trunc: o << "trunc"; break; - case Nearest: o << "nearest"; break; - case Sqrt: o << "sqrt"; break; + case Clz: o << "clz"; break; + case Ctz: o << "ctz"; break; + case Popcnt: o << "popcnt"; break; + case Neg: o << "neg"; break; + case Abs: o << "abs"; break; + case Ceil: o << "ceil"; break; + case Floor: o << "floor"; break; + case Trunc: o << "trunc"; break; + case Nearest: o << "nearest"; break; + case Sqrt: o << "sqrt"; break; + case ExtendSInt32: o << "extend_s/i32"; break; + case ExtendUInt32: o << "extend_u/i32"; break; + case WrapInt64: o << "wrap/i64"; break; + case TruncSFloat32: o << "trunc_s/f32"; break; + case TruncUFloat32: o << "trunc_u/f32"; break; + case TruncSFloat64: o << "trunc_s/f64"; break; + case TruncUFloat64: o << "trunc_u/f64"; break; + case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; + case ConvertUInt32: o << "convert_u/i32"; break; + case ConvertSInt32: o << "convert_s/i32"; break; + case ConvertUInt64: o << "convert_u/i64"; break; + case ConvertSInt64: o << "convert_s/i64"; break; + case PromoteFloat32: o << "promote/f32"; break; + case DemoteFloat64: o << "demote/f64"; break; + case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; default: abort(); } incIndent(o, indent); @@ -710,7 +733,7 @@ public: std::ostream& doPrint(std::ostream &o, unsigned indent) { o << '('; - prepareColor(o) << printWasmType(type) << '.'; + prepareColor(o) << printWasmType(isRelational() ? left->type : type) << '.'; switch (op) { case Add: o << "add"; break; case Sub: o << "sub"; break; @@ -729,50 +752,21 @@ public: case CopySign: o << "copysign"; break; case Min: o << "min"; break; case Max: o << "max"; break; - default: abort(); - } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, left); - printFullLine(o, indent, right); - return decIndent(o, indent); - } - - // the type is always the type of the operands - void finalize() { - type = left->type; - } -}; - -class Compare : public Expression { -public: - Compare() : Expression(CompareId) { - type = WasmType::i32; // output is always i32 - } - - RelationalOp op; - WasmType inputType; - Expression *left, *right; - - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(inputType) << '.'; - switch (op) { - case Eq: o << "eq"; break; - case Ne: o << "ne"; break; - case LtS: o << "lt_s"; break; - case LtU: o << "lt_u"; break; - case LeS: o << "le_s"; break; - case LeU: o << "le_u"; break; - case GtS: o << "gt_s"; break; - case GtU: o << "gt_u"; break; - case GeS: o << "ge_s"; break; - case GeU: o << "ge_u"; break; - case Lt: o << "lt"; break; - case Le: o << "le"; break; - case Gt: o << "gt"; break; - case Ge: o << "ge"; break; - default: abort(); + case Eq: o << "eq"; break; + case Ne: o << "ne"; break; + case LtS: o << "lt_s"; break; + case LtU: o << "lt_u"; break; + case LeS: o << "le_s"; break; + case LeU: o << "le_u"; break; + case GtS: o << "gt_s"; break; + case GtU: o << "gt_u"; break; + case GeS: o << "ge_s"; break; + case GeU: o << "ge_u"; break; + case Lt: o << "lt"; break; + case Le: o << "le"; break; + case Gt: o << "gt"; break; + case Ge: o << "ge"; break; + default: abort(); } restoreNormalColor(o); incIndent(o, indent); @@ -780,40 +774,19 @@ public: printFullLine(o, indent, right); return decIndent(o, indent); } -}; -class Convert : public Expression { -public: - Convert() : Expression(ConvertId) {} + // the type is always the type of the operands, + // except for relationals - ConvertOp op; - Expression *value; + bool isRelational() { return op >= Eq; } - std::ostream& doPrint(std::ostream &o, unsigned indent) { - o << '('; - prepareColor(o) << printWasmType(type) << '.'; - switch (op) { - case ExtendSInt32: o << "extend_s/i32"; break; - case ExtendUInt32: o << "extend_u/i32"; break; - case WrapInt64: o << "wrap/i64"; break; - case TruncSFloat32: o << "trunc_s/f32"; break; - case TruncUFloat32: o << "trunc_u/f32"; break; - case TruncSFloat64: o << "trunc_s/f64"; break; - case TruncUFloat64: o << "trunc_u/f64"; break; - case ReinterpretFloat: o << "reinterpret/" << (type == i64 ? "f64" : "f32"); break; - case ConvertUInt32: o << "convert_u/i32"; break; - case ConvertSInt32: o << "convert_s/i32"; break; - case ConvertUInt64: o << "convert_u/i64"; break; - case ConvertSInt64: o << "convert_s/i64"; break; - case PromoteFloat32: o << "promote/f32"; break; - case DemoteFloat64: o << "demote/f64"; break; - case ReinterpretInt: o << "reinterpret" << (type == f64 ? "i64" : "i32"); break; - default: abort(); + void finalize() { + if (isRelational()) { + type = i32; + } else { + assert(left->type == right->type); + type = left->type; } - restoreNormalColor(o); - incIndent(o, indent); - printFullLine(o, indent, value); - return decIndent(o, indent); } }; @@ -1117,8 +1090,6 @@ struct WasmVisitor { virtual ReturnType visitConst(Const *curr) { abort(); } virtual ReturnType visitUnary(Unary *curr) { abort(); } virtual ReturnType visitBinary(Binary *curr) { abort(); } - virtual ReturnType visitCompare(Compare *curr) { abort(); } - virtual ReturnType visitConvert(Convert *curr) { abort(); } virtual ReturnType visitSelect(Select *curr) { abort(); } virtual ReturnType visitHost(Host *curr) { abort(); } virtual ReturnType visitNop(Nop *curr) { abort(); } @@ -1151,8 +1122,6 @@ struct WasmVisitor { case Expression::Id::ConstId: return visitConst((Const*)curr); case Expression::Id::UnaryId: return visitUnary((Unary*)curr); case Expression::Id::BinaryId: return visitBinary((Binary*)curr); - case Expression::Id::CompareId: return visitCompare((Compare*)curr); - case Expression::Id::ConvertId: return visitConvert((Convert*)curr); case Expression::Id::SelectId: return visitSelect((Select*)curr); case Expression::Id::HostId: return visitHost((Host*)curr); case Expression::Id::NopId: return visitNop((Nop*)curr); @@ -1188,8 +1157,6 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { void visitConst(Const *curr) override { curr->doPrint(o, indent); } void visitUnary(Unary *curr) override { curr->doPrint(o, indent); } void visitBinary(Binary *curr) override { curr->doPrint(o, indent); } - void visitCompare(Compare *curr) override { curr->doPrint(o, indent); } - void visitConvert(Convert *curr) override { curr->doPrint(o, indent); } void visitSelect(Select *curr) override { curr->doPrint(o, indent); } void visitHost(Host *curr) override { curr->doPrint(o, indent); } void visitNop(Nop *curr) override { curr->doPrint(o, indent); } @@ -1234,8 +1201,6 @@ struct WasmWalker : public WasmVisitor<void> { void visitConst(Const *curr) override {} void visitUnary(Unary *curr) override {} void visitBinary(Binary *curr) override {} - void visitCompare(Compare *curr) override {} - void visitConvert(Convert *curr) override {} void visitSelect(Select *curr) override {} void visitHost(Host *curr) override {} void visitNop(Nop *curr) override {} @@ -1321,13 +1286,6 @@ struct WasmWalker : public WasmVisitor<void> { parent.walk(curr->left); parent.walk(curr->right); } - void visitCompare(Compare *curr) override { - parent.walk(curr->left); - parent.walk(curr->right); - } - void visitConvert(Convert *curr) override { - parent.walk(curr->value); - } void visitSelect(Select *curr) override { parent.walk(curr->condition); parent.walk(curr->ifTrue); |