diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/module-utils.cpp | 55 | ||||
-rw-r--r-- | src/passes/Print.cpp | 46 | ||||
-rw-r--r-- | src/wasm-type.h | 32 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 50 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 110 |
5 files changed, 265 insertions, 28 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 9b2293984..0d0ec531c 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -159,11 +159,51 @@ Counts getHeapTypeCounts(Module& wasm) { counts.note(*super); } } + + // Make sure we've noted the complete recursion group of each type as well. + auto recGroup = ht.getRecGroup(); + for (auto type : recGroup) { + if (!counts.count(type)) { + newTypes.insert(type); + counts.note(type); + } + } } return counts; } +void coalesceRecGroups(IndexedHeapTypes& indexedTypes) { + if (getTypeSystem() != TypeSystem::Isorecursive) { + // No rec groups to coalesce. + return; + } + + // TODO: Perform a topological sort of the recursion groups to create a valid + // ordering rather than this hack that just gets all the types in a group to + // be adjacent. + assert(indexedTypes.indices.empty()); + std::unordered_set<HeapType> seen; + std::vector<HeapType> grouped; + grouped.reserve(indexedTypes.types.size()); + for (auto type : indexedTypes.types) { + if (seen.insert(type).second) { + for (auto member : type.getRecGroup()) { + grouped.push_back(member); + seen.insert(member); + } + } + } + assert(grouped.size() == indexedTypes.types.size()); + indexedTypes.types = grouped; +} + +void setIndices(IndexedHeapTypes& indexedTypes) { + for (Index i = 0; i < indexedTypes.types.size(); i++) { + indexedTypes.indices[indexedTypes.types[i]] = i; + } +} + } // anonymous namespace std::vector<HeapType> collectHeapTypes(Module& wasm) { @@ -179,11 +219,12 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) { IndexedHeapTypes getIndexedHeapTypes(Module& wasm) { Counts counts = getHeapTypeCounts(wasm); IndexedHeapTypes indexedTypes; - Index i = 0; for (auto& [type, _] : counts) { indexedTypes.types.push_back(type); - indexedTypes.indices[type] = i++; } + + coalesceRecGroups(indexedTypes); + setIndices(indexedTypes); return indexedTypes; } @@ -199,10 +240,14 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { // Collect the results. IndexedHeapTypes indexedTypes; for (Index i = 0; i < sorted.size(); ++i) { - HeapType type = sorted[i].first; - indexedTypes.types.push_back(type); - indexedTypes.indices[type] = i; + indexedTypes.types.push_back(sorted[i].first); } + + // TODO: Explicitly construct a linear extension of the partial order of + // recursion groups by adding edges between unrelated groups according to + // their use counts. + coalesceRecGroups(indexedTypes); + setIndices(indexedTypes); return indexedTypes; } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 7bcf23b8e..5bdf84aba 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2582,7 +2582,10 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { void handleSignature(HeapType curr, Name name = Name()) { Signature sig = curr.getSignature(); - if (!name.is() && getTypeSystem() == TypeSystem::Nominal) { + bool hasSupertype = + !name.is() && (getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive); + if (hasSupertype) { o << "(func_subtype"; } else { o << "(func"; @@ -2612,7 +2615,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { } o << ')'; } - if (!name.is() && getTypeSystem() == TypeSystem::Nominal) { + if (hasSupertype) { o << ' '; printSupertypeOr(curr, "func"); } @@ -2638,21 +2641,25 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { } } void handleArray(HeapType curr) { - if (getTypeSystem() == TypeSystem::Nominal) { + bool hasSupertype = getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive; + if (hasSupertype) { o << "(array_subtype "; } else { o << "(array "; } handleFieldBody(curr.getArray().element); - if (getTypeSystem() == TypeSystem::Nominal) { + if (hasSupertype) { o << ' '; printSupertypeOr(curr, "data"); } o << ')'; } void handleStruct(HeapType curr) { + bool hasSupertype = getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive; const auto& fields = curr.getStruct().fields; - if (getTypeSystem() == TypeSystem::Nominal) { + if (hasSupertype) { o << "(struct_subtype "; } else { o << "(struct "; @@ -2669,7 +2676,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { o << ')'; sep = " "; } - if (getTypeSystem() == TypeSystem::Nominal) { + if (hasSupertype) { o << ' '; printSupertypeOr(curr, "data"); } @@ -2779,7 +2786,8 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { o << '('; printMajor(o, "func "); printName(curr->name, o); - if (getTypeSystem() == TypeSystem::Nominal) { + if (getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive) { o << " (type "; printHeapType(o, curr->type, currModule) << ')'; } @@ -3064,10 +3072,32 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printName(curr->name, o); } incIndent(); + // Use the same type order as the binary output would even though there is // no code size benefit in the text format. auto indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*curr); + std::optional<RecGroup> currGroup; + bool nontrivialGroup = false; + auto finishGroup = [&]() { + if (nontrivialGroup) { + decIndent(); + o << maybeNewLine; + } + }; for (auto type : indexedTypes.types) { + RecGroup newGroup = type.getRecGroup(); + if (!currGroup || *currGroup != newGroup) { + if (currGroup) { + finishGroup(); + } + currGroup = newGroup; + nontrivialGroup = currGroup->size() > 1; + if (nontrivialGroup) { + doIndent(o, indent); + o << "(rec "; + incIndent(); + } + } doIndent(o, indent); o << '('; printMedium(o, "type") << ' '; @@ -3076,6 +3106,8 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { handleHeapType(type); o << ")" << maybeNewLine; } + finishGroup(); + ModuleUtils::iterImportedMemories( *curr, [&](Memory* memory) { visitMemory(memory); }); ModuleUtils::iterImportedTables(*curr, diff --git a/src/wasm-type.h b/src/wasm-type.h index c99113317..fa5650a31 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -56,6 +56,7 @@ void destroyAllTypesForTestingPurposesOnly(); // data. class Type; class HeapType; +class RecGroup; struct Tuple; struct Signature; struct Field; @@ -352,6 +353,9 @@ public: // number of supertypes in its supertype chain. size_t getDepth() const; + // Get the recursion group for this non-basic type. + RecGroup getRecGroup() const; + constexpr TypeID getID() const { return id; } constexpr BasicHeapType getBasic() const { assert(isBasic() && "Basic heap type expected"); @@ -378,6 +382,30 @@ public: std::string toString() const; }; +// A recursion group consisting of one or more HeapTypes. HeapTypes with single +// members are encoded without using any additional memory, which is why +// `getHeapTypes` has to return a vector by value; it might have to create one +// on the fly. +class RecGroup { + uintptr_t id; + +public: + explicit RecGroup(uintptr_t id) : id(id) {} + bool operator==(const RecGroup& other) { return id == other.id; } + bool operator!=(const RecGroup& other) { return id != other.id; } + size_t size() const; + + struct Iterator : ParentIndexIterator<const RecGroup*, Iterator> { + using value_type = HeapType; + using pointer = const HeapType*; + using reference = const HeapType&; + value_type operator*() const; + }; + + Iterator begin() const { return Iterator{{this, 0}}; } + Iterator end() const { return Iterator{{this, size()}}; } +}; + typedef std::vector<Type> TypeList; // Passed by reference rather than by value because it can own an unbounded @@ -556,6 +584,10 @@ struct TypeBuilder { // `j`. Does nothing for equirecursive types. void setSubType(size_t i, size_t j); + // Create a new recursion group covering slots [i, i + length). Groups must + // not overlap or go out of bounds. + void createRecGroup(size_t i, size_t length); + // Returns all of the newly constructed heap types. May only be called once // all of the heap types have been initialized with `setHeapType`. In nominal // mode, all of the constructed HeapTypes will be fresh and distinct. In diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index f76db38af..01e1789cd 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -53,8 +53,8 @@ namespace wasm { static Name STRUCT("struct"), FIELD("field"), ARRAY("array"), FUNC_SUBTYPE("func_subtype"), STRUCT_SUBTYPE("struct_subtype"), - ARRAY_SUBTYPE("array_subtype"), EXTENDS("extends"), I8("i8"), I16("i16"), - RTT("rtt"), DECLARE("declare"), ITEM("item"), OFFSET("offset"); + ARRAY_SUBTYPE("array_subtype"), EXTENDS("extends"), REC("rec"), I8("i8"), + I16("i16"), RTT("rtt"), DECLARE("declare"), ITEM("item"), OFFSET("offset"); static Address getAddress(const Element* s) { return atoll(s->c_str()); } @@ -458,6 +458,9 @@ void SExpressionWasmBuilder::parseModuleElement(Element& curr) { if (id == TYPE) { return; // already done } + if (id == REC) { + return; // already done + } if (id == TAG) { return parseTag(curr); } @@ -680,18 +683,29 @@ size_t SExpressionWasmBuilder::parseTypeUse(Element& s, } void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { + // Iterate through each individual type definition, calling `f` with the + // definition and its recursion group number. auto forEachType = [&](auto f) { + size_t groupNumber = 0; for (auto* elemPtr : module) { auto& elem = *elemPtr; if (elementStartsWith(elem, TYPE)) { - f(elem); + f(elem, groupNumber++); + } else if (elementStartsWith(elem, REC)) { + for (auto* innerPtr : elem) { + auto& inner = *innerPtr; + if (elementStartsWith(inner, TYPE)) { + f(inner, groupNumber); + } + } + ++groupNumber; } } }; + // Map type names to indices size_t numTypes = 0; - forEachType([&](Element& elem) { - // Map type names to indices + forEachType([&](Element& elem, size_t) { if (elem[1]->dollared()) { std::string name = elem[1]->c_str(); if (!typeIndices.insert({name, numTypes}).second) { @@ -703,6 +717,22 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { TypeBuilder builder(numTypes); + // Create recursion groups + size_t currGroup = 0, groupStart = 0, groupLength = 0; + auto finishGroup = [&]() { + builder.createRecGroup(groupStart, groupLength); + groupStart = groupStart + groupLength; + groupLength = 0; + }; + forEachType([&](Element&, size_t group) { + if (group != currGroup) { + finishGroup(); + currGroup = group; + } + ++groupLength; + }); + finishGroup(); + auto parseRefType = [&](Element& elem) -> Type { // '(' 'ref' 'null'? ht ')' auto nullable = @@ -860,22 +890,22 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { }; size_t index = 0; - forEachType([&](Element& elem) { + forEachType([&](Element& elem, size_t) { Element& def = elem[1]->dollared() ? *elem[2] : *elem[1]; Element& kind = *def[0]; - bool nominal = + bool hasSupertype = kind == FUNC_SUBTYPE || kind == STRUCT_SUBTYPE || kind == ARRAY_SUBTYPE; if (kind == FUNC || kind == FUNC_SUBTYPE) { - builder[index] = parseSignatureDef(def, nominal); + builder[index] = parseSignatureDef(def, hasSupertype); } else if (kind == STRUCT || kind == STRUCT_SUBTYPE) { - builder[index] = parseStructDef(def, index, nominal); + builder[index] = parseStructDef(def, index, hasSupertype); } else if (kind == ARRAY || kind == ARRAY_SUBTYPE) { builder[index] = parseArrayDef(def); } else { throw ParseException("unknown heaptype kind", kind.line, kind.col); } Element* super = nullptr; - if (nominal) { + if (hasSupertype) { super = def[def.size() - 1]; if (super->dollared()) { // OK diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 7801cc289..4c2922c73 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -95,6 +95,8 @@ struct TypeInfo { bool operator!=(const TypeInfo& other) const { return !(*this == other); } }; +using RecGroupInfo = std::vector<HeapType>; + struct HeapTypeInfo { using type_t = HeapType; // Used in assertions to ensure that temporary types don't leak into the @@ -108,6 +110,9 @@ struct HeapTypeInfo { // In nominal or isorecursive mode, the supertype of this HeapType, if it // exists. HeapTypeInfo* supertype = nullptr; + // In isorecursive mode, the recursion group of this type or null if the + // recursion group is trivial (i.e. contains only this type). + RecGroupInfo* recGroup = nullptr; enum Kind { BasicKind, SignatureKind, @@ -581,6 +586,7 @@ bool TypeInfo::operator==(const TypeInfo& other) const { HeapTypeInfo::HeapTypeInfo(const HeapTypeInfo& other) { kind = other.kind; supertype = other.supertype; + recGroup = other.recGroup; switch (kind) { case BasicKind: new (&basic) auto(other.basic); @@ -765,6 +771,10 @@ struct SignatureTypeCache { static SignatureTypeCache nominalSignatureCache; +// Keep track of the constructed recursion groups. +static std::mutex recGroupsMutex; +static std::vector<std::unique_ptr<RecGroupInfo>> recGroups; + } // anonymous namespace void destroyAllTypesForTestingPurposesOnly() { @@ -1236,6 +1246,41 @@ HeapType HeapType::getLeastUpperBound(HeapType a, HeapType b) { return TypeBounder().getLeastUpperBound(a, b); } +// Recursion groups with single elements are encoded as that single element's +// type ID with the low bit set and other recursion groups are encoded with the +// address of the vector containing their members. These encodings are disjoint +// because the alignment of the vectors is greater than 1. +static_assert(alignof(std::vector<HeapType>) > 1); + +RecGroup HeapType::getRecGroup() const { + assert(!isBasic()); + if (auto* info = getHeapTypeInfo(*this)->recGroup) { + return RecGroup(uintptr_t(info)); + } else { + // Mark the low bit to signify that this is a trivial recursion group and + // points to a heap type info rather than a vector of heap types. + return RecGroup(id | 1); + } +} + +HeapType RecGroup::Iterator::operator*() const { + if (parent->id & 1) { + // This is a trivial recursion group. Mask off the low bit to recover the + // single HeapType. + return {HeapType(parent->id & ~(uintptr_t)1)}; + } else { + return (*(std::vector<HeapType>*)parent->id)[index]; + } +} + +size_t RecGroup::size() const { + if (id & 1) { + return 1; + } else { + return ((std::vector<HeapType>*)id)->size(); + } +} + template<typename T> static std::string genericToString(const T& t) { std::ostringstream ss; ss << t; @@ -1959,7 +2004,8 @@ size_t FiniteShapeHasher::hash(const TypeInfo& info) { } size_t FiniteShapeHasher::hash(const HeapTypeInfo& info) { - if (getTypeSystem() == TypeSystem::Nominal) { + if (getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive) { return wasm::hash(uintptr_t(&info)); } // If the HeapTypeInfo is not finalized, then it is mutable and its shape @@ -2080,7 +2126,8 @@ bool FiniteShapeEquator::eq(const TypeInfo& a, const TypeInfo& b) { } bool FiniteShapeEquator::eq(const HeapTypeInfo& a, const HeapTypeInfo& b) { - if (getTypeSystem() == TypeSystem::Nominal) { + if (getTypeSystem() == TypeSystem::Nominal || + getTypeSystem() == TypeSystem::Isorecursive) { return &a == &b; } if (a.isFinalized != b.isFinalized) { @@ -2233,8 +2280,14 @@ void TypeGraphWalkerBase<Self>::scanHeapType(HeapType* ht) { } // anonymous namespace struct TypeBuilder::Impl { + // Store of temporary Types. Types that need to be canonicalized will be + // copied into the global TypeStore. TypeStore typeStore; + // Store of temporary recursion groups, which will be moved to the global + // collection of recursion groups as part of building. + std::vector<std::unique_ptr<RecGroupInfo>> recGroups; + struct Entry { std::unique_ptr<HeapTypeInfo> info; bool initialized = false; @@ -2248,6 +2301,7 @@ struct TypeBuilder::Impl { } void set(HeapTypeInfo&& hti) { hti.supertype = info->supertype; + hti.recGroup = info->recGroup; *info = std::move(hti); info->isTemp = true; info->isFinalized = false; @@ -2342,6 +2396,20 @@ void TypeBuilder::setSubType(size_t i, size_t j) { sub->supertype = super; } +void TypeBuilder::createRecGroup(size_t i, size_t length) { + assert(i <= size() && i + length <= size() && "group out of bounds"); + // Only materialize nontrivial recursion groups. + if (length < 2) { + return; + } + recGroups.emplace_back(std::make_unique<RecGroupInfo>()); + for (; length > 0; --length) { + auto& info = impl->entries[i + length - 1].info; + assert(info->recGroup == nullptr && "group already assigned"); + info->recGroup = recGroups.back().get(); + } +} + namespace { // Helper for TypeBuilder::build() that keeps track of temporary types and @@ -3012,8 +3080,9 @@ void globallyCanonicalize(CanonicalizationState& state) { } void canonicalizeEquirecursive(CanonicalizationState& state) { - // Equirecursive types always have null supertypes. + // Equirecursive types always have null supertypes and recursion groups. for (auto& info : state.newInfos) { + info->recGroup = nullptr; info->supertype = nullptr; } @@ -3036,6 +3105,14 @@ void canonicalizeEquirecursive(CanonicalizationState& state) { } void canonicalizeNominal(CanonicalizationState& state) { + // TODO: clear recursion groups once we are no longer piggybacking the + // isorecursive system on the nominal system. + // if (typeSystem != TypeSystem::Isorecursive) { + // for (auto& info : state.newInfos) { + // assert(info->recGroup == nullptr && "unexpected recursion group"); + // } + // } + // Nominal types do not require separate canonicalization, so just validate // that their subtyping is correct. @@ -3043,8 +3120,6 @@ void canonicalizeNominal(CanonicalizationState& state) { auto start = std::chrono::steady_clock::now(); #endif - assert(typeSystem == TypeSystem::Nominal); - // Ensure there are no cycles in the subtype graph. This is the classic DFA // algorithm for detecting cycles, but in the form of a simple loop because // each node (type) has at most one child (supertype). @@ -3115,6 +3190,29 @@ void canonicalizeNominal(CanonicalizationState& state) { #endif } +void canonicalizeIsorecursive( + CanonicalizationState& state, + std::vector<std::unique_ptr<RecGroupInfo>>& recGroupInfos) { + // Fill out the recursion groups. + for (auto& info : state.newInfos) { + if (info->recGroup != nullptr) { + info->recGroup->push_back(asHeapType(info)); + } + } + + // TODO: proper isorecursive validation and canonicalization. For now just + // piggyback on the nominal system. + canonicalizeNominal(state); + + // Move the recursion groups into the global store. TODO: after proper + // isorecursive canonicalization, some groups may no longer be used, so they + // will need to be filtered out. + std::lock_guard<std::mutex> lock(recGroupsMutex); + for (auto& info : recGroupInfos) { + recGroups.emplace_back(std::move(info)); + } +} + void canonicalizeBasicHeapTypes(CanonicalizationState& state) { // Replace heap types backed by BasicKind HeapTypeInfos with their // corresponding BasicHeapTypes. The heap types backed by BasicKind @@ -3174,7 +3272,7 @@ std::vector<HeapType> TypeBuilder::build() { canonicalizeNominal(state); break; case TypeSystem::Isorecursive: - Fatal() << "Isorecursive types not yet implemented"; + canonicalizeIsorecursive(state, impl->recGroups); break; } |