summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThomas Lively <tlively@google.com>2024-08-07 20:49:55 -0400
committerGitHub <noreply@github.com>2024-08-07 17:49:55 -0700
commit2397f2af4512c31e1e54c0e0168302ab1ee06d58 (patch)
treeb469981d3a6bcd93809a7d136bd83d4f6703b887
parentfb6ead80296471276f4cee05f920e6fe8aba67c5 (diff)
downloadbinaryen-2397f2af4512c31e1e54c0e0168302ab1ee06d58.tar.gz
binaryen-2397f2af4512c31e1e54c0e0168302ab1ee06d58.tar.bz2
binaryen-2397f2af4512c31e1e54c0e0168302ab1ee06d58.zip
Add a utility for comparing and hashing rec group shapes (#6808)
This is very similar to the internal utilities for canonicalizing rec groups in the type system implementation, except that the new utility also supports ordered comparison of rec groups, and of course the new utility only uses the public type API. A follow-up PR will replace the internal implementation of rec group comparison and hashing in the type system with this one. Another follow-up PR will use this new utility in a type optimization.
-rw-r--r--src/tools/wasm-fuzz-types.cpp94
-rw-r--r--src/wasm-type-shape.h74
-rw-r--r--src/wasm-type.h2
-rw-r--r--src/wasm/CMakeLists.txt1
-rw-r--r--src/wasm/wasm-type-shape.cpp348
5 files changed, 519 insertions, 0 deletions
diff --git a/src/tools/wasm-fuzz-types.cpp b/src/tools/wasm-fuzz-types.cpp
index ecd8883b1..074d235c9 100644
--- a/src/tools/wasm-fuzz-types.cpp
+++ b/src/tools/wasm-fuzz-types.cpp
@@ -23,6 +23,7 @@
#include "tools/fuzzing/heap-types.h"
#include "tools/fuzzing/random.h"
#include "wasm-type-printing.h"
+#include "wasm-type-shape.h"
namespace wasm {
@@ -54,6 +55,7 @@ struct Fuzzer {
void checkLUBs() const;
void checkCanonicalization();
void checkInhabitable();
+ void checkRecGroupShapes();
};
void Fuzzer::run(uint64_t seed) {
@@ -88,6 +90,7 @@ void Fuzzer::run(uint64_t seed) {
checkLUBs();
checkCanonicalization();
checkInhabitable();
+ checkRecGroupShapes();
}
void Fuzzer::printTypes(const std::vector<HeapType>& types) {
@@ -509,6 +512,97 @@ void Fuzzer::checkInhabitable() {
}
}
+void Fuzzer::checkRecGroupShapes() {
+ using ShapeHash = std::hash<RecGroupShape>;
+
+ // Collect the groups and order types by index.
+ std::vector<std::vector<HeapType>> groups;
+ std::unordered_map<HeapType, Index> typeIndices;
+ for (auto type : types) {
+ typeIndices.insert({type, typeIndices.size()});
+ // We know we are at the beginning of a new rec group when we see a type
+ // that is at index zero of its rec group.
+ if (type.getRecGroupIndex() == 0) {
+ groups.push_back({type});
+ } else {
+ assert(!groups.empty());
+ groups.back().push_back(type);
+ }
+ }
+
+ auto less = [&typeIndices](HeapType a, HeapType b) {
+ return typeIndices.at(a) < typeIndices.at(b);
+ };
+
+ for (size_t i = 0; i < groups.size(); ++i) {
+ ComparableRecGroupShape shape(groups[i], less);
+ // A rec group should compare equal to itself.
+ if (shape != shape) {
+ Fatal() << "Rec group shape " << i << " not equal to itself";
+ }
+
+ // Its hash should be deterministic
+ auto hash = ShapeHash{}(shape);
+ if (hash != ShapeHash{}(shape)) {
+ Fatal() << "Rec group shape " << i << " has non-deterministic hash";
+ }
+
+ // Check how it compares to other groups.
+ for (size_t j = i + 1; j < groups.size(); ++j) {
+ ComparableRecGroupShape other(groups[j], less);
+ bool isLess = shape < other;
+ bool isEq = shape == other;
+ bool isGreater = shape > other;
+ if (isLess + isEq + isGreater == 0) {
+ Fatal() << "Rec groups " << i << " and " << j
+ << " do not have comparable shapes";
+ }
+ if (isLess + isEq + isGreater > 1) {
+ std::string comparisons;
+ auto append = [&](std::string comp) {
+ comparisons = comparisons == "" ? comp : comparisons + ", " + comp;
+ };
+ if (isLess) {
+ append("<");
+ }
+ if (isEq) {
+ append("==");
+ }
+ if (isGreater) {
+ append(">");
+ }
+ Fatal() << "Rec groups " << i << " and " << j << " compare "
+ << comparisons;
+ }
+
+ auto otherHash = ShapeHash{}(other);
+ if (isEq) {
+ if (hash != otherHash) {
+ Fatal() << "Equivalent rec groups " << i << " and " << j
+ << " do not have equivalent hashes";
+ }
+ } else {
+ // Hash collisions are technically possible, but should be rare enough
+ // that we can consider them bugs if the fuzzer finds them.
+ if (hash == otherHash) {
+ Fatal() << "Hash collision between rec groups " << i << " and " << j;
+ }
+ }
+
+ if (j + 1 < groups.size()) {
+ // Check transitivity.
+ RecGroupShape third(groups[j + 1]);
+ if ((isLess && other <= third && shape >= third) ||
+ (isEq && other == third && shape != third) ||
+ (isGreater && other >= third && shape <= third)) {
+ Fatal() << "Comparison between rec groups " << i << ", " << j
+ << ", and " << (j + 1) << " is not transitive";
+ }
+ }
+ }
+ }
+}
+
} // namespace wasm
int main(int argc, const char* argv[]) {
diff --git a/src/wasm-type-shape.h b/src/wasm-type-shape.h
new file mode 100644
index 000000000..5eb4250f0
--- /dev/null
+++ b/src/wasm-type-shape.h
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2024 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_wasm_type_shape_h
+#define wasm_wasm_type_shape_h
+
+#include <functional>
+#include <vector>
+
+#include "wasm-type.h"
+
+namespace wasm {
+
+// Provides hashing and equality comparison for a sequence of types. The hashing
+// and equality differentiate the top-level structure of each type in the
+// sequence and the equality of referenced heap types that are not in the
+// recursion group, but for references to types that are in the recursion group,
+// it considers only the index of the referenced type within the group. That
+// means that recursion groups containing different types can compare and hash
+// as equal as long as their internal structure and external references are the
+// same.
+struct RecGroupShape {
+ const std::vector<HeapType>& types;
+
+ RecGroupShape(const std::vector<HeapType>& types) : types(types) {}
+
+ bool operator==(const RecGroupShape& other) const;
+ bool operator!=(const RecGroupShape& other) const {
+ return !(*this == other);
+ }
+};
+
+// Extends `RecGroupShape` with ordered comparison of rec group structures.
+// Requires the user to supply a global ordering on heap types to be able to
+// compare differing references to external types.
+// TODO: This can all be upgraded to use C++20 three-way comparisons.
+struct ComparableRecGroupShape : RecGroupShape {
+ std::function<bool(HeapType, HeapType)> less;
+
+ ComparableRecGroupShape(const std::vector<HeapType>& types,
+ std::function<bool(HeapType, HeapType)> less)
+ : RecGroupShape(types), less(less) {}
+
+ bool operator<(const RecGroupShape& other) const;
+ bool operator>(const RecGroupShape& other) const;
+ bool operator<=(const RecGroupShape& other) const { return !(*this > other); }
+ bool operator>=(const RecGroupShape& other) const { return !(*this < other); }
+};
+
+} // namespace wasm
+
+namespace std {
+
+template<> class hash<wasm::RecGroupShape> {
+public:
+ size_t operator()(const wasm::RecGroupShape& shape) const;
+};
+
+} // namespace std
+
+#endif // wasm_wasm_type_shape_h
diff --git a/src/wasm-type.h b/src/wasm-type.h
index 88590cc17..e2eec5f35 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -434,6 +434,8 @@ public:
// Get the recursion group for this non-basic type.
RecGroup getRecGroup() const;
+
+ // Get the index of this non-basic type within its recursion group.
size_t getRecGroupIndex() const;
constexpr TypeID getID() const { return id; }
diff --git a/src/wasm/CMakeLists.txt b/src/wasm/CMakeLists.txt
index 16b6d8aed..7a7b26ead 100644
--- a/src/wasm/CMakeLists.txt
+++ b/src/wasm/CMakeLists.txt
@@ -12,6 +12,7 @@ set(wasm_SOURCES
wasm-stack.cpp
wasm-stack-opts.cpp
wasm-type.cpp
+ wasm-type-shape.cpp
wasm-validator.cpp
${wasm_HEADERS}
)
diff --git a/src/wasm/wasm-type-shape.cpp b/src/wasm/wasm-type-shape.cpp
new file mode 100644
index 000000000..99398bb7b
--- /dev/null
+++ b/src/wasm/wasm-type-shape.cpp
@@ -0,0 +1,348 @@
+/*
+ * Copyright 2024 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "wasm-type-shape.h"
+#include "support/hash.h"
+#include "wasm-type.h"
+
+namespace wasm {
+
+namespace {
+
+enum Comparison { EQ, LT, GT };
+
+template<typename CompareTypes> struct RecGroupComparator {
+ std::unordered_map<HeapType, Index> indicesA;
+ std::unordered_map<HeapType, Index> indicesB;
+ CompareTypes compareTypes;
+
+ RecGroupComparator(CompareTypes compareTypes) : compareTypes(compareTypes) {}
+
+ Comparison compare(const RecGroupShape& a, const RecGroupShape& b) {
+ if (a.types.size() != b.types.size()) {
+ return a.types.size() < b.types.size() ? LT : GT;
+ }
+ // Initialize index maps.
+ for (auto type : a.types) {
+ indicesA.insert({type, indicesA.size()});
+ }
+ for (auto type : b.types) {
+ indicesB.insert({type, indicesB.size()});
+ }
+ // Compare types until we find a difference.
+ for (size_t i = 0; i < a.types.size(); ++i) {
+ auto cmp = compareDefinition(a.types[i], b.types[i]);
+ if (cmp == EQ) {
+ continue;
+ }
+ return cmp;
+ }
+ // Never found a difference.
+ return EQ;
+ }
+
+ Comparison compareDefinition(HeapType a, HeapType b) {
+ if (a.isShared() != b.isShared()) {
+ return a.isShared() < b.isShared() ? LT : GT;
+ }
+ if (a.isOpen() != b.isOpen()) {
+ return a.isOpen() < b.isOpen() ? LT : GT;
+ }
+ auto aSuper = a.getDeclaredSuperType();
+ auto bSuper = b.getDeclaredSuperType();
+ if (aSuper.has_value() != bSuper.has_value()) {
+ return aSuper.has_value() < bSuper.has_value() ? LT : GT;
+ }
+ if (aSuper) {
+ if (auto cmp = compare(*aSuper, *bSuper); cmp != EQ) {
+ return cmp;
+ }
+ }
+ auto aKind = a.getKind();
+ auto bKind = b.getKind();
+ if (aKind != bKind) {
+ return aKind < bKind ? LT : GT;
+ }
+ switch (aKind) {
+ case HeapTypeKind::Func:
+ return compare(a.getSignature(), b.getSignature());
+ case HeapTypeKind::Struct:
+ return compare(a.getStruct(), b.getStruct());
+ case HeapTypeKind::Array:
+ return compare(a.getArray(), b.getArray());
+ case HeapTypeKind::Cont:
+ return compare(a.getContinuation(), b.getContinuation());
+ case HeapTypeKind::Basic:
+ break;
+ }
+ WASM_UNREACHABLE("unexpected kind");
+ }
+
+ Comparison compare(Signature a, Signature b) {
+ if (auto cmp = compare(a.params, b.params); cmp != EQ) {
+ return cmp;
+ }
+ return compare(a.results, b.results);
+ }
+
+ Comparison compare(const Struct& a, const Struct& b) {
+ if (a.fields.size() != b.fields.size()) {
+ return a.fields.size() < b.fields.size() ? LT : GT;
+ }
+ for (size_t i = 0; i < a.fields.size(); ++i) {
+ if (auto cmp = compare(a.fields[i], b.fields[i]); cmp != EQ) {
+ return cmp;
+ }
+ }
+ return EQ;
+ }
+
+ Comparison compare(Array a, Array b) { return compare(a.element, b.element); }
+
+ Comparison compare(Continuation a, Continuation b) {
+ return compare(a.type, b.type);
+ }
+
+ Comparison compare(Field a, Field b) {
+ if (a.mutable_ != b.mutable_) {
+ return a.mutable_ < b.mutable_ ? LT : GT;
+ }
+ if (a.isPacked() != b.isPacked()) {
+ return b.isPacked() < a.isPacked() ? LT : GT;
+ }
+ if (a.packedType != b.packedType) {
+ return a.packedType < b.packedType ? LT : GT;
+ }
+ return compare(a.type, b.type);
+ }
+
+ Comparison compare(Type a, Type b) {
+ if (a.isBasic() != b.isBasic()) {
+ return b.isBasic() < a.isBasic() ? LT : GT;
+ }
+ if (a.isBasic()) {
+ if (a.getBasic() != b.getBasic()) {
+ return a.getBasic() < b.getBasic() ? LT : GT;
+ }
+ return EQ;
+ }
+ if (a.isTuple() != b.isTuple()) {
+ return a.isTuple() < b.isTuple() ? LT : GT;
+ }
+ if (a.isTuple()) {
+ return compare(a.getTuple(), b.getTuple());
+ }
+ assert(a.isRef() && b.isRef());
+ if (a.isNullable() != b.isNullable()) {
+ return a.isNullable() < b.isNullable() ? LT : GT;
+ }
+ return compare(a.getHeapType(), b.getHeapType());
+ }
+
+ Comparison compare(const Tuple& a, const Tuple& b) {
+ if (a.size() != b.size()) {
+ return a.size() < b.size() ? LT : GT;
+ }
+ for (size_t i = 0; i < a.size(); ++i) {
+ if (auto cmp = compare(a[i], b[i]); cmp != EQ) {
+ return cmp;
+ }
+ }
+ return EQ;
+ }
+
+ Comparison compare(HeapType a, HeapType b) {
+ if (a.isBasic() != b.isBasic()) {
+ return b.isBasic() < a.isBasic() ? LT : GT;
+ }
+ if (a.isBasic()) {
+ if (a.getID() != b.getID()) {
+ return a.getID() < b.getID() ? LT : GT;
+ }
+ return EQ;
+ }
+ auto itA = indicesA.find(a);
+ auto itB = indicesB.find(b);
+ bool foundA = itA != indicesA.end();
+ bool foundB = itB != indicesB.end();
+ if (foundA != foundB) {
+ return foundB < foundA ? LT : GT;
+ }
+ if (foundA) {
+ auto indexA = itA->second;
+ auto indexB = itB->second;
+ if (indexA != indexB) {
+ return indexA < indexB ? LT : GT;
+ }
+ }
+ // These types are external to the group, so fall back to the provided
+ // comparator.
+ return compareTypes(a, b);
+ }
+};
+
+// Deduction guide to satisfy -Wctad-maybe-unsupported.
+template<typename CompareTypes>
+RecGroupComparator(CompareTypes) -> RecGroupComparator<CompareTypes>;
+
+struct RecGroupHasher {
+ std::unordered_map<HeapType, Index> typeIndices;
+
+ size_t hash(const RecGroupShape& shape) {
+ for (auto type : shape.types) {
+ typeIndices.insert({type, typeIndices.size()});
+ }
+ size_t digest = wasm::hash(shape.types.size());
+ for (auto type : shape.types) {
+ hash_combine(digest, hashDefinition(type));
+ }
+ return digest;
+ }
+
+ size_t hashDefinition(HeapType type) {
+ size_t digest = wasm::hash(type.isShared());
+ wasm::rehash(digest, type.isOpen());
+ auto super = type.getDeclaredSuperType();
+ wasm::rehash(digest, super.has_value());
+ if (super) {
+ hash_combine(digest, hash(*super));
+ }
+ auto kind = type.getKind();
+ // Mix in very random numbers to differentiate the kinds.
+ switch (kind) {
+ case HeapTypeKind::Func:
+ wasm::rehash(digest, 1904683903);
+ hash_combine(digest, hash(type.getSignature()));
+ return digest;
+ case HeapTypeKind::Struct:
+ wasm::rehash(digest, 3273309159);
+ hash_combine(digest, hash(type.getStruct()));
+ return digest;
+ case HeapTypeKind::Array:
+ wasm::rehash(digest, 4254688366);
+ hash_combine(digest, hash(type.getArray()));
+ return digest;
+ case HeapTypeKind::Cont:
+ wasm::rehash(digest, 2381496927);
+ hash_combine(digest, hash(type.getContinuation()));
+ return digest;
+ case HeapTypeKind::Basic:
+ break;
+ }
+ WASM_UNREACHABLE("unexpected kind");
+ }
+
+ size_t hash(Signature sig) {
+ size_t digest = hash(sig.params);
+ hash_combine(digest, hash(sig.results));
+ return digest;
+ }
+
+ size_t hash(const Struct& struct_) {
+ size_t digest = wasm::hash(struct_.fields.size());
+ for (auto field : struct_.fields) {
+ hash_combine(digest, hash(field));
+ }
+ return digest;
+ }
+
+ size_t hash(Array array) { return hash(array.element); }
+
+ size_t hash(Continuation cont) { return hash(cont.type); }
+
+ size_t hash(Field field) {
+ size_t digest = wasm::hash(field.mutable_);
+ wasm::rehash(digest, field.packedType);
+ hash_combine(digest, hash(field.type));
+ return digest;
+ }
+
+ size_t hash(Type type) {
+ size_t digest = wasm::hash(type.isBasic());
+ if (type.isBasic()) {
+ wasm::rehash(digest, type.getBasic());
+ return digest;
+ }
+ wasm::rehash(digest, type.isTuple());
+ if (type.isTuple()) {
+ hash_combine(digest, hash(type.getTuple()));
+ return digest;
+ }
+ assert(type.isRef());
+ wasm::rehash(digest, type.isNullable());
+ hash_combine(digest, hash(type.getHeapType()));
+ return digest;
+ }
+
+ size_t hash(const Tuple& tuple) {
+ size_t digest = wasm::hash(tuple.size());
+ for (auto type : tuple) {
+ hash_combine(digest, hash(type));
+ }
+ return digest;
+ }
+
+ size_t hash(HeapType type) {
+ size_t digest = wasm::hash(type.isBasic());
+ if (type.isBasic()) {
+ wasm::rehash(digest, type.getID());
+ return digest;
+ }
+ auto it = typeIndices.find(type);
+ wasm::rehash(digest, it != typeIndices.end());
+ if (it != typeIndices.end()) {
+ wasm::rehash(digest, it->second);
+ return digest;
+ }
+ wasm::rehash(digest, type.getID());
+ return digest;
+ }
+};
+
+Comparison compareComparable(const ComparableRecGroupShape& a,
+ const RecGroupShape& b) {
+ return RecGroupComparator{[&](HeapType ht1, HeapType ht2) {
+ return a.less(ht1, ht2) ? LT : a.less(ht2, ht1) ? GT : EQ;
+ }}
+ .compare(a, b);
+}
+
+} // anonymous namespace
+
+bool RecGroupShape::operator==(const RecGroupShape& other) const {
+ return EQ == RecGroupComparator{[](HeapType a, HeapType b) {
+ return a == b ? EQ : LT;
+ }}.compare(*this, other);
+}
+
+bool ComparableRecGroupShape::operator<(const RecGroupShape& other) const {
+ return LT == compareComparable(*this, other);
+}
+
+bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const {
+ return GT == compareComparable(*this, other);
+}
+
+} // namespace wasm
+
+namespace std {
+
+size_t
+hash<wasm::RecGroupShape>::operator()(const wasm::RecGroupShape& shape) const {
+ return wasm::RecGroupHasher{}.hash(shape);
+}
+
+} // namespace std