diff options
-rw-r--r-- | src/tools/wasm-fuzz-types.cpp | 94 | ||||
-rw-r--r-- | src/wasm-type-shape.h | 74 | ||||
-rw-r--r-- | src/wasm-type.h | 2 | ||||
-rw-r--r-- | src/wasm/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/wasm/wasm-type-shape.cpp | 348 |
5 files changed, 519 insertions, 0 deletions
diff --git a/src/tools/wasm-fuzz-types.cpp b/src/tools/wasm-fuzz-types.cpp index ecd8883b1..074d235c9 100644 --- a/src/tools/wasm-fuzz-types.cpp +++ b/src/tools/wasm-fuzz-types.cpp @@ -23,6 +23,7 @@ #include "tools/fuzzing/heap-types.h" #include "tools/fuzzing/random.h" #include "wasm-type-printing.h" +#include "wasm-type-shape.h" namespace wasm { @@ -54,6 +55,7 @@ struct Fuzzer { void checkLUBs() const; void checkCanonicalization(); void checkInhabitable(); + void checkRecGroupShapes(); }; void Fuzzer::run(uint64_t seed) { @@ -88,6 +90,7 @@ void Fuzzer::run(uint64_t seed) { checkLUBs(); checkCanonicalization(); checkInhabitable(); + checkRecGroupShapes(); } void Fuzzer::printTypes(const std::vector<HeapType>& types) { @@ -509,6 +512,97 @@ void Fuzzer::checkInhabitable() { } } +void Fuzzer::checkRecGroupShapes() { + using ShapeHash = std::hash<RecGroupShape>; + + // Collect the groups and order types by index. + std::vector<std::vector<HeapType>> groups; + std::unordered_map<HeapType, Index> typeIndices; + for (auto type : types) { + typeIndices.insert({type, typeIndices.size()}); + // We know we are at the beginning of a new rec group when we see a type + // that is at index zero of its rec group. + if (type.getRecGroupIndex() == 0) { + groups.push_back({type}); + } else { + assert(!groups.empty()); + groups.back().push_back(type); + } + } + + auto less = [&typeIndices](HeapType a, HeapType b) { + return typeIndices.at(a) < typeIndices.at(b); + }; + + for (size_t i = 0; i < groups.size(); ++i) { + ComparableRecGroupShape shape(groups[i], less); + // A rec group should compare equal to itself. + if (shape != shape) { + Fatal() << "Rec group shape " << i << " not equal to itself"; + } + + // Its hash should be deterministic + auto hash = ShapeHash{}(shape); + if (hash != ShapeHash{}(shape)) { + Fatal() << "Rec group shape " << i << " has non-deterministic hash"; + } + + // Check how it compares to other groups. + for (size_t j = i + 1; j < groups.size(); ++j) { + ComparableRecGroupShape other(groups[j], less); + bool isLess = shape < other; + bool isEq = shape == other; + bool isGreater = shape > other; + if (isLess + isEq + isGreater == 0) { + Fatal() << "Rec groups " << i << " and " << j + << " do not have comparable shapes"; + } + if (isLess + isEq + isGreater > 1) { + std::string comparisons; + auto append = [&](std::string comp) { + comparisons = comparisons == "" ? comp : comparisons + ", " + comp; + }; + if (isLess) { + append("<"); + } + if (isEq) { + append("=="); + } + if (isGreater) { + append(">"); + } + Fatal() << "Rec groups " << i << " and " << j << " compare " + << comparisons; + } + + auto otherHash = ShapeHash{}(other); + if (isEq) { + if (hash != otherHash) { + Fatal() << "Equivalent rec groups " << i << " and " << j + << " do not have equivalent hashes"; + } + } else { + // Hash collisions are technically possible, but should be rare enough + // that we can consider them bugs if the fuzzer finds them. + if (hash == otherHash) { + Fatal() << "Hash collision between rec groups " << i << " and " << j; + } + } + + if (j + 1 < groups.size()) { + // Check transitivity. + RecGroupShape third(groups[j + 1]); + if ((isLess && other <= third && shape >= third) || + (isEq && other == third && shape != third) || + (isGreater && other >= third && shape <= third)) { + Fatal() << "Comparison between rec groups " << i << ", " << j + << ", and " << (j + 1) << " is not transitive"; + } + } + } + } +} + } // namespace wasm int main(int argc, const char* argv[]) { diff --git a/src/wasm-type-shape.h b/src/wasm-type-shape.h new file mode 100644 index 000000000..5eb4250f0 --- /dev/null +++ b/src/wasm-type-shape.h @@ -0,0 +1,74 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_wasm_type_shape_h +#define wasm_wasm_type_shape_h + +#include <functional> +#include <vector> + +#include "wasm-type.h" + +namespace wasm { + +// Provides hashing and equality comparison for a sequence of types. The hashing +// and equality differentiate the top-level structure of each type in the +// sequence and the equality of referenced heap types that are not in the +// recursion group, but for references to types that are in the recursion group, +// it considers only the index of the referenced type within the group. That +// means that recursion groups containing different types can compare and hash +// as equal as long as their internal structure and external references are the +// same. +struct RecGroupShape { + const std::vector<HeapType>& types; + + RecGroupShape(const std::vector<HeapType>& types) : types(types) {} + + bool operator==(const RecGroupShape& other) const; + bool operator!=(const RecGroupShape& other) const { + return !(*this == other); + } +}; + +// Extends `RecGroupShape` with ordered comparison of rec group structures. +// Requires the user to supply a global ordering on heap types to be able to +// compare differing references to external types. +// TODO: This can all be upgraded to use C++20 three-way comparisons. +struct ComparableRecGroupShape : RecGroupShape { + std::function<bool(HeapType, HeapType)> less; + + ComparableRecGroupShape(const std::vector<HeapType>& types, + std::function<bool(HeapType, HeapType)> less) + : RecGroupShape(types), less(less) {} + + bool operator<(const RecGroupShape& other) const; + bool operator>(const RecGroupShape& other) const; + bool operator<=(const RecGroupShape& other) const { return !(*this > other); } + bool operator>=(const RecGroupShape& other) const { return !(*this < other); } +}; + +} // namespace wasm + +namespace std { + +template<> class hash<wasm::RecGroupShape> { +public: + size_t operator()(const wasm::RecGroupShape& shape) const; +}; + +} // namespace std + +#endif // wasm_wasm_type_shape_h diff --git a/src/wasm-type.h b/src/wasm-type.h index 88590cc17..e2eec5f35 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -434,6 +434,8 @@ public: // Get the recursion group for this non-basic type. RecGroup getRecGroup() const; + + // Get the index of this non-basic type within its recursion group. size_t getRecGroupIndex() const; constexpr TypeID getID() const { return id; } diff --git a/src/wasm/CMakeLists.txt b/src/wasm/CMakeLists.txt index 16b6d8aed..7a7b26ead 100644 --- a/src/wasm/CMakeLists.txt +++ b/src/wasm/CMakeLists.txt @@ -12,6 +12,7 @@ set(wasm_SOURCES wasm-stack.cpp wasm-stack-opts.cpp wasm-type.cpp + wasm-type-shape.cpp wasm-validator.cpp ${wasm_HEADERS} ) diff --git a/src/wasm/wasm-type-shape.cpp b/src/wasm/wasm-type-shape.cpp new file mode 100644 index 000000000..99398bb7b --- /dev/null +++ b/src/wasm/wasm-type-shape.cpp @@ -0,0 +1,348 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "wasm-type-shape.h" +#include "support/hash.h" +#include "wasm-type.h" + +namespace wasm { + +namespace { + +enum Comparison { EQ, LT, GT }; + +template<typename CompareTypes> struct RecGroupComparator { + std::unordered_map<HeapType, Index> indicesA; + std::unordered_map<HeapType, Index> indicesB; + CompareTypes compareTypes; + + RecGroupComparator(CompareTypes compareTypes) : compareTypes(compareTypes) {} + + Comparison compare(const RecGroupShape& a, const RecGroupShape& b) { + if (a.types.size() != b.types.size()) { + return a.types.size() < b.types.size() ? LT : GT; + } + // Initialize index maps. + for (auto type : a.types) { + indicesA.insert({type, indicesA.size()}); + } + for (auto type : b.types) { + indicesB.insert({type, indicesB.size()}); + } + // Compare types until we find a difference. + for (size_t i = 0; i < a.types.size(); ++i) { + auto cmp = compareDefinition(a.types[i], b.types[i]); + if (cmp == EQ) { + continue; + } + return cmp; + } + // Never found a difference. + return EQ; + } + + Comparison compareDefinition(HeapType a, HeapType b) { + if (a.isShared() != b.isShared()) { + return a.isShared() < b.isShared() ? LT : GT; + } + if (a.isOpen() != b.isOpen()) { + return a.isOpen() < b.isOpen() ? LT : GT; + } + auto aSuper = a.getDeclaredSuperType(); + auto bSuper = b.getDeclaredSuperType(); + if (aSuper.has_value() != bSuper.has_value()) { + return aSuper.has_value() < bSuper.has_value() ? LT : GT; + } + if (aSuper) { + if (auto cmp = compare(*aSuper, *bSuper); cmp != EQ) { + return cmp; + } + } + auto aKind = a.getKind(); + auto bKind = b.getKind(); + if (aKind != bKind) { + return aKind < bKind ? LT : GT; + } + switch (aKind) { + case HeapTypeKind::Func: + return compare(a.getSignature(), b.getSignature()); + case HeapTypeKind::Struct: + return compare(a.getStruct(), b.getStruct()); + case HeapTypeKind::Array: + return compare(a.getArray(), b.getArray()); + case HeapTypeKind::Cont: + return compare(a.getContinuation(), b.getContinuation()); + case HeapTypeKind::Basic: + break; + } + WASM_UNREACHABLE("unexpected kind"); + } + + Comparison compare(Signature a, Signature b) { + if (auto cmp = compare(a.params, b.params); cmp != EQ) { + return cmp; + } + return compare(a.results, b.results); + } + + Comparison compare(const Struct& a, const Struct& b) { + if (a.fields.size() != b.fields.size()) { + return a.fields.size() < b.fields.size() ? LT : GT; + } + for (size_t i = 0; i < a.fields.size(); ++i) { + if (auto cmp = compare(a.fields[i], b.fields[i]); cmp != EQ) { + return cmp; + } + } + return EQ; + } + + Comparison compare(Array a, Array b) { return compare(a.element, b.element); } + + Comparison compare(Continuation a, Continuation b) { + return compare(a.type, b.type); + } + + Comparison compare(Field a, Field b) { + if (a.mutable_ != b.mutable_) { + return a.mutable_ < b.mutable_ ? LT : GT; + } + if (a.isPacked() != b.isPacked()) { + return b.isPacked() < a.isPacked() ? LT : GT; + } + if (a.packedType != b.packedType) { + return a.packedType < b.packedType ? LT : GT; + } + return compare(a.type, b.type); + } + + Comparison compare(Type a, Type b) { + if (a.isBasic() != b.isBasic()) { + return b.isBasic() < a.isBasic() ? LT : GT; + } + if (a.isBasic()) { + if (a.getBasic() != b.getBasic()) { + return a.getBasic() < b.getBasic() ? LT : GT; + } + return EQ; + } + if (a.isTuple() != b.isTuple()) { + return a.isTuple() < b.isTuple() ? LT : GT; + } + if (a.isTuple()) { + return compare(a.getTuple(), b.getTuple()); + } + assert(a.isRef() && b.isRef()); + if (a.isNullable() != b.isNullable()) { + return a.isNullable() < b.isNullable() ? LT : GT; + } + return compare(a.getHeapType(), b.getHeapType()); + } + + Comparison compare(const Tuple& a, const Tuple& b) { + if (a.size() != b.size()) { + return a.size() < b.size() ? LT : GT; + } + for (size_t i = 0; i < a.size(); ++i) { + if (auto cmp = compare(a[i], b[i]); cmp != EQ) { + return cmp; + } + } + return EQ; + } + + Comparison compare(HeapType a, HeapType b) { + if (a.isBasic() != b.isBasic()) { + return b.isBasic() < a.isBasic() ? LT : GT; + } + if (a.isBasic()) { + if (a.getID() != b.getID()) { + return a.getID() < b.getID() ? LT : GT; + } + return EQ; + } + auto itA = indicesA.find(a); + auto itB = indicesB.find(b); + bool foundA = itA != indicesA.end(); + bool foundB = itB != indicesB.end(); + if (foundA != foundB) { + return foundB < foundA ? LT : GT; + } + if (foundA) { + auto indexA = itA->second; + auto indexB = itB->second; + if (indexA != indexB) { + return indexA < indexB ? LT : GT; + } + } + // These types are external to the group, so fall back to the provided + // comparator. + return compareTypes(a, b); + } +}; + +// Deduction guide to satisfy -Wctad-maybe-unsupported. +template<typename CompareTypes> +RecGroupComparator(CompareTypes) -> RecGroupComparator<CompareTypes>; + +struct RecGroupHasher { + std::unordered_map<HeapType, Index> typeIndices; + + size_t hash(const RecGroupShape& shape) { + for (auto type : shape.types) { + typeIndices.insert({type, typeIndices.size()}); + } + size_t digest = wasm::hash(shape.types.size()); + for (auto type : shape.types) { + hash_combine(digest, hashDefinition(type)); + } + return digest; + } + + size_t hashDefinition(HeapType type) { + size_t digest = wasm::hash(type.isShared()); + wasm::rehash(digest, type.isOpen()); + auto super = type.getDeclaredSuperType(); + wasm::rehash(digest, super.has_value()); + if (super) { + hash_combine(digest, hash(*super)); + } + auto kind = type.getKind(); + // Mix in very random numbers to differentiate the kinds. + switch (kind) { + case HeapTypeKind::Func: + wasm::rehash(digest, 1904683903); + hash_combine(digest, hash(type.getSignature())); + return digest; + case HeapTypeKind::Struct: + wasm::rehash(digest, 3273309159); + hash_combine(digest, hash(type.getStruct())); + return digest; + case HeapTypeKind::Array: + wasm::rehash(digest, 4254688366); + hash_combine(digest, hash(type.getArray())); + return digest; + case HeapTypeKind::Cont: + wasm::rehash(digest, 2381496927); + hash_combine(digest, hash(type.getContinuation())); + return digest; + case HeapTypeKind::Basic: + break; + } + WASM_UNREACHABLE("unexpected kind"); + } + + size_t hash(Signature sig) { + size_t digest = hash(sig.params); + hash_combine(digest, hash(sig.results)); + return digest; + } + + size_t hash(const Struct& struct_) { + size_t digest = wasm::hash(struct_.fields.size()); + for (auto field : struct_.fields) { + hash_combine(digest, hash(field)); + } + return digest; + } + + size_t hash(Array array) { return hash(array.element); } + + size_t hash(Continuation cont) { return hash(cont.type); } + + size_t hash(Field field) { + size_t digest = wasm::hash(field.mutable_); + wasm::rehash(digest, field.packedType); + hash_combine(digest, hash(field.type)); + return digest; + } + + size_t hash(Type type) { + size_t digest = wasm::hash(type.isBasic()); + if (type.isBasic()) { + wasm::rehash(digest, type.getBasic()); + return digest; + } + wasm::rehash(digest, type.isTuple()); + if (type.isTuple()) { + hash_combine(digest, hash(type.getTuple())); + return digest; + } + assert(type.isRef()); + wasm::rehash(digest, type.isNullable()); + hash_combine(digest, hash(type.getHeapType())); + return digest; + } + + size_t hash(const Tuple& tuple) { + size_t digest = wasm::hash(tuple.size()); + for (auto type : tuple) { + hash_combine(digest, hash(type)); + } + return digest; + } + + size_t hash(HeapType type) { + size_t digest = wasm::hash(type.isBasic()); + if (type.isBasic()) { + wasm::rehash(digest, type.getID()); + return digest; + } + auto it = typeIndices.find(type); + wasm::rehash(digest, it != typeIndices.end()); + if (it != typeIndices.end()) { + wasm::rehash(digest, it->second); + return digest; + } + wasm::rehash(digest, type.getID()); + return digest; + } +}; + +Comparison compareComparable(const ComparableRecGroupShape& a, + const RecGroupShape& b) { + return RecGroupComparator{[&](HeapType ht1, HeapType ht2) { + return a.less(ht1, ht2) ? LT : a.less(ht2, ht1) ? GT : EQ; + }} + .compare(a, b); +} + +} // anonymous namespace + +bool RecGroupShape::operator==(const RecGroupShape& other) const { + return EQ == RecGroupComparator{[](HeapType a, HeapType b) { + return a == b ? EQ : LT; + }}.compare(*this, other); +} + +bool ComparableRecGroupShape::operator<(const RecGroupShape& other) const { + return LT == compareComparable(*this, other); +} + +bool ComparableRecGroupShape::operator>(const RecGroupShape& other) const { + return GT == compareComparable(*this, other); +} + +} // namespace wasm + +namespace std { + +size_t +hash<wasm::RecGroupShape>::operator()(const wasm::RecGroupShape& shape) const { + return wasm::RecGroupHasher{}.hash(shape); +} + +} // namespace std |