summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ir/module-utils.cpp55
-rw-r--r--src/passes/Print.cpp46
-rw-r--r--src/wasm-type.h32
-rw-r--r--src/wasm/wasm-s-parser.cpp50
-rw-r--r--src/wasm/wasm-type.cpp110
5 files changed, 265 insertions, 28 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index 9b2293984..0d0ec531c 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -159,11 +159,51 @@ Counts getHeapTypeCounts(Module& wasm) {
counts.note(*super);
}
}
+
+ // Make sure we've noted the complete recursion group of each type as well.
+ auto recGroup = ht.getRecGroup();
+ for (auto type : recGroup) {
+ if (!counts.count(type)) {
+ newTypes.insert(type);
+ counts.note(type);
+ }
+ }
}
return counts;
}
+void coalesceRecGroups(IndexedHeapTypes& indexedTypes) {
+ if (getTypeSystem() != TypeSystem::Isorecursive) {
+ // No rec groups to coalesce.
+ return;
+ }
+
+ // TODO: Perform a topological sort of the recursion groups to create a valid
+ // ordering rather than this hack that just gets all the types in a group to
+ // be adjacent.
+ assert(indexedTypes.indices.empty());
+ std::unordered_set<HeapType> seen;
+ std::vector<HeapType> grouped;
+ grouped.reserve(indexedTypes.types.size());
+ for (auto type : indexedTypes.types) {
+ if (seen.insert(type).second) {
+ for (auto member : type.getRecGroup()) {
+ grouped.push_back(member);
+ seen.insert(member);
+ }
+ }
+ }
+ assert(grouped.size() == indexedTypes.types.size());
+ indexedTypes.types = grouped;
+}
+
+void setIndices(IndexedHeapTypes& indexedTypes) {
+ for (Index i = 0; i < indexedTypes.types.size(); i++) {
+ indexedTypes.indices[indexedTypes.types[i]] = i;
+ }
+}
+
} // anonymous namespace
std::vector<HeapType> collectHeapTypes(Module& wasm) {
@@ -179,11 +219,12 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) {
IndexedHeapTypes getIndexedHeapTypes(Module& wasm) {
Counts counts = getHeapTypeCounts(wasm);
IndexedHeapTypes indexedTypes;
- Index i = 0;
for (auto& [type, _] : counts) {
indexedTypes.types.push_back(type);
- indexedTypes.indices[type] = i++;
}
+
+ coalesceRecGroups(indexedTypes);
+ setIndices(indexedTypes);
return indexedTypes;
}
@@ -199,10 +240,14 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
// Collect the results.
IndexedHeapTypes indexedTypes;
for (Index i = 0; i < sorted.size(); ++i) {
- HeapType type = sorted[i].first;
- indexedTypes.types.push_back(type);
- indexedTypes.indices[type] = i;
+ indexedTypes.types.push_back(sorted[i].first);
}
+
+ // TODO: Explicitly construct a linear extension of the partial order of
+ // recursion groups by adding edges between unrelated groups according to
+ // their use counts.
+ coalesceRecGroups(indexedTypes);
+ setIndices(indexedTypes);
return indexedTypes;
}
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 7bcf23b8e..5bdf84aba 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -2582,7 +2582,10 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
void handleSignature(HeapType curr, Name name = Name()) {
Signature sig = curr.getSignature();
- if (!name.is() && getTypeSystem() == TypeSystem::Nominal) {
+ bool hasSupertype =
+ !name.is() && (getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive);
+ if (hasSupertype) {
o << "(func_subtype";
} else {
o << "(func";
@@ -2612,7 +2615,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
}
o << ')';
}
- if (!name.is() && getTypeSystem() == TypeSystem::Nominal) {
+ if (hasSupertype) {
o << ' ';
printSupertypeOr(curr, "func");
}
@@ -2638,21 +2641,25 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
}
}
void handleArray(HeapType curr) {
- if (getTypeSystem() == TypeSystem::Nominal) {
+ bool hasSupertype = getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive;
+ if (hasSupertype) {
o << "(array_subtype ";
} else {
o << "(array ";
}
handleFieldBody(curr.getArray().element);
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (hasSupertype) {
o << ' ';
printSupertypeOr(curr, "data");
}
o << ')';
}
void handleStruct(HeapType curr) {
+ bool hasSupertype = getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive;
const auto& fields = curr.getStruct().fields;
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (hasSupertype) {
o << "(struct_subtype ";
} else {
o << "(struct ";
@@ -2669,7 +2676,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
o << ')';
sep = " ";
}
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (hasSupertype) {
o << ' ';
printSupertypeOr(curr, "data");
}
@@ -2779,7 +2786,8 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
o << '(';
printMajor(o, "func ");
printName(curr->name, o);
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive) {
o << " (type ";
printHeapType(o, curr->type, currModule) << ')';
}
@@ -3064,10 +3072,32 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
printName(curr->name, o);
}
incIndent();
+
// Use the same type order as the binary output would even though there is
// no code size benefit in the text format.
auto indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*curr);
+ std::optional<RecGroup> currGroup;
+ bool nontrivialGroup = false;
+ auto finishGroup = [&]() {
+ if (nontrivialGroup) {
+ decIndent();
+ o << maybeNewLine;
+ }
+ };
for (auto type : indexedTypes.types) {
+ RecGroup newGroup = type.getRecGroup();
+ if (!currGroup || *currGroup != newGroup) {
+ if (currGroup) {
+ finishGroup();
+ }
+ currGroup = newGroup;
+ nontrivialGroup = currGroup->size() > 1;
+ if (nontrivialGroup) {
+ doIndent(o, indent);
+ o << "(rec ";
+ incIndent();
+ }
+ }
doIndent(o, indent);
o << '(';
printMedium(o, "type") << ' ';
@@ -3076,6 +3106,8 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
handleHeapType(type);
o << ")" << maybeNewLine;
}
+ finishGroup();
+
ModuleUtils::iterImportedMemories(
*curr, [&](Memory* memory) { visitMemory(memory); });
ModuleUtils::iterImportedTables(*curr,
diff --git a/src/wasm-type.h b/src/wasm-type.h
index c99113317..fa5650a31 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -56,6 +56,7 @@ void destroyAllTypesForTestingPurposesOnly();
// data.
class Type;
class HeapType;
+class RecGroup;
struct Tuple;
struct Signature;
struct Field;
@@ -352,6 +353,9 @@ public:
// number of supertypes in its supertype chain.
size_t getDepth() const;
+ // Get the recursion group for this non-basic type.
+ RecGroup getRecGroup() const;
+
constexpr TypeID getID() const { return id; }
constexpr BasicHeapType getBasic() const {
assert(isBasic() && "Basic heap type expected");
@@ -378,6 +382,30 @@ public:
std::string toString() const;
};
+// A recursion group consisting of one or more HeapTypes. HeapTypes with single
+// members are encoded without using any additional memory, which is why
+// `getHeapTypes` has to return a vector by value; it might have to create one
+// on the fly.
+class RecGroup {
+ uintptr_t id;
+
+public:
+ explicit RecGroup(uintptr_t id) : id(id) {}
+ bool operator==(const RecGroup& other) { return id == other.id; }
+ bool operator!=(const RecGroup& other) { return id != other.id; }
+ size_t size() const;
+
+ struct Iterator : ParentIndexIterator<const RecGroup*, Iterator> {
+ using value_type = HeapType;
+ using pointer = const HeapType*;
+ using reference = const HeapType&;
+ value_type operator*() const;
+ };
+
+ Iterator begin() const { return Iterator{{this, 0}}; }
+ Iterator end() const { return Iterator{{this, size()}}; }
+};
+
typedef std::vector<Type> TypeList;
// Passed by reference rather than by value because it can own an unbounded
@@ -556,6 +584,10 @@ struct TypeBuilder {
// `j`. Does nothing for equirecursive types.
void setSubType(size_t i, size_t j);
+ // Create a new recursion group covering slots [i, i + length). Groups must
+ // not overlap or go out of bounds.
+ void createRecGroup(size_t i, size_t length);
+
// Returns all of the newly constructed heap types. May only be called once
// all of the heap types have been initialized with `setHeapType`. In nominal
// mode, all of the constructed HeapTypes will be fresh and distinct. In
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index f76db38af..01e1789cd 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -53,8 +53,8 @@ namespace wasm {
static Name STRUCT("struct"), FIELD("field"), ARRAY("array"),
FUNC_SUBTYPE("func_subtype"), STRUCT_SUBTYPE("struct_subtype"),
- ARRAY_SUBTYPE("array_subtype"), EXTENDS("extends"), I8("i8"), I16("i16"),
- RTT("rtt"), DECLARE("declare"), ITEM("item"), OFFSET("offset");
+ ARRAY_SUBTYPE("array_subtype"), EXTENDS("extends"), REC("rec"), I8("i8"),
+ I16("i16"), RTT("rtt"), DECLARE("declare"), ITEM("item"), OFFSET("offset");
static Address getAddress(const Element* s) { return atoll(s->c_str()); }
@@ -458,6 +458,9 @@ void SExpressionWasmBuilder::parseModuleElement(Element& curr) {
if (id == TYPE) {
return; // already done
}
+ if (id == REC) {
+ return; // already done
+ }
if (id == TAG) {
return parseTag(curr);
}
@@ -680,18 +683,29 @@ size_t SExpressionWasmBuilder::parseTypeUse(Element& s,
}
void SExpressionWasmBuilder::preParseHeapTypes(Element& module) {
+ // Iterate through each individual type definition, calling `f` with the
+ // definition and its recursion group number.
auto forEachType = [&](auto f) {
+ size_t groupNumber = 0;
for (auto* elemPtr : module) {
auto& elem = *elemPtr;
if (elementStartsWith(elem, TYPE)) {
- f(elem);
+ f(elem, groupNumber++);
+ } else if (elementStartsWith(elem, REC)) {
+ for (auto* innerPtr : elem) {
+ auto& inner = *innerPtr;
+ if (elementStartsWith(inner, TYPE)) {
+ f(inner, groupNumber);
+ }
+ }
+ ++groupNumber;
}
}
};
+ // Map type names to indices
size_t numTypes = 0;
- forEachType([&](Element& elem) {
- // Map type names to indices
+ forEachType([&](Element& elem, size_t) {
if (elem[1]->dollared()) {
std::string name = elem[1]->c_str();
if (!typeIndices.insert({name, numTypes}).second) {
@@ -703,6 +717,22 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) {
TypeBuilder builder(numTypes);
+ // Create recursion groups
+ size_t currGroup = 0, groupStart = 0, groupLength = 0;
+ auto finishGroup = [&]() {
+ builder.createRecGroup(groupStart, groupLength);
+ groupStart = groupStart + groupLength;
+ groupLength = 0;
+ };
+ forEachType([&](Element&, size_t group) {
+ if (group != currGroup) {
+ finishGroup();
+ currGroup = group;
+ }
+ ++groupLength;
+ });
+ finishGroup();
+
auto parseRefType = [&](Element& elem) -> Type {
// '(' 'ref' 'null'? ht ')'
auto nullable =
@@ -860,22 +890,22 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) {
};
size_t index = 0;
- forEachType([&](Element& elem) {
+ forEachType([&](Element& elem, size_t) {
Element& def = elem[1]->dollared() ? *elem[2] : *elem[1];
Element& kind = *def[0];
- bool nominal =
+ bool hasSupertype =
kind == FUNC_SUBTYPE || kind == STRUCT_SUBTYPE || kind == ARRAY_SUBTYPE;
if (kind == FUNC || kind == FUNC_SUBTYPE) {
- builder[index] = parseSignatureDef(def, nominal);
+ builder[index] = parseSignatureDef(def, hasSupertype);
} else if (kind == STRUCT || kind == STRUCT_SUBTYPE) {
- builder[index] = parseStructDef(def, index, nominal);
+ builder[index] = parseStructDef(def, index, hasSupertype);
} else if (kind == ARRAY || kind == ARRAY_SUBTYPE) {
builder[index] = parseArrayDef(def);
} else {
throw ParseException("unknown heaptype kind", kind.line, kind.col);
}
Element* super = nullptr;
- if (nominal) {
+ if (hasSupertype) {
super = def[def.size() - 1];
if (super->dollared()) {
// OK
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index 7801cc289..4c2922c73 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -95,6 +95,8 @@ struct TypeInfo {
bool operator!=(const TypeInfo& other) const { return !(*this == other); }
};
+using RecGroupInfo = std::vector<HeapType>;
+
struct HeapTypeInfo {
using type_t = HeapType;
// Used in assertions to ensure that temporary types don't leak into the
@@ -108,6 +110,9 @@ struct HeapTypeInfo {
// In nominal or isorecursive mode, the supertype of this HeapType, if it
// exists.
HeapTypeInfo* supertype = nullptr;
+ // In isorecursive mode, the recursion group of this type or null if the
+ // recursion group is trivial (i.e. contains only this type).
+ RecGroupInfo* recGroup = nullptr;
enum Kind {
BasicKind,
SignatureKind,
@@ -581,6 +586,7 @@ bool TypeInfo::operator==(const TypeInfo& other) const {
HeapTypeInfo::HeapTypeInfo(const HeapTypeInfo& other) {
kind = other.kind;
supertype = other.supertype;
+ recGroup = other.recGroup;
switch (kind) {
case BasicKind:
new (&basic) auto(other.basic);
@@ -765,6 +771,10 @@ struct SignatureTypeCache {
static SignatureTypeCache nominalSignatureCache;
+// Keep track of the constructed recursion groups.
+static std::mutex recGroupsMutex;
+static std::vector<std::unique_ptr<RecGroupInfo>> recGroups;
+
} // anonymous namespace
void destroyAllTypesForTestingPurposesOnly() {
@@ -1236,6 +1246,41 @@ HeapType HeapType::getLeastUpperBound(HeapType a, HeapType b) {
return TypeBounder().getLeastUpperBound(a, b);
}
+// Recursion groups with single elements are encoded as that single element's
+// type ID with the low bit set and other recursion groups are encoded with the
+// address of the vector containing their members. These encodings are disjoint
+// because the alignment of the vectors is greater than 1.
+static_assert(alignof(std::vector<HeapType>) > 1);
+
+RecGroup HeapType::getRecGroup() const {
+ assert(!isBasic());
+ if (auto* info = getHeapTypeInfo(*this)->recGroup) {
+ return RecGroup(uintptr_t(info));
+ } else {
+ // Mark the low bit to signify that this is a trivial recursion group and
+ // points to a heap type info rather than a vector of heap types.
+ return RecGroup(id | 1);
+ }
+}
+
+HeapType RecGroup::Iterator::operator*() const {
+ if (parent->id & 1) {
+ // This is a trivial recursion group. Mask off the low bit to recover the
+ // single HeapType.
+ return {HeapType(parent->id & ~(uintptr_t)1)};
+ } else {
+ return (*(std::vector<HeapType>*)parent->id)[index];
+ }
+}
+
+size_t RecGroup::size() const {
+ if (id & 1) {
+ return 1;
+ } else {
+ return ((std::vector<HeapType>*)id)->size();
+ }
+}
+
template<typename T> static std::string genericToString(const T& t) {
std::ostringstream ss;
ss << t;
@@ -1959,7 +2004,8 @@ size_t FiniteShapeHasher::hash(const TypeInfo& info) {
}
size_t FiniteShapeHasher::hash(const HeapTypeInfo& info) {
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive) {
return wasm::hash(uintptr_t(&info));
}
// If the HeapTypeInfo is not finalized, then it is mutable and its shape
@@ -2080,7 +2126,8 @@ bool FiniteShapeEquator::eq(const TypeInfo& a, const TypeInfo& b) {
}
bool FiniteShapeEquator::eq(const HeapTypeInfo& a, const HeapTypeInfo& b) {
- if (getTypeSystem() == TypeSystem::Nominal) {
+ if (getTypeSystem() == TypeSystem::Nominal ||
+ getTypeSystem() == TypeSystem::Isorecursive) {
return &a == &b;
}
if (a.isFinalized != b.isFinalized) {
@@ -2233,8 +2280,14 @@ void TypeGraphWalkerBase<Self>::scanHeapType(HeapType* ht) {
} // anonymous namespace
struct TypeBuilder::Impl {
+ // Store of temporary Types. Types that need to be canonicalized will be
+ // copied into the global TypeStore.
TypeStore typeStore;
+ // Store of temporary recursion groups, which will be moved to the global
+ // collection of recursion groups as part of building.
+ std::vector<std::unique_ptr<RecGroupInfo>> recGroups;
+
struct Entry {
std::unique_ptr<HeapTypeInfo> info;
bool initialized = false;
@@ -2248,6 +2301,7 @@ struct TypeBuilder::Impl {
}
void set(HeapTypeInfo&& hti) {
hti.supertype = info->supertype;
+ hti.recGroup = info->recGroup;
*info = std::move(hti);
info->isTemp = true;
info->isFinalized = false;
@@ -2342,6 +2396,20 @@ void TypeBuilder::setSubType(size_t i, size_t j) {
sub->supertype = super;
}
+void TypeBuilder::createRecGroup(size_t i, size_t length) {
+ assert(i <= size() && i + length <= size() && "group out of bounds");
+ // Only materialize nontrivial recursion groups.
+ if (length < 2) {
+ return;
+ }
+ recGroups.emplace_back(std::make_unique<RecGroupInfo>());
+ for (; length > 0; --length) {
+ auto& info = impl->entries[i + length - 1].info;
+ assert(info->recGroup == nullptr && "group already assigned");
+ info->recGroup = recGroups.back().get();
+ }
+}
+
namespace {
// Helper for TypeBuilder::build() that keeps track of temporary types and
@@ -3012,8 +3080,9 @@ void globallyCanonicalize(CanonicalizationState& state) {
}
void canonicalizeEquirecursive(CanonicalizationState& state) {
- // Equirecursive types always have null supertypes.
+ // Equirecursive types always have null supertypes and recursion groups.
for (auto& info : state.newInfos) {
+ info->recGroup = nullptr;
info->supertype = nullptr;
}
@@ -3036,6 +3105,14 @@ void canonicalizeEquirecursive(CanonicalizationState& state) {
}
void canonicalizeNominal(CanonicalizationState& state) {
+ // TODO: clear recursion groups once we are no longer piggybacking the
+ // isorecursive system on the nominal system.
+ // if (typeSystem != TypeSystem::Isorecursive) {
+ // for (auto& info : state.newInfos) {
+ // assert(info->recGroup == nullptr && "unexpected recursion group");
+ // }
+ // }
+
// Nominal types do not require separate canonicalization, so just validate
// that their subtyping is correct.
@@ -3043,8 +3120,6 @@ void canonicalizeNominal(CanonicalizationState& state) {
auto start = std::chrono::steady_clock::now();
#endif
- assert(typeSystem == TypeSystem::Nominal);
-
// Ensure there are no cycles in the subtype graph. This is the classic DFA
// algorithm for detecting cycles, but in the form of a simple loop because
// each node (type) has at most one child (supertype).
@@ -3115,6 +3190,29 @@ void canonicalizeNominal(CanonicalizationState& state) {
#endif
}
+void canonicalizeIsorecursive(
+ CanonicalizationState& state,
+ std::vector<std::unique_ptr<RecGroupInfo>>& recGroupInfos) {
+ // Fill out the recursion groups.
+ for (auto& info : state.newInfos) {
+ if (info->recGroup != nullptr) {
+ info->recGroup->push_back(asHeapType(info));
+ }
+ }
+
+ // TODO: proper isorecursive validation and canonicalization. For now just
+ // piggyback on the nominal system.
+ canonicalizeNominal(state);
+
+ // Move the recursion groups into the global store. TODO: after proper
+ // isorecursive canonicalization, some groups may no longer be used, so they
+ // will need to be filtered out.
+ std::lock_guard<std::mutex> lock(recGroupsMutex);
+ for (auto& info : recGroupInfos) {
+ recGroups.emplace_back(std::move(info));
+ }
+}
+
void canonicalizeBasicHeapTypes(CanonicalizationState& state) {
// Replace heap types backed by BasicKind HeapTypeInfos with their
// corresponding BasicHeapTypes. The heap types backed by BasicKind
@@ -3174,7 +3272,7 @@ std::vector<HeapType> TypeBuilder::build() {
canonicalizeNominal(state);
break;
case TypeSystem::Isorecursive:
- Fatal() << "Isorecursive types not yet implemented";
+ canonicalizeIsorecursive(state, impl->recGroups);
break;
}