summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/passes/TypeMerging.cpp211
-rw-r--r--test/lit/passes/type-merging.wast162
2 files changed, 319 insertions, 54 deletions
diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp
index 365a6c213..6f39d586f 100644
--- a/src/passes/TypeMerging.cpp
+++ b/src/passes/TypeMerging.cpp
@@ -126,14 +126,35 @@ struct TypeMerging : public Pass {
std::vector<HeapType> getPublicChildren(HeapType type);
DFA::State<HeapType> makeDFAState(HeapType type);
void applyMerges(const TypeUpdates& merges);
+};
- bool mayBeMergeable(HeapType sub, HeapType super);
- bool mayBeMergeable(const Struct& a, const Struct& b);
- bool mayBeMergeable(Array a, Array b);
- bool mayBeMergeable(Signature a, Signature b);
- bool mayBeMergeable(Field a, Field b);
- bool mayBeMergeable(Type a, Type b);
- bool mayBeMergeable(const Tuple& a, const Tuple& b);
+// Hash and equality-compare HeapTypes based on their top-level structure (i.e.
+// "shape"), ignoring nontrivial heap type children that will not be
+// differentiated between until we run the DFA partition refinement.
+bool shapeEq(HeapType a, HeapType b);
+bool shapeEq(const Struct& a, const Struct& b);
+bool shapeEq(Array a, Array b);
+bool shapeEq(Signature a, Signature b);
+bool shapeEq(Field a, Field b);
+bool shapeEq(Type a, Type b);
+bool shapeEq(const Tuple& a, const Tuple& b);
+
+size_t shapeHash(HeapType a);
+size_t shapeHash(const Struct& a);
+size_t shapeHash(Array a);
+size_t shapeHash(Signature a);
+size_t shapeHash(Field a);
+size_t shapeHash(Type a);
+size_t shapeHash(const Tuple& a);
+
+struct ShapeEq {
+ bool operator()(const HeapType& a, const HeapType& b) const {
+ return shapeEq(a, b);
+ }
+};
+
+struct ShapeHash {
+ size_t operator()(const HeapType& type) const { return shapeHash(type); }
};
void TypeMerging::run(Module* module_) {
@@ -166,11 +187,15 @@ void TypeMerging::run(Module* module_) {
// Map each type to its partition in the list.
std::unordered_map<HeapType, Partitions::iterator> typePartitions;
+ // Map the top-level structures of root types to their partitions in the list.
+ std::unordered_map<HeapType, Partitions::iterator, ShapeHash, ShapeEq>
+ shapePartitions;
+
#if TYPE_MERGING_DEBUG
+ using Fallback = IndexedTypeNameGenerator<DefaultTypeNameGenerator>;
+ Fallback printPrivate(privates, "private.");
+ ModuleTypeNameGenerator<Fallback> print(*module, printPrivate);
auto dumpPartitions = [&]() {
- using Fallback = IndexedTypeNameGenerator<DefaultTypeNameGenerator>;
- Fallback printPrivate(privates, "private.");
- ModuleTypeNameGenerator<Fallback> print(*module, printPrivate);
size_t i = 0;
for (auto& partition : partitions) {
std::cerr << i++ << ": " << print(partition[0].val) << "\n";
@@ -182,8 +207,11 @@ void TypeMerging::run(Module* module_) {
};
#endif // TYPE_MERGING_DEBUG
- // Ensure the type has a partition and return a reference to it.
- auto ensurePartition = [&](HeapType type) {
+ // Ensure the type has a partition and return a reference to it. Since we
+ // merge up the type tree and visit supertypes first, the partition usually
+ // already exists. The exception is when the supertype is public, in which
+ // case we might not have created a partition for it yet.
+ auto ensurePartition = [&](HeapType type) -> Partitions::iterator {
auto [it, inserted] = typePartitions.insert({type, partitions.end()});
if (inserted) {
it->second = partitions.insert(partitions.end(), {makeDFAState(type)});
@@ -191,6 +219,16 @@ void TypeMerging::run(Module* module_) {
return it->second;
};
+ // Similar to the above, but look up or create a partition associated with the
+ // type's top-level shape rather than its identity.
+ auto ensureShapePartition = [&](HeapType type) -> Partitions::iterator {
+ auto [it, inserted] = shapePartitions.insert({type, partitions.end()});
+ if (inserted) {
+ it->second = partitions.insert(partitions.end(), Partition{});
+ }
+ return it->second;
+ };
+
// For each type, either create a new partition or add to its supertype's
// partition.
for (auto type : HeapTypeOrdering::SupertypesFirst(privates)) {
@@ -199,19 +237,31 @@ void TypeMerging::run(Module* module_) {
for (auto child : getPublicChildren(type)) {
ensurePartition(child);
}
+ // If the type is distinguished by the module or public, we cannot merge it,
+ // so create a new partition for it.
+ if (castTypes.count(type) || !privateTypes.count(type)) {
+ ensurePartition(type);
+ continue;
+ }
auto super = type.getSuperType();
- if (!super || !mayBeMergeable(type, *super)) {
- // Create a new partition containing just this type.
+ // If there is no supertype to merge with, then we can still merge with
+ // other root types with the same structure. Find and add to the partition
+ // with other such types.
+ if (!super) {
+ auto it = ensureShapePartition(type);
+ it->push_back(makeDFAState(type));
+ typePartitions[type] = it;
+ continue;
+ }
+ // If this type refines its supertype in some way, then we cannot merge it.
+ // Create a new partition for it.
+ if (!shapeEq(type, *super)) {
ensurePartition(type);
continue;
}
- // The current type and its supertype have the same top-level structure, so
- // merge the current type's partition into its supertype's partition. First,
- // find the supertype's partition. The supertype's partition may not exist
- // yet if the supertype is public since we don't visit public types in this
- // loop. In that case we can create a new partition for the supertype
- // because merging private types into public supertypes is fine. (In
- // contrast, merging public types into their supertypes is not fine.)
+ // The current type and its supertype have the same top-level structure and
+ // are not distinguished, so add the current type to its supertype's
+ // partition.
auto it = ensurePartition(*super);
it->push_back(makeDFAState(type));
typePartitions[type] = it;
@@ -363,59 +413,91 @@ void TypeMerging::applyMerges(const TypeUpdates& merges) {
} rewriter(*module, merges);
}
-bool TypeMerging::mayBeMergeable(HeapType sub, HeapType super) {
- // If the type is distinguishable from its supertype or public, we cannot
- // merge it.
- if (castTypes.count(sub) || !privateTypes.count(sub)) {
- return false;
+bool shapeEq(HeapType a, HeapType b) {
+ // Check whether `a` and `b` have the same top-level structure, including the
+ // position and identity of any children that are not included as transitions
+ // in the DFA, i.e. any children that are not nontrivial references.
+ if (a.isStruct() && b.isStruct()) {
+ return shapeEq(a.getStruct(), b.getStruct());
}
- // Check whether `sub` and `super` have the same top-level structure,
- // including the position and identity of any children that are not included
- // as transitions in the DFA, i.e. any children that are not nontrivial
- // references.
- if (sub.isStruct() && super.isStruct()) {
- return mayBeMergeable(sub.getStruct(), super.getStruct());
+ if (a.isArray() && b.isArray()) {
+ return shapeEq(a.getArray(), b.getArray());
}
- if (sub.isArray() && super.isArray()) {
- return mayBeMergeable(sub.getArray(), super.getArray());
- }
- if (sub.isSignature() && super.isSignature()) {
- return mayBeMergeable(sub.getSignature(), super.getSignature());
+ if (a.isSignature() && b.isSignature()) {
+ return shapeEq(a.getSignature(), b.getSignature());
}
return false;
}
-bool TypeMerging::mayBeMergeable(const Struct& a, const Struct& b) {
+size_t shapeHash(HeapType a) {
+ size_t digest;
+ if (a.isStruct()) {
+ digest = hash(0);
+ hash_combine(digest, shapeHash(a.getStruct()));
+ } else if (a.isArray()) {
+ digest = hash(1);
+ hash_combine(digest, shapeHash(a.getArray()));
+ } else if (a.isSignature()) {
+ digest = hash(2);
+ hash_combine(digest, shapeHash(a.getSignature()));
+ } else {
+ WASM_UNREACHABLE("unexpected kind");
+ }
+ return digest;
+}
+
+bool shapeEq(const Struct& a, const Struct& b) {
if (a.fields.size() != b.fields.size()) {
return false;
}
for (size_t i = 0; i < a.fields.size(); ++i) {
- if (!mayBeMergeable(a.fields[i], b.fields[i])) {
+ if (!shapeEq(a.fields[i], b.fields[i])) {
return false;
}
}
return true;
}
-bool TypeMerging::mayBeMergeable(Array a, Array b) {
- return mayBeMergeable(a.element, b.element);
+size_t shapeHash(const Struct& a) {
+ size_t digest = hash(a.fields.size());
+ for (size_t i = 0; i < a.fields.size(); ++i) {
+ hash_combine(digest, shapeHash(a.fields[i]));
+ }
+ return digest;
+}
+
+bool shapeEq(Array a, Array b) { return shapeEq(a.element, b.element); }
+
+size_t shapeHash(Array a) { return shapeHash(a.element); }
+
+bool shapeEq(Signature a, Signature b) {
+ return shapeEq(a.params, b.params) && shapeEq(a.results, b.results);
+}
+
+size_t shapeHash(Signature a) {
+ auto digest = shapeHash(a.params);
+ hash_combine(digest, shapeHash(a.results));
+ return digest;
}
-bool TypeMerging::mayBeMergeable(Signature a, Signature b) {
- return mayBeMergeable(a.params, b.params) &&
- mayBeMergeable(a.results, b.results);
+bool shapeEq(Field a, Field b) {
+ return a.packedType == b.packedType && a.mutable_ == b.mutable_ &&
+ shapeEq(a.type, b.type);
}
-bool TypeMerging::mayBeMergeable(Field a, Field b) {
- return a.packedType == b.packedType && mayBeMergeable(a.type, b.type);
+size_t shapeHash(Field a) {
+ auto digest = hash((int)a.packedType);
+ rehash(digest, (int)a.mutable_);
+ hash_combine(digest, shapeHash(a.type));
+ return digest;
}
-bool TypeMerging::mayBeMergeable(Type a, Type b) {
+bool shapeEq(Type a, Type b) {
if (a == b) {
return true;
}
if (a.isTuple() && b.isTuple()) {
- return mayBeMergeable(a.getTuple(), b.getTuple());
+ return shapeEq(a.getTuple(), b.getTuple());
}
// The only thing allowed to differ is the non-basic heap type child, since we
// don't know before running the DFA partition refinement whether different
@@ -433,18 +515,47 @@ bool TypeMerging::mayBeMergeable(Type a, Type b) {
return true;
}
-bool TypeMerging::mayBeMergeable(const Tuple& a, const Tuple& b) {
+size_t shapeHash(Type a) {
+ if (a.isTuple()) {
+ auto digest = hash(0);
+ hash_combine(digest, shapeHash(a.getTuple()));
+ return digest;
+ }
+ auto digest = hash(1);
+ if (!a.isRef()) {
+ rehash(digest, 2);
+ return digest;
+ }
+ if (a.getHeapType().isBasic()) {
+ rehash(digest, 3);
+ rehash(digest, a.getHeapType().getID());
+ return digest;
+ }
+ rehash(digest, 4);
+ rehash(digest, (int)a.getNullability());
+ return digest;
+}
+
+bool shapeEq(const Tuple& a, const Tuple& b) {
if (a.types.size() != b.types.size()) {
return false;
}
for (size_t i = 0; i < a.types.size(); ++i) {
- if (!mayBeMergeable(a.types[i], b.types[i])) {
+ if (!shapeEq(a.types[i], b.types[i])) {
return false;
}
}
return true;
}
+size_t shapeHash(const Tuple& a) {
+ auto digest = hash(a.types.size());
+ for (auto type : a.types) {
+ hash_combine(digest, shapeHash(type));
+ }
+ return digest;
+}
+
} // anonymous namespace
Pass* createTypeMergingPass() { return new TypeMerging(); }
diff --git a/test/lit/passes/type-merging.wast b/test/lit/passes/type-merging.wast
index 578d84ebb..c20e212c3 100644
--- a/test/lit/passes/type-merging.wast
+++ b/test/lit/passes/type-merging.wast
@@ -233,6 +233,35 @@
(module
(rec
;; CHECK: (rec
+ ;; CHECK-NEXT: (type $A (struct (field (ref null $A))))
+ (type $A (struct (ref null $X)))
+ (type $B (struct_subtype (ref null $Y) $A))
+ (type $X (struct (ref null $A)))
+ (type $Y (struct_subtype (ref null $B) $X))
+ )
+
+ ;; CHECK: (type $none_=>_none (func))
+
+ ;; CHECK: (func $foo (type $none_=>_none)
+ ;; CHECK-NEXT: (local $a (ref null $A))
+ ;; CHECK-NEXT: (local $b (ref null $A))
+ ;; CHECK-NEXT: (local $x (ref null $A))
+ ;; CHECK-NEXT: (local $y (ref null $A))
+ ;; CHECK-NEXT: (nop)
+ ;; CHECK-NEXT: )
+ (func $foo
+ ;; As above, but now the A->B and X->Y chains are not differentiated by the
+ ;; i32 and f32, so all four types can be merged into a single type.
+ (local $a (ref null $A))
+ (local $b (ref null $B))
+ (local $x (ref null $X))
+ (local $y (ref null $Y))
+ )
+)
+
+(module
+ (rec
+ ;; CHECK: (rec
;; CHECK-NEXT: (type $X (struct (field (ref null $A))))
;; CHECK: (type $A (struct (field (ref null $X))))
@@ -249,16 +278,141 @@
;; CHECK-NEXT: (local $b (ref null $A))
;; CHECK-NEXT: (local $x (ref null $X))
;; CHECK-NEXT: (local $y (ref null $X))
- ;; CHECK-NEXT: (nop)
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (ref.cast $A
+ ;; CHECK-NEXT: (local.get $a)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $foo
- ;; As above, but now the A->B and X->Y chains are not differentiated by the
- ;; i32 and f32, so all four types can be merged into a single type.
- ;; TODO: This is not yet implemented. Merge the top level types.
+ ;; As above, but now there is a cast to A that prevents A and X from being
+ ;; merged.
(local $a (ref null $A))
(local $b (ref null $B))
(local $x (ref null $X))
(local $y (ref null $Y))
+
+ (drop
+ (ref.cast $A
+ (local.get $a)
+ )
+ )
+ )
+)
+
+(module
+ ;; Check that a diversity of root types are merged correctly.
+ ;; CHECK: (rec
+ ;; CHECK-NEXT: (type $M (struct (field i32) (field i32)))
+
+ ;; CHECK: (type $L (struct (field i32)))
+
+ ;; CHECK: (type $K (func (param i32 i32 i32) (result i32 i32)))
+
+ ;; CHECK: (type $J (func (param i32 i32) (result i32 i32 i32)))
+
+ ;; CHECK: (type $I (array (ref $A)))
+
+ ;; CHECK: (type $H (array (ref null $A)))
+
+ ;; CHECK: (type $G (array (ref any)))
+
+ ;; CHECK: (type $F (array anyref))
+
+ ;; CHECK: (type $E (array i64))
+
+ ;; CHECK: (type $D (array i32))
+
+ ;; CHECK: (type $C (array i16))
+
+ ;; CHECK: (type $B (array (mut i8)))
+
+ ;; CHECK: (type $A (array i8))
+ (type $A (array i8))
+ (type $A' (array i8))
+ (type $B (array (mut i8)))
+ (type $B' (array (mut i8)))
+ (type $C (array i16))
+ (type $C' (array i16))
+ (type $D (array i32))
+ (type $D' (array i32))
+ (type $E (array i64))
+ (type $E' (array i64))
+ (type $F (array anyref))
+ (type $F' (array anyref))
+ (type $G (array (ref any)))
+ (type $G' (array (ref any)))
+ (type $H (array (ref null $A)))
+ (type $H' (array (ref null $A)))
+ (type $I (array (ref $A)))
+ (type $I' (array (ref $A)))
+ (type $J (func (param i32 i32) (result i32 i32 i32)))
+ (type $J' (func (param i32 i32) (result i32 i32 i32)))
+ (type $K (func (param i32 i32 i32) (result i32 i32)))
+ (type $K' (func (param i32 i32 i32) (result i32 i32)))
+ (type $L (struct i32))
+ (type $L' (struct i32))
+ (type $M (struct i32 i32))
+ (type $M' (struct i32 i32))
+
+ ;; CHECK: (type $none_=>_none (func))
+
+ ;; CHECK: (func $foo (type $none_=>_none)
+ ;; CHECK-NEXT: (local $a (ref null $A))
+ ;; CHECK-NEXT: (local $a' (ref null $A))
+ ;; CHECK-NEXT: (local $b (ref null $B))
+ ;; CHECK-NEXT: (local $b' (ref null $B))
+ ;; CHECK-NEXT: (local $c (ref null $C))
+ ;; CHECK-NEXT: (local $c' (ref null $C))
+ ;; CHECK-NEXT: (local $d (ref null $D))
+ ;; CHECK-NEXT: (local $d' (ref null $D))
+ ;; CHECK-NEXT: (local $e (ref null $E))
+ ;; CHECK-NEXT: (local $e' (ref null $E))
+ ;; CHECK-NEXT: (local $f (ref null $F))
+ ;; CHECK-NEXT: (local $f' (ref null $F))
+ ;; CHECK-NEXT: (local $g (ref null $G))
+ ;; CHECK-NEXT: (local $g' (ref null $G))
+ ;; CHECK-NEXT: (local $h (ref null $H))
+ ;; CHECK-NEXT: (local $h' (ref null $H))
+ ;; CHECK-NEXT: (local $i (ref null $I))
+ ;; CHECK-NEXT: (local $i' (ref null $I))
+ ;; CHECK-NEXT: (local $j (ref null $J))
+ ;; CHECK-NEXT: (local $j' (ref null $J))
+ ;; CHECK-NEXT: (local $k (ref null $K))
+ ;; CHECK-NEXT: (local $k' (ref null $K))
+ ;; CHECK-NEXT: (local $l (ref null $L))
+ ;; CHECK-NEXT: (local $l' (ref null $L))
+ ;; CHECK-NEXT: (local $m (ref null $M))
+ ;; CHECK-NEXT: (local $m' (ref null $M))
+ ;; CHECK-NEXT: (nop)
+ ;; CHECK-NEXT: )
+ (func $foo
+ (local $a (ref null $A))
+ (local $a' (ref null $A'))
+ (local $b (ref null $B))
+ (local $b' (ref null $B'))
+ (local $c (ref null $C))
+ (local $c' (ref null $C'))
+ (local $d (ref null $D))
+ (local $d' (ref null $D'))
+ (local $e (ref null $E))
+ (local $e' (ref null $E'))
+ (local $f (ref null $F))
+ (local $f' (ref null $F'))
+ (local $g (ref null $G))
+ (local $g' (ref null $G'))
+ (local $h (ref null $H))
+ (local $h' (ref null $H'))
+ (local $i (ref null $I))
+ (local $i' (ref null $I'))
+ (local $j (ref null $J))
+ (local $j' (ref null $J'))
+ (local $k (ref null $K))
+ (local $k' (ref null $K'))
+ (local $l (ref null $L))
+ (local $l' (ref null $L'))
+ (local $m (ref null $M))
+ (local $m' (ref null $M'))
)
)