diff options
-rw-r--r-- | src/wasm/wasm-type.cpp | 304 | ||||
-rw-r--r-- | test/example/type-builder.cpp | 39 | ||||
-rw-r--r-- | test/example/type-builder.txt | 51 | ||||
-rw-r--r-- | test/lit/passes/roundtrip-gc-types.wast | 27 |
4 files changed, 338 insertions, 83 deletions
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index f60f1c3a7..87b562937 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -27,6 +27,12 @@ #include "wasm-features.h" #include "wasm-type.h" +#define TRACE_CANONICALIZATION 0 + +#if TRACE_CANONICALIZATION +#include <iostream> +#endif + namespace wasm { namespace { @@ -65,6 +71,12 @@ struct TypeInfo { bool isNullable() const { return kind == RefKind && ref.nullable; } + // If this TypeInfo represents a Type that can be represented more simply, set + // `out` to be that simpler Type and return true. For example, this handles + // canonicalizing the TypeInfo representing (ref null any) into the BasicType + // anyref. It also handles eliminating singleton tuple types. + bool getCanonical(Type& out) const; + bool operator==(const TypeInfo& other) const; bool operator!=(const TypeInfo& other) const { return !(*this == other); } }; @@ -106,6 +118,11 @@ struct HeapTypeInfo { constexpr bool isArray() const { return kind == ArrayKind; } constexpr bool isData() const { return isStruct() || isArray(); } + // If this HeapTypeInfo represents a HeapType that can be represented more + // simply, set `out` to be that simpler HeapType and return true. This handles + // turning BasicKind HeapTypes into their corresponding BasicHeapTypes. + bool getCanonical(HeapType& out) const; + HeapTypeInfo& operator=(const HeapTypeInfo& other); bool operator==(const HeapTypeInfo& other) const; bool operator!=(const HeapTypeInfo& other) const { return !(*this == other); } @@ -251,9 +268,7 @@ namespace std { template<> class hash<wasm::TypeInfo> { public: - size_t operator()(const wasm::TypeInfo& info) const { - return wasm::FiniteShapeHasher().hash(info); - } + size_t operator()(const wasm::TypeInfo& info) const; }; template<> class hash<wasm::HeapTypeInfo> { @@ -310,6 +325,26 @@ bool isTemp(HeapType type) { return !type.isBasic() && getHeapTypeInfo(type)->isTemp; } +// Given a Type that may or may not be backed by the simplest possible +// representation, return the equivalent type that is definitely backed by the +// simplest possible representation. +Type asCanonical(Type type) { + if (!type.isBasic()) { + getTypeInfo(type)->getCanonical(type); + } + return type; +} + +// Given a HeapType that may or may not be backed by the simplest possible +// representation, return the equivalent type that is definitely backed by the +// simplest possible representation. +HeapType asCanonical(HeapType type) { + if (!type.isBasic()) { + getHeapTypeInfo(type)->getCanonical(type); + } + return type; +} + TypeInfo::TypeInfo(const TypeInfo& other) { kind = other.kind; switch (kind) { @@ -341,11 +376,67 @@ TypeInfo::~TypeInfo() { WASM_UNREACHABLE("unexpected kind"); } +bool TypeInfo::getCanonical(Type& out) const { + if (isTuple()) { + if (tuple.types.size() == 0) { + out = Type::none; + return true; + } + if (tuple.types.size() == 1) { + out = tuple.types[0]; + return true; + } + } + if (isRef()) { + HeapType basic = asCanonical(ref.heapType); + if (basic.isBasic()) { + if (ref.nullable) { + switch (basic.getBasic()) { + case HeapType::func: + out = Type::funcref; + return true; + case HeapType::ext: + out = Type::externref; + return true; + case HeapType::any: + out = Type::anyref; + return true; + case HeapType::eq: + out = Type::eqref; + return true; + case HeapType::i31: + case HeapType::data: + break; + } + } else { + if (basic == HeapType::i31) { + out = Type::i31ref; + return true; + } + if (basic == HeapType::data) { + out = Type::dataref; + return true; + } + } + } + } + return false; +} + bool TypeInfo::operator==(const TypeInfo& other) const { - // TypeInfos with the same shape are considered equivalent. This is important - // during global canonicalization, when newly created canonically-shaped - // graphs are checked against the existing globally canonical graphs. - return FiniteShapeEquator().eq(*this, other); + if (kind != other.kind) { + return false; + } + switch (kind) { + case TupleKind: + return tuple == other.tuple; + case RefKind: + return ref.nullable == other.ref.nullable && + ref.heapType == other.ref.heapType; + case RttKind: + return rtt == other.rtt; + } + WASM_UNREACHABLE("unexpected kind"); } HeapTypeInfo::HeapTypeInfo(const HeapTypeInfo& other) { @@ -384,6 +475,14 @@ HeapTypeInfo::~HeapTypeInfo() { WASM_UNREACHABLE("unexpected kind"); } +bool HeapTypeInfo::getCanonical(HeapType& out) const { + if (isFinalized && kind == BasicKind) { + out = basic; + return true; + } + return false; +} + HeapTypeInfo& HeapTypeInfo::operator=(const HeapTypeInfo& other) { if (&other != this) { this->~HeapTypeInfo(); @@ -393,6 +492,10 @@ HeapTypeInfo& HeapTypeInfo::operator=(const HeapTypeInfo& other) { } bool HeapTypeInfo::operator==(const HeapTypeInfo& other) const { + // HeapTypeInfos with the same shape are considered equivalent. This is + // important during global canonicalization, when newly created + // canonically-shaped graphs are checked against the existing globally + // canonical graphs. return FiniteShapeEquator().eq(*this, other); } @@ -409,23 +512,15 @@ template<typename Info> struct Store { bool isGlobalStore(); #endif - TypeID recordCanonical(std::unique_ptr<Info>&& info); typename Info::type_t canonicalize(const Info& info); -}; + typename Info::type_t canonicalize(std::unique_ptr<Info>&& info); -struct TypeStore : Store<TypeInfo> { - Type canonicalize(TypeInfo info); +private: + TypeID recordCanonical(std::unique_ptr<Info>&& info); }; -struct HeapTypeStore : Store<HeapTypeInfo> { - HeapType canonicalize(const HeapTypeInfo& info) { - if (info.kind == HeapTypeInfo::BasicKind) { - return info.basic; - } - return Store<HeapTypeInfo>::canonicalize(info); - } - HeapType canonicalize(std::unique_ptr<HeapTypeInfo>&& info); -}; +using TypeStore = Store<TypeInfo>; +using HeapTypeStore = Store<HeapTypeInfo>; TypeStore globalTypeStore; HeapTypeStore globalHeapTypeStore; @@ -454,17 +549,11 @@ template<typename Info> bool Store<Info>::isGlobalStore() { #endif template<typename Info> -TypeID Store<Info>::recordCanonical(std::unique_ptr<Info>&& info) { - assert((!isGlobalStore() || !info->isTemp) && "Leaking temporary type!"); - TypeID id = uintptr_t(info.get()); - assert(id > Info::type_t::_last_basic_type); - typeIDs[*info] = id; - constructedTypes.emplace_back(std::move(info)); - return id; -} - -template<typename Info> typename Info::type_t Store<Info>::canonicalize(const Info& info) { + typename Info::type_t canonical; + if (info.getCanonical(canonical)) { + return canonical; + } std::lock_guard<std::mutex> lock(mutex); auto indexIt = typeIDs.find(std::cref(info)); if (indexIt != typeIDs.end()) { @@ -473,9 +562,11 @@ typename Info::type_t Store<Info>::canonicalize(const Info& info) { return typename Info::type_t(recordCanonical(std::make_unique<Info>(info))); } -HeapType HeapTypeStore::canonicalize(std::unique_ptr<HeapTypeInfo>&& info) { - if (info->kind == HeapTypeInfo::BasicKind) { - return info->basic; +template<typename Info> +typename Info::type_t Store<Info>::canonicalize(std::unique_ptr<Info>&& info) { + typename Info::type_t canonical; + if (info->getCanonical(canonical)) { + return canonical; } std::lock_guard<std::mutex> lock(mutex); auto indexIt = typeIDs.find(std::cref(*info)); @@ -486,40 +577,14 @@ HeapType HeapTypeStore::canonicalize(std::unique_ptr<HeapTypeInfo>&& info) { return HeapType(recordCanonical(std::move(info))); } -Type TypeStore::canonicalize(TypeInfo info) { - if (info.isTuple()) { - if (info.tuple.types.size() == 0) { - return Type::none; - } - if (info.tuple.types.size() == 1) { - return info.tuple.types[0]; - } - } - if (info.isRef() && info.ref.heapType.isBasic()) { - if (info.ref.nullable) { - switch (info.ref.heapType.getBasic()) { - case HeapType::func: - return Type::funcref; - case HeapType::ext: - return Type::externref; - case HeapType::any: - return Type::anyref; - case HeapType::eq: - return Type::eqref; - case HeapType::i31: - case HeapType::data: - break; - } - } else { - if (info.ref.heapType == HeapType::i31) { - return Type::i31ref; - } - if (info.ref.heapType == HeapType::data) { - return Type::dataref; - } - } - } - return Store<TypeInfo>::canonicalize(info); +template<typename Info> +TypeID Store<Info>::recordCanonical(std::unique_ptr<Info>&& info) { + assert((!isGlobalStore() || !info->isTemp) && "Leaking temporary type!"); + TypeID id = uintptr_t(info.get()); + assert(id > Info::type_t::_last_basic_type); + typeIDs[*info] = id; + constructedTypes.emplace_back(std::move(info)); + return id; } } // anonymous namespace @@ -1237,7 +1302,8 @@ Type TypeBounder::getLeastUpperBound(Type a, Type b) { // Array is arbitrary; it might as well have been a Struct. builder.grow(1); builder[builder.size() - 1] = Array(Field(tempLUB, Mutable)); - return builder.build().back().getArray().element.type; + std::vector<HeapType> built = builder.build(); + return built.back().getArray().element.type; } bool TypeBounder::lub(Type a, Type b, Type& out) { @@ -1524,7 +1590,10 @@ std::ostream& TypePrinter::print(HeapType heapType) { if (isTemp(heapType)) { os << "[T]"; } - if (heapType.isSignature()) { + if (getHeapTypeInfo(heapType)->kind == HeapTypeInfo::BasicKind) { + os << '*'; + print(getHeapTypeInfo(heapType)->basic); + } else if (heapType.isSignature()) { print(heapType.getSignature()); } else if (heapType.isStruct()) { print(heapType.getStruct()); @@ -1622,6 +1691,7 @@ std::ostream& TypePrinter::print(const Rtt& rtt) { } size_t FiniteShapeHasher::hash(Type type) { + type = asCanonical(type); size_t digest = wasm::hash(type.isBasic()); if (type.isBasic()) { rehash(digest, type.getID()); @@ -1632,6 +1702,7 @@ size_t FiniteShapeHasher::hash(Type type) { } size_t FiniteShapeHasher::hash(HeapType heapType) { + heapType = asCanonical(heapType); size_t digest = wasm::hash(heapType.isBasic()); if (heapType.isBasic()) { rehash(digest, heapType.getID()); @@ -1682,8 +1753,7 @@ size_t FiniteShapeHasher::hash(const HeapTypeInfo& info) { rehash(digest, info.kind); switch (info.kind) { case HeapTypeInfo::BasicKind: - hash_combine(digest, wasm::hash(info.basic)); - return digest; + WASM_UNREACHABLE("Basic HeapTypeInfo should have been canonicalized"); case HeapTypeInfo::SignatureKind: hash_combine(digest, hash(info.signature)); return digest; @@ -1737,6 +1807,8 @@ size_t FiniteShapeHasher::hash(const Rtt& rtt) { } bool FiniteShapeEquator::eq(Type a, Type b) { + a = asCanonical(a); + b = asCanonical(b); if (a.isBasic() != b.isBasic()) { return false; } else if (a.isBasic()) { @@ -1747,6 +1819,8 @@ bool FiniteShapeEquator::eq(Type a, Type b) { } bool FiniteShapeEquator::eq(HeapType a, HeapType b) { + a = asCanonical(a); + b = asCanonical(b); if (a.isBasic() != b.isBasic()) { return false; } else if (a.isBasic()) { @@ -1797,7 +1871,7 @@ bool FiniteShapeEquator::eq(const HeapTypeInfo& a, const HeapTypeInfo& b) { } switch (a.kind) { case HeapTypeInfo::BasicKind: - return a.basic == b.basic; + WASM_UNREACHABLE("Basic HeapTypeInfo should have been canonicalized"); case HeapTypeInfo::SignatureKind: return eq(a.signature, b.signature); case HeapTypeInfo::StructKind: @@ -2000,10 +2074,28 @@ private: void initializePartitions(); void translatePartitionsToTypes(); - std::vector<HeapType*> getChildren(HeapType type); + // Returns pointers to the HeapType's immediate descendant compound HeapTypes. + // For determining partitions and state transitions, BasicKind HeapTypes are + // treated identically to basic HeapTypes and are not included in the results + // of `getChildren`. For translating the partitions back into types, though, + // it is important that BasicKind children are included so they can be updated + // to refer to their corresponding shape-canonicalized HeapTypeInfo in the + // results. TODO: Consolidate all type scanning in one utility. + std::vector<HeapType*> getChildren(HeapType type, bool includeBasic = false); const TypeSet& getPredsOf(HeapType type, size_t symbol); TypeSet getIntersection(const TypeSet& a, const TypeSet& b); TypeSet getDifference(const TypeSet& a, const TypeSet& b); + +#if TRACE_CANONICALIZATION + void dumpPartitions() { + for (auto& partition : partitions) { + for (HeapType type : partition) { + std::cerr << type << '\n'; + } + std::cerr << '\n'; + } + } +#endif }; ShapeCanonicalizer::ShapeCanonicalizer(const std::vector<HeapType>& input) @@ -2011,6 +2103,11 @@ ShapeCanonicalizer::ShapeCanonicalizer(const std::vector<HeapType>& input) initializePredecessors(); initializePartitions(); +#if TRACE_CANONICALIZATION + std::cerr << "Initial partitions:\n"; + dumpPartitions(); +#endif + // The Hopcroft's algorithm's list of partitions that may still be // distinguishing partitions. Starts out containing all partitions. std::set<size_t> distinguishers; @@ -2071,6 +2168,11 @@ ShapeCanonicalizer::ShapeCanonicalizer(const std::vector<HeapType>& input) } } +#if TRACE_PARTITIONS + std::cerr << "Final partitions:\n"; + dumpPartitions(); +#endif + translatePartitionsToTypes(); } @@ -2124,7 +2226,7 @@ void ShapeCanonicalizer::translatePartitionsToTypes() { infos.back()->isTemp = true; } for (auto& info : infos) { - for (auto* child : getChildren(asHeapType(info))) { + for (auto* child : getChildren(asHeapType(info), true)) { auto partitionIt = partitionIndices.find(*child); if (partitionIt == partitionIndices.end()) { // This child has already been replaced. @@ -2135,11 +2237,16 @@ void ShapeCanonicalizer::translatePartitionsToTypes() { } } -std::vector<HeapType*> ShapeCanonicalizer::getChildren(HeapType heapType) { +std::vector<HeapType*> ShapeCanonicalizer::getChildren(HeapType heapType, + bool includeBasic) { std::vector<HeapType*> children; auto noteChild = [&](HeapType* child) { - if (!child->isBasic()) { + HeapType type = *child; + if (!includeBasic) { + type = asCanonical(type); + } + if (!type.isBasic()) { children.push_back(child); } }; @@ -2274,6 +2381,14 @@ GlobalCanonicalizer::GlobalCanonicalizer( scanList.push_back(&results.back()); } +#if TRACE_CANONICALIZATION + std::cerr << "Initial Types:\n"; + for (HeapType type : results) { + std::cerr << type << '\n'; + } + std::cerr << '\n'; +#endif + // Traverse the type graph reachable from the heap types, collecting a list of // type and heap type use sites that need to be patched with canonical types. while (scanList.size() != 0) { @@ -2297,13 +2412,19 @@ GlobalCanonicalizer::GlobalCanonicalizer( // Canonicalize non-tuple Types (which never directly refer to other Types) // before tuple Types to avoid canonicalizing a tuple that still contains // non-canonical Types. + std::unordered_map<HeapType, HeapType> canonicalHeapTypes; for (auto& info : infos) { HeapType original = asHeapType(info); HeapType canonical = globalHeapTypeStore.canonicalize(std::move(info)); if (original != canonical) { - for (HeapType* use : heapTypeLocations.at(original)) { - *use = canonical; - } + canonicalHeapTypes[original] = canonical; + } + } + for (auto& pair : canonicalHeapTypes) { + HeapType original = pair.first; + HeapType canonical = pair.second; + for (HeapType* use : heapTypeLocations.at(original)) { + *use = canonical; } } auto canonicalizeTypes = [&](bool tuples) { @@ -2320,6 +2441,14 @@ GlobalCanonicalizer::GlobalCanonicalizer( }; canonicalizeTypes(false); canonicalizeTypes(true); + +#if TRACE_CANONICALIZATION + std::cerr << "Final Types:\n"; + for (HeapType type : results) { + std::cerr << type << '\n'; + } + std::cerr << '\n'; +#endif } template<typename T1, typename T2> @@ -2472,4 +2601,21 @@ size_t hash<wasm::Rtt>::operator()(const wasm::Rtt& rtt) const { return digest; } +size_t hash<wasm::TypeInfo>::operator()(const wasm::TypeInfo& info) const { + auto digest = wasm::hash(info.kind); + switch (info.kind) { + case wasm::TypeInfo::TupleKind: + wasm::rehash(digest, info.tuple); + return digest; + case wasm::TypeInfo::RefKind: + wasm::rehash(digest, info.ref.nullable); + wasm::rehash(digest, info.ref.heapType); + return digest; + case wasm::TypeInfo::RttKind: + wasm::rehash(digest, info.rtt); + return digest; + } + WASM_UNREACHABLE("unexpected kind"); +} + } // namespace std diff --git a/test/example/type-builder.cpp b/test/example/type-builder.cpp index a228697e9..8d87c5054 100644 --- a/test/example/type-builder.cpp +++ b/test/example/type-builder.cpp @@ -103,6 +103,32 @@ void test_canonicalization() { assert(built[3] == sig); } +// Check that defined basic HeapTypes are handled correctly. +void test_basic() { + std::cout << ";; Test basic\n"; + + TypeBuilder builder(6); + + Type anyref = builder.getTempRefType(builder[4], Nullable); + Type i31ref = builder.getTempRefType(builder[5], NonNullable); + + builder[0] = Signature(Type::anyref, Type::i31ref); + builder[1] = Signature(anyref, Type::i31ref); + builder[2] = Signature(Type::anyref, i31ref); + builder[3] = Signature(anyref, i31ref); + builder[4] = HeapType::any; + builder[5] = HeapType::i31; + + std::vector<HeapType> built = builder.build(); + + assert(built[0] == HeapType(Signature(Type::anyref, Type::i31ref))); + assert(built[1] == built[0]); + assert(built[2] == built[1]); + assert(built[3] == built[2]); + assert(built[4] == HeapType::any); + assert(built[5] == HeapType::i31); +} + void test_recursive() { std::cout << ";; Test recursive types\n"; @@ -467,8 +493,13 @@ void test_lub() { } int main() { - test_builder(); - test_canonicalization(); - test_recursive(); - test_lub(); + // Run the tests twice to ensure things still work when the global stores are + // already populated. + for (size_t i = 0; i < 2; ++i) { + test_builder(); + test_canonicalization(); + test_basic(); + test_recursive(); + test_lub(); + } } diff --git a/test/example/type-builder.txt b/test/example/type-builder.txt index d9f469461..a219816e6 100644 --- a/test/example/type-builder.txt +++ b/test/example/type-builder.txt @@ -21,6 +21,57 @@ After building types: (rtt 0 $array) => (rtt 0 (array (mut externref))) ;; Test canonicalization +;; Test basic +;; Test recursive types +(func (result (ref null ...1))) + +(func (result (ref null ...1))) +(func (result (ref null ...1))) + +(func (result (ref null ...1))) +(func (result (ref null ...1))) +(func (result (ref null ...1))) +(func (result (ref null ...1))) +(func (result (ref null ...1))) + +(func (result (ref null ...1) (ref null (func)))) +(func (result (ref null ...1) (ref null (func)))) +(func) +(func) +(func (result (ref null (func (result ...1 (ref null (func))))))) +(func (result (ref null (func (result ...1 (ref null (func))))))) + +(func (result (ref null ...1))) +(func (result (ref null ...1))) + +(func (param anyref) (result (ref null ...1))) +(func (param anyref) (result (ref null ...1))) + +;; Test LUBs +;; Test TypeBuilder +Before setting heap types: +(ref $sig) => [T](ref [T](func)) +(ref $struct) => [T](ref [T](func)) +(ref $array) => [T](ref [T](func)) +(ref null $array) => [T](ref null [T](func)) +(rtt 0 $array) => [T](rtt 0 [T](func)) + +After setting heap types: +(ref $sig) => [T](ref [T](func (param [T](ref [T](struct (field [T](ref null [T](array (mut externref))) (mut [T](rtt 0 [T](array (mut externref)))))))) (result [T](ref [T](array (mut externref))) i32))) +(ref $struct) => [T](ref [T](struct (field [T](ref null [T](array (mut externref))) (mut [T](rtt 0 [T](array (mut externref))))))) +(ref $array) => [T](ref [T](array (mut externref))) +(ref null $array) => [T](ref null [T](array (mut externref))) +(rtt 0 $array) => [T](rtt 0 [T](array (mut externref))) + +After building types: +(ref $sig) => (ref (func (param (ref (struct (field (ref null (array (mut externref))) (mut (rtt 0 (array (mut externref)))))))) (result (ref (array (mut externref))) i32))) +(ref $struct) => (ref (struct (field (ref null (array (mut externref))) (mut (rtt 0 (array (mut externref))))))) +(ref $array) => (ref (array (mut externref))) +(ref null $array) => (ref null (array (mut externref))) +(rtt 0 $array) => (rtt 0 (array (mut externref))) + +;; Test canonicalization +;; Test basic ;; Test recursive types (func (result (ref null ...1))) diff --git a/test/lit/passes/roundtrip-gc-types.wast b/test/lit/passes/roundtrip-gc-types.wast new file mode 100644 index 000000000..449df6639 --- /dev/null +++ b/test/lit/passes/roundtrip-gc-types.wast @@ -0,0 +1,27 @@ +;; RUN: wasm-opt %s -all --roundtrip -S -o - | filecheck %s + +;; Regression test for an issue in which roundtripping failed to reproduce the +;; original types because type canonicalization was incorrect when the canonical +;; types already existed in the store. + +;; CHECK: (module +;; CHECK-NEXT: (type $A (struct (field (ref $C)))) +;; CHECK-NEXT: (type $B (func (param (ref $A)) (result (ref $B)))) +;; CHECK-NEXT: (type $C (struct (field (mut (ref $B))))) +;; CHECK-NEXT: (type $D (struct (field (ref $C)) (field (ref $A)))) +;; CHECK-NEXT: (global $g0 (rtt 0 $A) (rtt.canon $A)) +;; CHECK-NEXT: (global $g1 (rtt 1 $D) (rtt.sub $D +;; CHECK-NEXT: (global.get $g0) +;; CHECK-NEXT: )) +;; CHECK-NEXT: ) + +(module + (type $A (struct (field (ref $C)))) + (type $B (func (param (ref $A)) (result (ref $B)))) + (type $C (struct (field (mut (ref $B))))) + (type $D (struct (field (ref $C)) (field (ref $A)))) + (global $g0 (rtt 0 $A) (rtt.canon $A)) + (global $g1 (rtt 1 $D) (rtt.sub $D + (global.get $g0) + )) +) |