diff options
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/literal.cpp | 146 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 168 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 30 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 29 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 156 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 102 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 9 |
7 files changed, 463 insertions, 177 deletions
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 43f407525..09022eea0 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -44,19 +44,25 @@ Literal::Literal(Type type) : type(type) { memset(&v128, 0, 16); return; case Type::none: - return; case Type::unreachable: - break; + WASM_UNREACHABLE("Invalid literal type"); + return; } } - if (isData()) { - assert(!type.isNonNullable()); + if (type.isNull()) { + assert(type.isNullable()); new (&gcData) std::shared_ptr<GCData>(); - } else { - // For anything else, zero out all the union data. - memset(&v128, 0, 16); + return; } + + if (type.isRef() && type.getHeapType() == HeapType::i31) { + assert(type.isNonNullable()); + i32 = 0; + return; + } + + WASM_UNREACHABLE("Unexpected literal type"); } Literal::Literal(const uint8_t init[16]) : type(Type::v128) { @@ -64,9 +70,9 @@ Literal::Literal(const uint8_t init[16]) : type(Type::v128) { } Literal::Literal(std::shared_ptr<GCData> gcData, HeapType type) - : gcData(gcData), type(type, gcData ? NonNullable : Nullable) { + : gcData(gcData), type(type, NonNullable) { // The type must be a proper type for GC data. - assert(isData()); + assert((isData() && gcData) || (type.isBottom() && !gcData)); } Literal::Literal(const Literal& other) : type(other.type) { @@ -89,6 +95,10 @@ Literal::Literal(const Literal& other) : type(other.type) { break; } } + if (other.isNull()) { + new (&gcData) std::shared_ptr<GCData>(); + return; + } if (other.isData()) { new (&gcData) std::shared_ptr<GCData>(other.gcData); return; @@ -98,23 +108,30 @@ Literal::Literal(const Literal& other) : type(other.type) { return; } if (type.isRef()) { + assert(!type.isNullable()); auto heapType = type.getHeapType(); if (heapType.isBasic()) { switch (heapType.getBasic()) { - case HeapType::ext: - case HeapType::any: - case HeapType::eq: - return; // null case HeapType::i31: i32 = other.i32; return; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Null + return; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: case HeapType::func: case HeapType::data: + WASM_UNREACHABLE("invalid type"); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("invalid type"); + WASM_UNREACHABLE("TODO: string literals"); } } } @@ -125,7 +142,7 @@ Literal::~Literal() { if (type.isBasic()) { return; } - if (isData()) { + if (isNull() || isData()) { gcData.~shared_ptr(); } } @@ -239,7 +256,7 @@ std::array<uint8_t, 16> Literal::getv128() const { } std::shared_ptr<GCData> Literal::getGCData() const { - assert(isData()); + assert(isNull() || isData()); return gcData; } @@ -325,11 +342,6 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { - // The types must be identical, unless both are references - in that case, - // nulls of different types *do* compare equal. - if (type.isRef() && other.type.isRef() && (isNull() || other.isNull())) { - return isNull() && other.isNull(); - } if (type != other.type) { return false; } @@ -350,7 +362,9 @@ bool Literal::operator==(const Literal& other) const { } } else if (type.isRef()) { assert(type.isRef()); - // Note that we've already handled nulls earlier. + if (type.isNull()) { + return true; + } if (type.isFunction()) { assert(func.is() && other.func.is()); return func == other.func; @@ -361,8 +375,6 @@ bool Literal::operator==(const Literal& other) const { if (type.getHeapType() == HeapType::i31) { return i32 == other.i32; } - // other non-null reference type literals cannot represent concrete values, - // i.e. there is no concrete anyref or eqref other than null. WASM_UNREACHABLE("unexpected type"); } WASM_UNREACHABLE("unexpected type"); @@ -463,52 +475,8 @@ void Literal::printVec128(std::ostream& o, const std::array<uint8_t, 16>& v) { std::ostream& operator<<(std::ostream& o, Literal literal) { prepareMinorColor(o); - if (literal.type.isFunction()) { - if (literal.isNull()) { - o << "funcref(null)"; - } else { - o << "funcref(" << literal.getFunc() << ")"; - } - } else if (literal.type.isRef()) { - if (literal.isData()) { - auto data = literal.getGCData(); - if (data) { - o << "[ref " << data->type << ' ' << data->values << ']'; - } else { - o << "[ref null " << literal.type << ']'; - } - } else { - switch (literal.type.getHeapType().getBasic()) { - case HeapType::ext: - assert(literal.isNull() && "unexpected non-null externref literal"); - o << "externref(null)"; - break; - case HeapType::any: - assert(literal.isNull() && "unexpected non-null anyref literal"); - o << "anyref(null)"; - break; - case HeapType::eq: - assert(literal.isNull() && "unexpected non-null eqref literal"); - o << "eqref(null)"; - break; - case HeapType::i31: - if (literal.isNull()) { - o << "i31ref(null)"; - } else { - o << "i31ref(" << literal.geti31() << ")"; - } - break; - case HeapType::func: - case HeapType::data: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("type should have been handled above"); - } - } - } else { - TODO_SINGLE_COMPOUND(literal.type); + assert(literal.type.isSingle()); + if (literal.type.isBasic()) { switch (literal.type.getBasic()) { case Type::none: o << "?"; @@ -532,6 +500,44 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { case Type::unreachable: WASM_UNREACHABLE("unexpected type"); } + } else { + assert(literal.type.isRef()); + auto heapType = literal.type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + o << "i31ref(" << literal.geti31() << ")"; + break; + case HeapType::none: + o << "nullref"; + break; + case HeapType::noext: + o << "nullexternref"; + break; + case HeapType::nofunc: + o << "nullfuncref"; + break; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + } + } else if (heapType.isSignature()) { + o << "funcref(" << literal.getFunc() << ")"; + } else { + assert(literal.isData()); + auto data = literal.getGCData(); + assert(data); + o << "[ref " << data->type << ' ' << data->values << ']'; + } } restoreNormalColor(o); return o; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 55dafadd4..f2698bd79 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1429,6 +1429,25 @@ void WasmBinaryWriter::writeType(Type type) { case HeapType::stringview_iter: o << S32LEB(BinaryConsts::EncodedType::stringview_iter); return; + case HeapType::none: + o << S32LEB(BinaryConsts::EncodedType::nullref); + return; + case HeapType::noext: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::externref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullexternref); + } + return; + case HeapType::nofunc: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::funcref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullfuncref); + } + return; } } if (type.isNullable()) { @@ -1468,46 +1487,63 @@ void WasmBinaryWriter::writeType(Type type) { } void WasmBinaryWriter::writeHeapType(HeapType type) { + // ref.null always has a bottom heap type in Binaryen IR, but those types are + // only actually valid with GC enabled. When GC is not enabled, emit the + // corresponding valid top types instead. + if (!wasm->features.hasGC()) { + if (type == HeapType::nofunc || type.isSignature()) { + type = HeapType::func; + } else if (type == HeapType::noext) { + type = HeapType::ext; + } + } + if (type.isSignature() || type.isStruct() || type.isArray()) { o << S64LEB(getTypeIndex(type)); // TODO: Actually s33 return; } int ret = 0; - if (type.isBasic()) { - switch (type.getBasic()) { - case HeapType::ext: - ret = BinaryConsts::EncodedHeapType::ext; - break; - case HeapType::func: - ret = BinaryConsts::EncodedHeapType::func; - break; - case HeapType::any: - ret = BinaryConsts::EncodedHeapType::any; - break; - case HeapType::eq: - ret = BinaryConsts::EncodedHeapType::eq; - break; - case HeapType::i31: - ret = BinaryConsts::EncodedHeapType::i31; - break; - case HeapType::data: - ret = BinaryConsts::EncodedHeapType::data; - break; - case HeapType::string: - ret = BinaryConsts::EncodedHeapType::string; - break; - case HeapType::stringview_wtf8: - ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; - break; - case HeapType::stringview_wtf16: - ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; - break; - case HeapType::stringview_iter: - ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; - break; - } - } else { - WASM_UNREACHABLE("TODO: compound GC types"); + assert(type.isBasic()); + switch (type.getBasic()) { + case HeapType::ext: + ret = BinaryConsts::EncodedHeapType::ext; + break; + case HeapType::func: + ret = BinaryConsts::EncodedHeapType::func; + break; + case HeapType::any: + ret = BinaryConsts::EncodedHeapType::any; + break; + case HeapType::eq: + ret = BinaryConsts::EncodedHeapType::eq; + break; + case HeapType::i31: + ret = BinaryConsts::EncodedHeapType::i31; + break; + case HeapType::data: + ret = BinaryConsts::EncodedHeapType::data; + break; + case HeapType::string: + ret = BinaryConsts::EncodedHeapType::string; + break; + case HeapType::stringview_wtf8: + ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; + break; + case HeapType::stringview_wtf16: + ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; + break; + case HeapType::stringview_iter: + ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; + break; + case HeapType::none: + ret = BinaryConsts::EncodedHeapType::none; + break; + case HeapType::noext: + ret = BinaryConsts::EncodedHeapType::noext; + break; + case HeapType::nofunc: + ret = BinaryConsts::EncodedHeapType::nofunc; + break; } o << S64LEB(ret); // TODO: Actually s33 } @@ -1867,6 +1903,15 @@ bool WasmBinaryBuilder::getBasicType(int32_t code, Type& out) { case BinaryConsts::EncodedType::stringview_iter: out = Type(HeapType::stringview_iter, Nullable); return true; + case BinaryConsts::EncodedType::nullref: + out = Type(HeapType::none, Nullable); + return true; + case BinaryConsts::EncodedType::nullexternref: + out = Type(HeapType::noext, Nullable); + return true; + case BinaryConsts::EncodedType::nullfuncref: + out = Type(HeapType::nofunc, Nullable); + return true; default: return false; } @@ -1904,6 +1949,15 @@ bool WasmBinaryBuilder::getBasicHeapType(int64_t code, HeapType& out) { case BinaryConsts::EncodedHeapType::stringview_iter_heap: out = HeapType::stringview_iter; return true; + case BinaryConsts::EncodedHeapType::none: + out = HeapType::none; + return true; + case BinaryConsts::EncodedHeapType::noext: + out = HeapType::noext; + return true; + case BinaryConsts::EncodedHeapType::nofunc: + out = HeapType::nofunc; + return true; default: return false; } @@ -2849,7 +2903,14 @@ void WasmBinaryBuilder::skipUnreachableCode() { expressionStack = savedStack; return; } - pushExpression(curr); + if (curr->type == Type::unreachable) { + // Nothing before this unreachable should be available to future + // expressions. They will get `(unreachable)`s if they try to pop past + // this point. + expressionStack.clear(); + } else { + pushExpression(curr); + } } } @@ -6530,7 +6591,7 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { void WasmBinaryBuilder::visitRefNull(RefNull* curr) { BYN_TRACE("zz node: RefNull\n"); - curr->finalize(getHeapType()); + curr->finalize(getHeapType().getBottom()); } void WasmBinaryBuilder::visitRefIs(RefIs* curr, uint8_t code) { @@ -6941,28 +7002,29 @@ bool WasmBinaryBuilder::maybeVisitStructNew(Expression*& out, uint32_t code) { } bool WasmBinaryBuilder::maybeVisitStructGet(Expression*& out, uint32_t code) { - StructGet* curr; + bool signed_ = false; switch (code) { case BinaryConsts::StructGet: - curr = allocator.alloc<StructGet>(); + case BinaryConsts::StructGetU: break; case BinaryConsts::StructGetS: - curr = allocator.alloc<StructGet>(); - curr->signed_ = true; - break; - case BinaryConsts::StructGetU: - curr = allocator.alloc<StructGet>(); - curr->signed_ = false; + signed_ = true; break; default: return false; } auto heapType = getIndexedHeapType(); - curr->index = getU32LEB(); - curr->ref = popNonVoidExpression(); - validateHeapTypeUsingChild(curr->ref, heapType); - curr->finalize(); - out = curr; + if (!heapType.isStruct()) { + throwError("Expected struct heaptype"); + } + auto index = getU32LEB(); + if (index >= heapType.getStruct().fields.size()) { + throwError("Struct field index out of bounds"); + } + auto type = heapType.getStruct().fields[index].type; + auto ref = popNonVoidExpression(); + validateHeapTypeUsingChild(ref, heapType); + out = Builder(wasm).makeStructGet(index, ref, type, signed_); return true; } @@ -7022,10 +7084,14 @@ bool WasmBinaryBuilder::maybeVisitArrayGet(Expression*& out, uint32_t code) { return false; } auto heapType = getIndexedHeapType(); + if (!heapType.isArray()) { + throwError("Expected array heaptype"); + } + auto type = heapType.getArray().element.type; auto* index = popNonVoidExpression(); auto* ref = popNonVoidExpression(); validateHeapTypeUsingChild(ref, heapType); - out = Builder(wasm).makeArrayGet(ref, index, signed_); + out = Builder(wasm).makeArrayGet(ref, index, type, signed_); return true; } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 6de8744a6..b23419071 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1196,6 +1196,15 @@ Type SExpressionWasmBuilder::stringToType(const char* str, if (strncmp(str, "stringview_iter", 15) == 0 && (prefix || str[15] == 0)) { return Type(HeapType::stringview_iter, Nullable); } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return Type(HeapType::none, Nullable); + } + if (strncmp(str, "nullexternref", 13) == 0 && (prefix || str[13] == 0)) { + return Type(HeapType::noext, Nullable); + } + if (strncmp(str, "nullfuncref", 11) == 0 && (prefix || str[11] == 0)) { + return Type(HeapType::nofunc, Nullable); + } if (allowError) { return Type::none; } @@ -1249,6 +1258,17 @@ HeapType SExpressionWasmBuilder::stringToHeapType(const char* str, return HeapType::stringview_iter; } } + if (str[0] == 'n') { + if (strncmp(str, "none", 4) == 0 && (prefix || str[4] == 0)) { + return HeapType::none; + } + if (strncmp(str, "noextern", 8) == 0 && (prefix || str[8] == 0)) { + return HeapType::noext; + } + if (strncmp(str, "nofunc", 6) == 0 && (prefix || str[6] == 0)) { + return HeapType::nofunc; + } + } throw ParseException(std::string("invalid wasm heap type: ") + str); } @@ -2615,9 +2635,9 @@ Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { // (ref.null func), or it may be the name of a defined type, such as // (ref.null $struct.FOO) if (s[1]->dollared()) { - ret->finalize(parseHeapType(*s[1])); + ret->finalize(parseHeapType(*s[1]).getBottom()); } else { - ret->finalize(stringToHeapType(s[1]->str())); + ret->finalize(stringToHeapType(s[1]->str()).getBottom()); } return ret; } @@ -2990,10 +3010,14 @@ Expression* SExpressionWasmBuilder::makeArrayInitStatic(Element& s) { Expression* SExpressionWasmBuilder::makeArrayGet(Element& s, bool signed_) { auto heapType = parseHeapType(*s[1]); + if (!heapType.isArray()) { + throw ParseException("bad array heap type", s.line, s.col); + } auto ref = parseExpression(*s[2]); + auto type = heapType.getArray().element.type; validateHeapTypeUsingChild(ref, heapType, s); auto index = parseExpression(*s[3]); - return Builder(wasm).makeArrayGet(ref, index, signed_); + return Builder(wasm).makeArrayGet(ref, index, type, signed_); } Expression* SExpressionWasmBuilder::makeArraySet(Element& s) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 71bc98928..13d85d338 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2014,7 +2014,10 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { void BinaryInstWriter::visitCallRef(CallRef* curr) { assert(curr->target->type != Type::unreachable); - // TODO: `emitUnreachable` if target has bottom type. + if (curr->target->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef : BinaryConsts::CallRef); parent.writeIndexedHeapType(curr->target->type.getHeapType()); @@ -2090,6 +2093,10 @@ void BinaryInstWriter::visitStructNew(StructNew* curr) { } void BinaryInstWriter::visitStructGet(StructGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } const auto& heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getStruct().fields[curr->index]; int8_t op; @@ -2106,6 +2113,10 @@ void BinaryInstWriter::visitStructGet(StructGet* curr) { } void BinaryInstWriter::visitStructSet(StructSet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::StructSet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); o << U32LEB(curr->index); @@ -2129,6 +2140,10 @@ void BinaryInstWriter::visitArrayInit(ArrayInit* curr) { } void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } auto heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getArray().element; int8_t op; @@ -2144,16 +2159,28 @@ void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { } void BinaryInstWriter::visitArraySet(ArraySet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArraySet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayLen(ArrayLen* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayLen); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayCopy(ArrayCopy* curr) { + if (curr->srcRef->type.isNull() || curr->destRef->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayCopy); parent.writeIndexedHeapType(curr->destRef->type.getHeapType()); parent.writeIndexedHeapType(curr->srcRef->type.getHeapType()); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index d24e42acb..43b381a1b 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -578,6 +578,15 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, if (a == b) { return a; } + if (HeapType(a).getBottom() != HeapType(b).getBottom()) { + return {}; + } + if (HeapType(a).isBottom()) { + return b; + } + if (HeapType(b).isBottom()) { + return a; + } // Canonicalize to have `a` be the lesser type. if (unsigned(a) > unsigned(b)) { std::swap(a, b); @@ -585,7 +594,7 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, switch (a) { case HeapType::ext: case HeapType::func: - return {}; + return std::nullopt; case HeapType::any: return {HeapType::any}; case HeapType::eq: @@ -604,6 +613,11 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, case HeapType::stringview_wtf16: case HeapType::stringview_iter: return {HeapType::any}; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Bottom types already handled. + break; } WASM_UNREACHABLE("unexpected basic type"); } @@ -1085,6 +1099,12 @@ FeatureSet Type::getFeatures() const { case HeapType::stringview_wtf16: case HeapType::stringview_iter: return FeatureSet::ReferenceTypes | FeatureSet::Strings; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Technically introduced in GC, but used internally as part of + // ref.null with just reference types. + return FeatureSet::ReferenceTypes; } } // Note: Technically typed function references also require the typed @@ -1360,6 +1380,29 @@ bool HeapType::isArray() const { } } +bool HeapType::isBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + case func: + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + return false; + case none: + case noext: + case nofunc: + return true; + } + } + return false; +} + Signature HeapType::getSignature() const { assert(isSignature()); return getHeapTypeInfo(*this)->signature; @@ -1420,11 +1463,52 @@ size_t HeapType::getDepth() const { case HeapType::stringview_iter: depth += 2; break; + case HeapType::none: + case HeapType::nofunc: + case HeapType::noext: + // Bottom types are infinitely deep. + depth = size_t(-1l); } } return depth; } +HeapType::BasicHeapType HeapType::getBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + return noext; + case func: + return nofunc; + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + case none: + return none; + case noext: + return noext; + case nofunc: + return nofunc; + } + } + auto* info = getHeapTypeInfo(*this); + switch (info->kind) { + case HeapTypeInfo::BasicKind: + return HeapType(info->basic).getBottom(); + case HeapTypeInfo::SignatureKind: + return nofunc; + case HeapTypeInfo::StructKind: + case HeapTypeInfo::ArrayKind: + return none; + } + WASM_UNREACHABLE("unexpected kind"); +} + bool HeapType::isSubType(HeapType left, HeapType right) { // As an optimization, in the common case do not even construct a SubTyper. if (left == right) { @@ -1451,6 +1535,15 @@ std::optional<HeapType> HeapType::getLeastUpperBound(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (getTypeSystem() == TypeSystem::Equirecursive) { return TypeBounder().getLeastUpperBound(a, b); } @@ -1653,27 +1746,34 @@ bool SubTyper::isSubType(HeapType a, HeapType b) { if (b.isBasic()) { switch (b.getBasic()) { case HeapType::ext: - return a == HeapType::ext; + return a == HeapType::noext; case HeapType::func: - return a.isSignature(); + return a == HeapType::nofunc || a.isSignature(); case HeapType::any: - return a != HeapType::ext && !a.isFunction(); + return a == HeapType::eq || a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::eq: - return a == HeapType::i31 || a.isData(); + return a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::i31: - return false; + return a == HeapType::none; case HeapType::data: - return a.isData(); + return a == HeapType::none || a.isData(); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: + return a == HeapType::none; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: return false; } } if (a.isBasic()) { - // Basic HeapTypes are never subtypes of compound HeapTypes. - return false; + // Basic HeapTypes are only subtypes of compound HeapTypes if they are + // bottom types. + return a == b.getBottom(); } if (typeSystem == TypeSystem::Nominal || typeSystem == TypeSystem::Isorecursive) { @@ -1823,6 +1923,15 @@ std::optional<HeapType> TypeBounder::lub(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (a.isBasic() || b.isBasic()) { return getBasicHeapTypeLUB(getBasicHeapSupertype(a), @@ -2000,12 +2109,18 @@ std::ostream& TypePrinter::print(Type type) { // Print shorthands for certain basic heap types. if (type.isNullable()) { switch (heapType.getBasic()) { + case HeapType::ext: + return os << "externref"; case HeapType::func: return os << "funcref"; case HeapType::any: return os << "anyref"; case HeapType::eq: return os << "eqref"; + case HeapType::i31: + return os << "i31ref"; + case HeapType::data: + return os << "dataref"; case HeapType::string: return os << "stringref"; case HeapType::stringview_wtf8: @@ -2014,17 +2129,12 @@ std::ostream& TypePrinter::print(Type type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; - default: - break; - } - } else { - switch (heapType.getBasic()) { - case HeapType::i31: - return os << "i31ref"; - case HeapType::data: - return os << "dataref"; - default: - break; + case HeapType::none: + return os << "nullref"; + case HeapType::noext: + return os << "nullexternref"; + case HeapType::nofunc: + return os << "nullfuncref"; } } } @@ -2063,6 +2173,12 @@ std::ostream& TypePrinter::print(HeapType type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; + case HeapType::none: + return os << "none"; + case HeapType::noext: + return os << "noextern"; + case HeapType::nofunc: + return os << "nofunc"; } } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 5e3fdc6e7..ba309ddea 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2110,13 +2110,12 @@ void FunctionValidator::visitRefNull(RefNull* curr) { shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(), curr, "ref.null requires reference-types to be enabled"); + if (!shouldBeTrue( + curr->type.isNullable(), curr, "ref.null types must be nullable")) { + return; + } shouldBeTrue( - curr->type.isNullable(), curr, "ref.null types must be nullable"); - - // The type of the null must also be valid for the features. - shouldBeTrue(curr->type.getFeatures() <= getModule()->features, - curr->type, - "ref.null type should be allowed"); + curr->type.isNull(), curr, "ref.null must have a bottom heap type"); } void FunctionValidator::visitRefIs(RefIs* curr) { @@ -2454,12 +2453,15 @@ void FunctionValidator::visitCallRef(CallRef* curr) { validateReturnCall(curr); shouldBeTrue( getModule()->features.hasGC(), curr, "call_ref requires gc to be enabled"); - if (curr->target->type != Type::unreachable) { - if (shouldBeTrue(curr->target->type.isFunction(), - curr, - "call_ref target must be a function reference")) { - validateCallParamsAndResult(curr, curr->target->type.getHeapType()); - } + if (curr->target->type == Type::unreachable || + (curr->target->type.isRef() && + curr->target->type.getHeapType() == HeapType::nofunc)) { + return; + } + if (shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference")) { + validateCallParamsAndResult(curr, curr->target->type.getHeapType()); } } @@ -2580,7 +2582,7 @@ void FunctionValidator::visitStructGet(StructGet* curr) { shouldBeTrue(getModule()->features.hasGC(), curr, "struct.get requires gc to be enabled"); - if (curr->ref->type == Type::unreachable) { + if (curr->type == Type::unreachable || curr->ref->type.isNull()) { return; } if (!shouldBeTrue(curr->ref->type.isStruct(), @@ -2610,22 +2612,28 @@ void FunctionValidator::visitStructSet(StructSet* curr) { if (curr->ref->type == Type::unreachable) { return; } - if (!shouldBeTrue(curr->ref->type.isStruct(), + if (!shouldBeTrue(curr->ref->type.isRef(), curr->ref, - "struct.set ref must be a struct")) { + "struct.set ref must be a reference type")) { return; } - if (curr->ref->type != Type::unreachable) { - const auto& fields = curr->ref->type.getHeapType().getStruct().fields; - shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); - auto& field = fields[curr->index]; - shouldBeSubType(curr->value->type, - field.type, - curr, - "struct.set must have the proper type"); - shouldBeEqual( - field.mutable_, Mutable, curr, "struct.set field must be mutable"); + auto type = curr->ref->type.getHeapType(); + if (type == HeapType::none) { + return; } + if (!shouldBeTrue( + type.isStruct(), curr->ref, "struct.set ref must be a struct")) { + return; + } + const auto& fields = type.getStruct().fields; + shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); + auto& field = fields[curr->index]; + shouldBeSubType(curr->value->type, + field.type, + curr, + "struct.set must have the proper type"); + shouldBeEqual( + field.mutable_, Mutable, curr, "struct.set field must be mutable"); } void FunctionValidator::visitArrayNew(ArrayNew* curr) { @@ -2688,7 +2696,18 @@ void FunctionValidator::visitArrayGet(ArrayGet* curr) { if (curr->type == Type::unreachable) { return; } - const auto& element = curr->ref->type.getHeapType().getArray().element; + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.get target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } + const auto& element = heapType.getArray().element; // If the type is not packed, it must be marked internally as unsigned, by // convention. if (element.type != Type::i32 || element.packedType == Field::not_packed) { @@ -2706,6 +2725,17 @@ void FunctionValidator::visitArraySet(ArraySet* curr) { if (curr->type == Type::unreachable) { return; } + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.set target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } const auto& element = curr->ref->type.getHeapType().getArray().element; shouldBeSubType(curr->value->type, element.type, @@ -2736,9 +2766,23 @@ void FunctionValidator::visitArrayCopy(ArrayCopy* curr) { if (curr->type == Type::unreachable) { return; } - const auto& srcElement = curr->srcRef->type.getHeapType().getArray().element; - const auto& destElement = - curr->destRef->type.getHeapType().getArray().element; + if (!shouldBeSubType(curr->srcRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy source should be an array reference") || + !shouldBeSubType(curr->destRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy destination should be an array reference")) { + return; + } + auto srcHeapType = curr->srcRef->type.getHeapType(); + auto destHeapType = curr->destRef->type.getHeapType(); + if (srcHeapType == HeapType::none || destHeapType == HeapType::none) { + return; + } + const auto& srcElement = srcHeapType.getArray().element; + const auto& destElement = destHeapType.getArray().element; shouldBeSubType(srcElement.type, destElement.type, curr, diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 724fc12e2..27690f43e 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -796,7 +796,10 @@ void MemoryGrow::finalize() { } } -void RefNull::finalize(HeapType heapType) { type = Type(heapType, Nullable); } +void RefNull::finalize(HeapType heapType) { + assert(heapType.isBottom()); + type = Type(heapType, Nullable); +} void RefNull::finalize(Type type_) { type = type_; } @@ -1033,7 +1036,7 @@ void StructNew::finalize() { void StructGet::finalize() { if (ref->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getStruct().fields[index].type; } } @@ -1066,7 +1069,7 @@ void ArrayInit::finalize() { void ArrayGet::finalize() { if (ref->type == Type::unreachable || index->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getArray().element.type; } } |