diff options
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/literal.cpp | 45 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 83 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 43 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 76 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 53 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 306 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 90 |
7 files changed, 473 insertions, 223 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 82a150257..4f66b36e3 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -137,8 +137,11 @@ void Literal::getBits(uint8_t (&buf)[16]) const { case Type::v128: memcpy(buf, &v128, sizeof(v128)); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -146,10 +149,20 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { + if (type.isRef() && other.type.isRef()) { + if (type == Type::nullref && other.type == Type::nullref) { + return true; + } + if (type == Type::funcref && other.type == Type::funcref && + func == other.func) { + return true; + } + return false; + } if (type != other.type) { return false; } - if (type == none) { + if (type == Type::none) { return true; } uint8_t bits[16], other_bits[16]; @@ -273,8 +286,14 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "i32x4 "; literal.printVec128(o, literal.getv128()); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + o << "funcref(" << literal.getFunc() << ")"; + break; + case Type::nullref: + o << "nullref"; + break; + case Type::anyref: + case Type::exnref: case Type::unreachable: WASM_UNREACHABLE("invalid type"); } @@ -477,7 +496,9 @@ Literal Literal::eqz() const { case Type::f64: return eq(Literal(double(0))); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -497,7 +518,9 @@ Literal Literal::neg() const { case Type::f64: return Literal(int64_t(i64 ^ 0x8000000000000000ULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -517,7 +540,9 @@ Literal Literal::abs() const { case Type::f64: return Literal(int64_t(i64 & 0x7fffffffffffffffULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -620,7 +645,9 @@ Literal Literal::add(const Literal& other) const { case Type::f64: return Literal(getf64() + other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -640,7 +667,9 @@ Literal Literal::sub(const Literal& other) const { case Type::f64: return Literal(getf64() - other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -731,7 +760,9 @@ Literal Literal::mul(const Literal& other) const { case Type::f64: return Literal(getf64() * other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -967,7 +998,9 @@ Literal Literal::eq(const Literal& other) const { case Type::f64: return Literal(getf64() == other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -987,7 +1020,9 @@ Literal Literal::ne(const Literal& other) const { case Type::f64: return Literal(getf64() != other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 82eb51d7e..ba5a8d3dd 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -262,7 +262,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one table\n"); writeImportHeader(&wasm->table); o << U32LEB(int32_t(ExternalKind::Table)); - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -463,7 +463,7 @@ void WasmBinaryWriter::writeFunctionTableDeclaration() { BYN_TRACE("== writeFunctionTableDeclaration\n"); auto start = startSection(BinaryConsts::Section::Table); o << U32LEB(1); // Declare 1 table. - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -1059,8 +1059,12 @@ Type WasmBinaryBuilder::getType() { return f64; case BinaryConsts::EncodedType::v128: return v128; + case BinaryConsts::EncodedType::funcref: + return funcref; case BinaryConsts::EncodedType::anyref: return anyref; + case BinaryConsts::EncodedType::nullref: + return nullref; case BinaryConsts::EncodedType::exnref: return exnref; default: @@ -1258,8 +1262,8 @@ void WasmBinaryBuilder::readImports() { wasm.table.name = Name(std::string("timport$") + std::to_string(i)); auto elementType = getS32LEB(); WASM_UNUSED(elementType); - if (elementType != BinaryConsts::EncodedType::AnyFunc) { - throwError("Imported table type is not AnyFunc"); + if (elementType != BinaryConsts::EncodedType::funcref) { + throwError("Imported table type is not funcref"); } wasm.table.exists = true; bool is_shared; @@ -1802,11 +1806,17 @@ void WasmBinaryBuilder::processFunctions() { wasm.addExport(curr); } - for (auto& iter : functionCalls) { + for (auto& iter : functionRefs) { size_t index = iter.first; - auto& calls = iter.second; - for (auto* call : calls) { - call->target = getFunctionName(index); + auto& refs = iter.second; + for (auto* ref : refs) { + if (auto* call = ref->dynCast<Call>()) { + call->target = getFunctionName(index); + } else if (auto* refFunc = ref->dynCast<RefFunc>()) { + refFunc->func = getFunctionName(index); + } else { + WASM_UNREACHABLE("Invalid type in function references"); + } } } @@ -1869,8 +1879,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() { } wasm.table.exists = true; auto elemType = getS32LEB(); - if (elemType != BinaryConsts::EncodedType::AnyFunc) { - throwError("ElementType must be AnyFunc in MVP"); + if (elemType != BinaryConsts::EncodedType::funcref) { + throwError("ElementType must be funcref in MVP"); } bool is_shared; getResizableLimits( @@ -2117,7 +2127,8 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitGlobalSet((curr = allocator.alloc<GlobalSet>())->cast<GlobalSet>()); break; case BinaryConsts::Select: - visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); + case BinaryConsts::SelectWithType: + visitSelect((curr = allocator.alloc<Select>())->cast<Select>(), code); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); @@ -2137,6 +2148,15 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { case BinaryConsts::Catch: curr = nullptr; break; + case BinaryConsts::RefNull: + visitRefNull((curr = allocator.alloc<RefNull>())->cast<RefNull>()); + break; + case BinaryConsts::RefIsNull: + visitRefIsNull((curr = allocator.alloc<RefIsNull>())->cast<RefIsNull>()); + break; + case BinaryConsts::RefFunc: + visitRefFunc((curr = allocator.alloc<RefFunc>())->cast<RefFunc>()); + break; case BinaryConsts::Try: visitTry((curr = allocator.alloc<Try>())->cast<Try>()); break; @@ -2510,7 +2530,7 @@ void WasmBinaryBuilder::visitCall(Call* curr) { curr->operands[num - i - 1] = popNonVoidExpression(); } curr->type = sig.results; - functionCalls[index].push_back(curr); // we don't know function names yet + functionRefs[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -4326,12 +4346,24 @@ bool WasmBinaryBuilder::maybeVisitSIMDLoad(Expression*& out, uint32_t code) { return true; } -void WasmBinaryBuilder::visitSelect(Select* curr) { - BYN_TRACE("zz node: Select\n"); +void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) { + BYN_TRACE("zz node: Select, code " << int32_t(code) << std::endl); + if (code == BinaryConsts::SelectWithType) { + size_t numTypes = getU32LEB(); + std::vector<Type> types; + for (size_t i = 0; i < numTypes; i++) { + types.push_back(getType()); + } + curr->type = Type(types); + } curr->condition = popNonVoidExpression(); curr->ifFalse = popNonVoidExpression(); curr->ifTrue = popNonVoidExpression(); - curr->finalize(); + if (code == BinaryConsts::SelectWithType) { + curr->finalize(curr->type); + } else { + curr->finalize(); + } } void WasmBinaryBuilder::visitReturn(Return* curr) { @@ -4383,6 +4415,27 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { curr->finalize(); } +void WasmBinaryBuilder::visitRefNull(RefNull* curr) { + BYN_TRACE("zz node: RefNull\n"); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefIsNull(RefIsNull* curr) { + BYN_TRACE("zz node: RefIsNull\n"); + curr->value = popNonVoidExpression(); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { + BYN_TRACE("zz node: RefFunc\n"); + Index index = getU32LEB(); + if (index >= functionImports.size() + functionSignatures.size()) { + throwError("ref.func: invalid call index"); + } + functionRefs[index].push_back(curr); // we don't know function names yet + curr->finalize(); +} + void WasmBinaryBuilder::visitTry(Try* curr) { BYN_TRACE("zz node: Try\n"); // For simplicity of implementation, like if scopes, we create a hidden block diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 20aff2091..3b12c4346 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -850,16 +850,22 @@ Type SExpressionWasmBuilder::stringToType(const char* str, return v128; } } + if (strncmp(str, "funcref", 7) == 0 && (prefix || str[7] == 0)) { + return funcref; + } if (strncmp(str, "anyref", 6) == 0 && (prefix || str[6] == 0)) { return anyref; } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return nullref; + } if (strncmp(str, "exnref", 6) == 0 && (prefix || str[6] == 0)) { return exnref; } if (allowError) { return none; } - throw ParseException("invalid wasm type"); + throw ParseException(std::string("invalid wasm type: ") + str); } Type SExpressionWasmBuilder::stringToLaneType(const char* str) { @@ -936,10 +942,16 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) { Expression* SExpressionWasmBuilder::makeSelect(Element& s) { auto ret = allocator.alloc<Select>(); - ret->ifTrue = parseExpression(s[1]); - ret->ifFalse = parseExpression(s[2]); - ret->condition = parseExpression(s[3]); - ret->finalize(); + Index i = 1; + Type type = parseOptionalResultType(s, i); + ret->ifTrue = parseExpression(s[i++]); + ret->ifFalse = parseExpression(s[i++]); + ret->condition = parseExpression(s[i]); + if (type.isConcrete()) { + ret->finalize(type); + } else { + ret->finalize(); + } return ret; } @@ -1718,6 +1730,27 @@ Expression* SExpressionWasmBuilder::makeReturn(Element& s) { return ret; } +Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { + auto ret = allocator.alloc<RefNull>(); + ret->finalize(); + return ret; +} + +Expression* SExpressionWasmBuilder::makeRefIsNull(Element& s) { + auto ret = allocator.alloc<RefIsNull>(); + ret->value = parseExpression(s[1]); + ret->finalize(); + return ret; +} + +Expression* SExpressionWasmBuilder::makeRefFunc(Element& s) { + auto func = getFunctionName(*s[1]); + auto ret = allocator.alloc<RefFunc>(); + ret->func = func; + ret->finalize(); + return ret; +} + // try-catch-end is written in the folded wast format as // (try // ... diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 593214838..22d6a0036 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -147,8 +147,10 @@ void BinaryInstWriter::visitLoad(Load* curr) { // the pointer is unreachable, so we are never reached; just don't emit // a load return; - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: WASM_UNREACHABLE("unexpected type"); } @@ -247,8 +249,10 @@ void BinaryInstWriter::visitStore(Store* curr) { o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::V128Store); break; - case anyref: // anyref cannot be stored from memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -642,8 +646,10 @@ void BinaryInstWriter::visitConst(Const* curr) { } break; } - case anyref: // there's no anyref.const - case exnref: // there's no exnref.const + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1541,7 +1547,15 @@ void BinaryInstWriter::visitBinary(Binary* curr) { } void BinaryInstWriter::visitSelect(Select* curr) { - o << int8_t(BinaryConsts::Select); + if (curr->type.isRef()) { + o << int8_t(BinaryConsts::SelectWithType) << U32LEB(curr->type.size()); + for (size_t i = 0; i < curr->type.size(); i++) { + o << binaryType(curr->type != Type::unreachable ? curr->type + : Type::none); + } + } else { + o << int8_t(BinaryConsts::Select); + } } void BinaryInstWriter::visitReturn(Return* curr) { @@ -1562,6 +1576,19 @@ void BinaryInstWriter::visitHost(Host* curr) { o << U32LEB(0); // Reserved flags field } +void BinaryInstWriter::visitRefNull(RefNull* curr) { + o << int8_t(BinaryConsts::RefNull); +} + +void BinaryInstWriter::visitRefIsNull(RefIsNull* curr) { + o << int8_t(BinaryConsts::RefIsNull); +} + +void BinaryInstWriter::visitRefFunc(RefFunc* curr) { + o << int8_t(BinaryConsts::RefFunc) + << U32LEB(parent.getFunctionIndex(curr->func)); +} + void BinaryInstWriter::visitTry(Try* curr) { breakStack.emplace_back(IMPOSSIBLE_CONTINUE); o << int8_t(BinaryConsts::Try); @@ -1659,11 +1686,21 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { continue; } index += numLocalsByType[v128]; + if (type == funcref) { + mappedLocals[i] = index + currLocalsByType[funcref] - 1; + continue; + } + index += numLocalsByType[funcref]; if (type == anyref) { mappedLocals[i] = index + currLocalsByType[anyref] - 1; continue; } index += numLocalsByType[anyref]; + if (type == nullref) { + mappedLocals[i] = index + currLocalsByType[nullref] - 1; + continue; + } + index += numLocalsByType[nullref]; if (type == exnref) { mappedLocals[i] = index + currLocalsByType[exnref] - 1; continue; @@ -1671,11 +1708,12 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { WASM_UNREACHABLE("unexpected type"); } // Emit them. - o << U32LEB((numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) + - (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) + - (numLocalsByType[v128] ? 1 : 0) + - (numLocalsByType[anyref] ? 1 : 0) + - (numLocalsByType[exnref] ? 1 : 0)); + o << U32LEB( + (numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) + + (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) + + (numLocalsByType[v128] ? 1 : 0) + (numLocalsByType[funcref] ? 1 : 0) + + (numLocalsByType[anyref] ? 1 : 0) + (numLocalsByType[nullref] ? 1 : 0) + + (numLocalsByType[exnref] ? 1 : 0)); if (numLocalsByType[i32]) { o << U32LEB(numLocalsByType[i32]) << binaryType(i32); } @@ -1691,9 +1729,15 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { if (numLocalsByType[v128]) { o << U32LEB(numLocalsByType[v128]) << binaryType(v128); } + if (numLocalsByType[funcref]) { + o << U32LEB(numLocalsByType[funcref]) << binaryType(funcref); + } if (numLocalsByType[anyref]) { o << U32LEB(numLocalsByType[anyref]) << binaryType(anyref); } + if (numLocalsByType[nullref]) { + o << U32LEB(numLocalsByType[nullref]) << binaryType(nullref); + } if (numLocalsByType[exnref]) { o << U32LEB(numLocalsByType[exnref]) << binaryType(exnref); } @@ -1760,7 +1804,7 @@ StackInst* StackIRGenerator::makeStackInst(StackInst::Op op, // type. stackType = none; } else if (op != StackInst::BlockEnd && op != StackInst::IfEnd && - op != StackInst::LoopEnd) { + op != StackInst::LoopEnd && op != StackInst::TryEnd) { // If a concrete type is returned, we mark the end of the construct has // having that type (as it is pushed to the value stack at that point), // other parts are marked as none). @@ -1781,13 +1825,15 @@ void StackIRToBinaryWriter::write() { case StackInst::Basic: case StackInst::BlockBegin: case StackInst::IfBegin: - case StackInst::LoopBegin: { + case StackInst::LoopBegin: + case StackInst::TryBegin: { writer.visit(inst->origin); break; } case StackInst::BlockEnd: case StackInst::IfEnd: - case StackInst::LoopEnd: { + case StackInst::LoopEnd: + case StackInst::TryEnd: { writer.emitScopeEnd(); break; } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 939ee8c93..62c30d1e0 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -62,7 +62,9 @@ std::vector<std::unique_ptr<std::vector<Type>>> typeLists = [] { add({Type::f32}); add({Type::f64}); add({Type::v128}); + add({Type::funcref}); add({Type::anyref}); + add({Type::nullref}); add({Type::exnref}); return lists; }(); @@ -75,7 +77,9 @@ std::unordered_map<std::vector<Type>, uint32_t> indices = { {{Type::f32}, Type::f32}, {{Type::f64}, Type::f64}, {{Type::v128}, Type::v128}, + {{Type::funcref}, Type::funcref}, {{Type::anyref}, Type::anyref}, + {{Type::nullref}, Type::nullref}, {{Type::exnref}, Type::exnref}, }; @@ -154,8 +158,10 @@ unsigned Type::getByteSize() const { return 8; case Type::v128: return 16; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::anyref: + case Type::nullref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -164,7 +170,7 @@ unsigned Type::getByteSize() const { } Type Type::reinterpret() const { - assert(isSingle() && "reinterpret only works with single types"); + assert(isSingle() && "reinterpretType only works with single types"); Type singleType = *expand().begin(); switch (singleType) { case Type::i32: @@ -176,7 +182,9 @@ Type Type::reinterpret() const { case Type::f64: return i64; case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -221,6 +229,39 @@ Type Type::get(unsigned byteSize, bool float_) { WASM_UNREACHABLE("invalid size"); } +bool Type::Type::isSubType(Type left, Type right) { + if (left == right) { + return true; + } + if (left.isRef() && right.isRef() && + (right == Type::anyref || left == Type::nullref)) { + return true; + } + return false; +} + +Type Type::Type::getLeastUpperBound(Type a, Type b) { + if (a == b) { + return a; + } + if (a == Type::unreachable) { + return b; + } + if (b == Type::unreachable) { + return a; + } + if (!a.isRef() || !b.isRef()) { + return none; // a poison value that must not be consumed + } + if (a == Type::nullref) { + return b; + } + if (b == Type::nullref) { + return a; + } + return Type::anyref; +} + namespace { std::ostream& @@ -280,9 +321,15 @@ std::ostream& operator<<(std::ostream& os, Type type) { case Type::v128: os << "v128"; break; + case Type::funcref: + os << "funcref"; + break; case Type::anyref: os << "anyref"; break; + case Type::nullref: + os << "nullref"; + break; case Type::exnref: os << "exnref"; break; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 55e115d95..7bf51c5f7 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -21,6 +21,7 @@ #include "ir/branch-utils.h" #include "ir/features.h" +#include "ir/global-utils.h" #include "ir/module-utils.h" #include "ir/utils.h" #include "support/colors.h" @@ -181,6 +182,31 @@ struct ValidationInfo { fail(text, curr, func); } } + + // Type 'left' should be a subtype of 'right'. + bool shouldBeSubType(Type left, + Type right, + Expression* curr, + const char* text, + Function* func = nullptr) { + if (Type::isSubType(left, right)) { + return true; + } + fail(text, curr, func); + return false; + } + + // Type 'left' should be a subtype of 'right', or unreachable. + bool shouldBeSubTypeOrUnreachable(Type left, + Type right, + Expression* curr, + const char* text, + Function* func = nullptr) { + if (left == Type::unreachable) { + return true; + } + return shouldBeSubType(left, right, curr, text, func); + } }; struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { @@ -210,7 +236,7 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { std::unordered_map<Name, BreakInfo> breakInfos; - Type returnType = unreachable; // type used in returns + std::set<Type> returnTypes; // types used in returns // Binaryen IR requires that label names must be unique - IR generators must // ensure that @@ -287,6 +313,8 @@ public: void visitDrop(Drop* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -327,6 +355,19 @@ private: return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction()); } + bool + shouldBeSubType(Type left, Type right, Expression* curr, const char* text) { + return info.shouldBeSubType(left, right, curr, text, getFunction()); + } + + bool shouldBeSubTypeOrUnreachable(Type left, + Type right, + Expression* curr, + const char* text) { + return info.shouldBeSubTypeOrUnreachable( + left, right, curr, text, getFunction()); + } + void validateAlignment( size_t align, Type type, Index bytes, bool isAtomic, Expression* curr); void validateMemBytes(uint8_t bytes, Type type, Expression* curr); @@ -364,29 +405,23 @@ void FunctionValidator::visitBlock(Block* curr) { // none or unreachable means a poison value that we should ignore - if // consumed, it will error if (info.type.isConcrete() && curr->type.isConcrete()) { - shouldBeEqual( - curr->type, + shouldBeSubType( info.type, + curr->type, curr, "block+breaks must have right type if breaks return a value"); } if (curr->type.isConcrete() && info.arity && info.type != unreachable) { - shouldBeEqual(curr->type, - info.type, - curr, - "block+breaks must have right type if breaks have arity"); + shouldBeSubType( + info.type, + curr->type, + curr, + "block+breaks must have right type if breaks have arity"); } shouldBeTrue( info.arity != BreakInfo::PoisonArity, curr, "break arities must match"); if (curr->list.size() > 0) { auto last = curr->list.back()->type; - if (last.isConcrete() && info.type != unreachable) { - shouldBeEqual(last, - info.type, - curr, - "block+breaks must have right type if block ends with " - "a reachable value"); - } if (last == none) { shouldBeTrue(info.arity == Index(0), curr, @@ -420,9 +455,9 @@ void FunctionValidator::visitBlock(Block* curr) { "not flow out a value"); } else { if (backType.isConcrete()) { - shouldBeEqual( - curr->type, + shouldBeSubType( backType, + curr->type, curr, "block with value and last element with value must match types"); } else { @@ -457,6 +492,23 @@ void FunctionValidator::visitLoop(Loop* curr) { curr, "bad body for a loop that has no value"); } + + // When there are multiple instructions within a loop, they are wrapped in a + // Block internally, so visitBlock can take care of verification. Here we + // check cases when there is only one instruction in a Loop. + if (!curr->body->is<Block>()) { + if (!curr->type.isConcrete()) { + shouldBeFalse(curr->body->type.isConcrete(), + curr, + "if loop is not returning a value, final element should " + "not flow out a value"); + } else { + shouldBeSubTypeOrUnreachable(curr->body->type, + curr->type, + curr, + "loop with value and body must match types"); + } + } } void FunctionValidator::visitIf(If* curr) { @@ -476,12 +528,12 @@ void FunctionValidator::visitIf(If* curr) { } } else { if (curr->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable( + shouldBeSubTypeOrUnreachable( curr->ifTrue->type, curr->type, curr, "returning if-else's true must have right type"); - shouldBeEqualOrFirstIsUnreachable( + shouldBeSubTypeOrUnreachable( curr->ifFalse->type, curr->type, curr, @@ -499,25 +551,16 @@ void FunctionValidator::visitIf(If* curr) { } } if (curr->ifTrue->type.isConcrete()) { - shouldBeEqual(curr->type, - curr->ifTrue->type, - curr, - "if type must match concrete ifTrue"); - shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, - curr->ifTrue->type, - curr, - "other arm must match concrete ifTrue"); + shouldBeSubType(curr->ifTrue->type, + curr->type, + curr, + "if type must match concrete ifTrue"); } if (curr->ifFalse->type.isConcrete()) { - shouldBeEqual(curr->type, - curr->ifFalse->type, - curr, - "if type must match concrete ifFalse"); - shouldBeEqualOrFirstIsUnreachable( - curr->ifTrue->type, - curr->ifFalse->type, - curr, - "other arm must match concrete ifFalse"); + shouldBeSubType(curr->ifFalse->type, + curr->type, + curr, + "if type must match concrete ifFalse"); } } } @@ -545,13 +588,7 @@ void FunctionValidator::noteBreak(Name name, Type valueType, Expression* curr) { if (!info.hasBeenSet()) { info = BreakInfo(valueType, arity); } else { - if (info.type == unreachable) { - info.type = valueType; - } else if (valueType != unreachable) { - if (valueType != info.type) { - info.type = none; // a poison value that must not be consumed - } - } + info.type = Type::getLeastUpperBound(info.type, valueType); if (arity != info.arity) { info.arity = BreakInfo::PoisonArity; } @@ -600,10 +637,10 @@ void FunctionValidator::visitCall(Call* curr) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - params[i], - curr, - "call param types must match") && + if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type, + params[i], + curr, + "call param types must match") && !info.quiet) { getStream() << "(on argument " << i << ")\n"; } @@ -653,10 +690,10 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - params[i], - curr, - "call param types must match") && + if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type, + params[i], + curr, + "call param types must match") && !info.quiet) { getStream() << "(on argument " << i << ")\n"; } @@ -723,10 +760,10 @@ void FunctionValidator::visitLocalSet(LocalSet* curr) { curr, "local.set type must be correct"); } - shouldBeEqual(curr->value->type, - getFunction()->getLocalType(curr->index), - curr, - "local.set's value type must be correct"); + shouldBeSubType(curr->value->type, + getFunction()->getLocalType(curr->index), + curr, + "local.set's value type must be correct"); } } } @@ -750,10 +787,10 @@ void FunctionValidator::visitGlobalSet(GlobalSet* curr) { "global.set name must be valid (and not an import; imports " "can't be modified)")) { shouldBeTrue(global->mutable_, curr, "global.set global must be mutable"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, - global->type, - curr, - "global.set value must have right type"); + shouldBeSubTypeOrUnreachable(curr->value->type, + global->type, + curr, + "global.set value must have right type"); } } @@ -1182,12 +1219,14 @@ void FunctionValidator::validateMemBytes(uint8_t bytes, shouldBeEqual( bytes, uint8_t(16), curr, "expected v128 operation to touch 16 bytes"); break; - case anyref: // anyref cannot be stored in memory - case exnref: // exnref cannot be stored in memory - case none: - WASM_UNREACHABLE("unexpected type"); case unreachable: break; + case funcref: + case anyref: + case nullref: + case exnref: + case none: + WASM_UNREACHABLE("unexpected type"); } } @@ -1616,15 +1655,18 @@ void FunctionValidator::visitSelect(Select* curr) { shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); shouldBeUnequal( curr->ifFalse->type, none, curr, "select right must be valid"); + shouldBeUnequal(curr->type, none, curr, "select type must be valid"); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid"); - if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) { - shouldBeEqual(curr->ifTrue->type, - curr->ifFalse->type, - curr, - "select sides must be equal"); + if (curr->type != unreachable) { + shouldBeTrue(Type::isSubType(curr->ifTrue->type, curr->type), + curr, + "select's left expression must be subtype of select's type"); + shouldBeTrue(Type::isSubType(curr->ifFalse->type, curr->type), + curr, + "select's right expression must be subtype of select's type"); } } @@ -1636,16 +1678,7 @@ void FunctionValidator::visitDrop(Drop* curr) { } void FunctionValidator::visitReturn(Return* curr) { - if (curr->value) { - if (returnType == unreachable) { - returnType = curr->value->type; - } else if (curr->value->type != unreachable) { - shouldBeEqual( - curr->value->type, returnType, curr, "function results must match"); - } - } else { - returnType = none; - } + returnTypes.insert(curr->value ? curr->value->type : Type::none); } void FunctionValidator::visitHost(Host* curr) { @@ -1668,32 +1701,37 @@ void FunctionValidator::visitHost(Host* curr) { } } +void FunctionValidator::visitRefIsNull(RefIsNull* curr) { + shouldBeTrue(curr->value->type == Type::unreachable || + curr->value->type.isRef(), + curr->value, + "ref.is_null's argument should be a reference type"); +} + +void FunctionValidator::visitRefFunc(RefFunc* curr) { + auto* func = getModule()->getFunctionOrNull(curr->func); + shouldBeTrue(!!func, curr, "function argument of ref.func must exist"); +} + void FunctionValidator::visitTry(Try* curr) { if (curr->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable( - curr->body->type, - curr->type, - curr->body, - "try's type does not match try body's type"); - shouldBeEqualOrFirstIsUnreachable( - curr->catchBody->type, - curr->type, - curr->catchBody, - "try's type does not match catch's body type"); - } - if (curr->body->type.isConcrete()) { - shouldBeEqualOrFirstIsUnreachable( - curr->catchBody->type, - curr->body->type, - curr->catchBody, - "try's body type must match catch's body type"); - } - if (curr->catchBody->type.isConcrete()) { - shouldBeEqualOrFirstIsUnreachable( - curr->body->type, - curr->catchBody->type, - curr->body, - "try's body type must match catch's body type"); + shouldBeSubTypeOrUnreachable(curr->body->type, + curr->type, + curr->body, + "try's type does not match try body's type"); + shouldBeSubTypeOrUnreachable(curr->catchBody->type, + curr->type, + curr->catchBody, + "try's type does not match catch's body type"); + } else { + shouldBeEqual(curr->body->type, + unreachable, + curr, + "unreachable try-catch must have unreachable try body"); + shouldBeEqual(curr->catchBody->type, + unreachable, + curr, + "unreachable try-catch must have unreachable catch body"); } } @@ -1727,10 +1765,10 @@ void FunctionValidator::visitThrow(Throw* curr) { void FunctionValidator::visitRethrow(Rethrow* curr) { shouldBeEqual( curr->type, unreachable, curr, "rethrow's type must be unreachable"); - shouldBeEqual(curr->exnref->type, - exnref, - curr->exnref, - "rethrow's argument must be exnref type"); + shouldBeSubType(curr->exnref->type, + Type::exnref, + curr->exnref, + "rethrow's argument must be exnref type or its subtype"); } void FunctionValidator::visitBrOnExn(BrOnExn* curr) { @@ -1740,10 +1778,11 @@ void FunctionValidator::visitBrOnExn(BrOnExn* curr) { curr, "br_on_exn's event params and event's params are different"); noteBreak(curr->name, curr->sent, curr); - shouldBeTrue(curr->exnref->type == unreachable || - curr->exnref->type == exnref, - curr, - "br_on_exn's argument must be unreachable or exnref type"); + shouldBeSubTypeOrUnreachable( + curr->exnref->type, + Type::exnref, + curr, + "br_on_exn's argument must be unreachable or exnref type or its subtype"); if (curr->exnref->type == unreachable) { shouldBeTrue(curr->type == unreachable, curr, @@ -1779,21 +1818,22 @@ void FunctionValidator::visitFunction(Function* curr) { "all used types should be allowed"); // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->body->type != unreachable) { - shouldBeEqual(curr->sig.results, - curr->body->type, - curr->body, - "function body type must match, if function returns"); - } - if (returnType != unreachable) { - shouldBeEqual(curr->sig.results, - returnType, - curr->body, - "function result must match, if function has returns"); + shouldBeSubTypeOrUnreachable( + curr->body->type, + curr->sig.results, + curr->body, + "function body type must match, if function returns"); + for (Type returnType : returnTypes) { + shouldBeSubTypeOrUnreachable( + returnType, + curr->sig.results, + curr->body, + "function result must match, if function has returns"); } + shouldBeTrue( breakInfos.empty(), curr->body, "all named break targets must exist"); - returnType = unreachable; + returnTypes.clear(); labelNames.clear(); // validate optional local names std::set<Name> seen; @@ -1858,8 +1898,10 @@ void FunctionValidator::validateAlignment( case v128: case unreachable: break; - case anyref: // anyref cannot be stored in memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: WASM_UNREACHABLE("invalid type"); } @@ -1890,7 +1932,8 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { // // The block has an added type, not derived from the ast itself, so it // is ok for it to be either i32 or unreachable. - if (!(oldType.isConcrete() && newType == unreachable)) { + if (!Type::isSubType(newType, oldType) && + !(oldType.isConcrete() && newType == Type::unreachable)) { std::ostringstream ss; ss << "stale type found in " << scope << " on " << curr << "\n(marked as " << oldType << ", should be " << newType @@ -2011,13 +2054,14 @@ static void validateGlobals(Module& module, ValidationInfo& info) { info.shouldBeTrue( curr->init != nullptr, curr->name, "global init must be non-null"); assert(curr->init); - info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GlobalGet>(), + info.shouldBeTrue(GlobalUtils::canInitializeGlobal(curr->init), curr->name, "global init must be valid"); - if (!info.shouldBeEqual(curr->type, - curr->init->type, - curr->init, - "global init must have correct type") && + + if (!info.shouldBeSubType(curr->init->type, + curr->type, + curr->init, + "global init must have correct type") && !info.quiet) { info.getStream(nullptr) << "(on global " << curr->name << ")\n"; } @@ -2118,9 +2162,9 @@ static void validateEvents(Module& module, ValidationInfo& info) { curr->name, "Event type's result type should be none"); for (auto type : curr->sig.params.expand()) { - info.shouldBeTrue(type.isInteger() || type.isFloat(), + info.shouldBeTrue(type.isConcrete(), curr->name, - "Values in an event should have integer or float type"); + "Values in an event should have concrete types"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index ff1295bad..11d203835 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -173,13 +173,19 @@ const char* getExpressionName(Expression* curr) { return "push"; case Expression::Id::PopId: return "pop"; - case Expression::TryId: + case Expression::Id::RefNullId: + return "ref.null"; + case Expression::Id::RefIsNullId: + return "ref.is_null"; + case Expression::Id::RefFuncId: + return "ref.func"; + case Expression::Id::TryId: return "try"; - case Expression::ThrowId: + case Expression::Id::ThrowId: return "throw"; - case Expression::RethrowId: + case Expression::Id::RethrowId: return "rethrow"; - case Expression::BrOnExnId: + case Expression::Id::BrOnExnId: return "br_on_exn"; case Expression::Id::NumExpressionIds: WASM_UNREACHABLE("invalid expr id"); @@ -187,6 +193,18 @@ const char* getExpressionName(Expression* curr) { WASM_UNREACHABLE("invalid expr id"); } +Literal getLiteralFromConstExpression(Expression* curr) { + if (auto* c = curr->dynCast<Const>()) { + return c->value; + } else if (curr->is<RefNull>()) { + return Literal::makeNullref(); + } else if (auto* r = curr->dynCast<RefFunc>()) { + return Literal::makeFuncref(r->func); + } else { + WASM_UNREACHABLE("Not a constant expression"); + } +} + // core AST type checking struct TypeSeeker : public PostWalker<TypeSeeker> { @@ -248,27 +266,6 @@ struct TypeSeeker : public PostWalker<TypeSeeker> { } }; -static Type mergeTypes(std::vector<Type>& types) { - Type type = unreachable; - for (auto other : types) { - // once none, stop. it then indicates a poison value, that must not be - // consumed and ignore unreachable - if (type != none) { - if (other == none) { - type = none; - } else if (other != unreachable) { - if (type == unreachable) { - type = other; - } else if (type != other) { - // poison value, we saw multiple types; this should not be consumed - type = none; - } - } - } - } - return type; -} - // a block is unreachable if one of its elements is unreachable, // and there are no branches to it static void handleUnreachable(Block* block, @@ -336,7 +333,7 @@ void Block::finalize() { } TypeSeeker seeker(this, this->name); - type = mergeTypes(seeker.types); + type = Type::mergeTypes(seeker.types); handleUnreachable(this); } @@ -364,19 +361,8 @@ void If::finalize(Type type_) { } void If::finalize() { - if (ifFalse) { - if (ifTrue->type == ifFalse->type) { - type = ifTrue->type; - } else if (ifTrue->type.isConcrete() && ifFalse->type == unreachable) { - type = ifTrue->type; - } else if (ifFalse->type.isConcrete() && ifTrue->type == unreachable) { - type = ifFalse->type; - } else { - type = none; - } - } else { - type = none; // if without else - } + type = ifFalse ? Type::getLeastUpperBound(ifTrue->type, ifFalse->type) + : Type::none; // if the arms return a value, leave it even if the condition // is unreachable, we still mark ourselves as having that type, e.g. // (if (result i32) @@ -828,13 +814,15 @@ void Binary::finalize() { } } +void Select::finalize(Type type_) { type = type_; } + void Select::finalize() { assert(ifTrue && ifFalse); if (ifTrue->type == unreachable || ifFalse->type == unreachable || condition->type == unreachable) { type = unreachable; } else { - type = ifTrue->type; + type = Type::getLeastUpperBound(ifTrue->type, ifFalse->type); } } @@ -864,16 +852,20 @@ void Host::finalize() { } } -void Try::finalize() { - if (body->type == catchBody->type) { - type = body->type; - } else if (body->type.isConcrete() && catchBody->type == unreachable) { - type = body->type; - } else if (catchBody->type.isConcrete() && body->type == unreachable) { - type = catchBody->type; - } else { - type = none; +void RefNull::finalize() { type = Type::nullref; } + +void RefIsNull::finalize() { + if (value->type == Type::unreachable) { + type = Type::unreachable; + return; } + type = Type::i32; +} + +void RefFunc::finalize() { type = Type::funcref; } + +void Try::finalize() { + type = Type::getLeastUpperBound(body->type, catchBody->type); } void Try::finalize(Type type_) { |