diff options
-rw-r--r-- | src/analysis/lattices/inverted.h | 2 | ||||
-rw-r--r-- | src/analysis/lattices/lift.h | 1 | ||||
-rw-r--r-- | src/tools/wasm-fuzz-lattices.cpp | 641 |
3 files changed, 525 insertions, 119 deletions
diff --git a/src/analysis/lattices/inverted.h b/src/analysis/lattices/inverted.h index 92b965323..c69f71a8b 100644 --- a/src/analysis/lattices/inverted.h +++ b/src/analysis/lattices/inverted.h @@ -35,7 +35,7 @@ template<FullLattice L> struct Inverted { Element getBottom() const noexcept { return lattice.getTop(); } Element getTop() const noexcept { return lattice.getBottom(); } LatticeComparison compare(const Element& a, const Element& b) const noexcept { - return reverseComparison(lattice.compare(a, b)); + return lattice.compare(b, a); } bool join(Element& self, Element other) const noexcept { return lattice.meet(self, other); diff --git a/src/analysis/lattices/lift.h b/src/analysis/lattices/lift.h index a458fd6b6..fc7c9754f 100644 --- a/src/analysis/lattices/lift.h +++ b/src/analysis/lattices/lift.h @@ -40,6 +40,7 @@ template<Lattice L> struct Lift { }; L lattice; + Lift(L&& lattice) : lattice(std::move(lattice)) {} Element getBottom() const noexcept { return {std::nullopt}; } diff --git a/src/tools/wasm-fuzz-lattices.cpp b/src/tools/wasm-fuzz-lattices.cpp index 390104223..abc855467 100644 --- a/src/tools/wasm-fuzz-lattices.cpp +++ b/src/tools/wasm-fuzz-lattices.cpp @@ -17,8 +17,15 @@ #include <optional> #include <random> #include <string> +#include <type_traits> +#include <variant> #include "analysis/lattice.h" +#include "analysis/lattices/bool.h" +#include "analysis/lattices/flat.h" +#include "analysis/lattices/int.h" +#include "analysis/lattices/inverted.h" +#include "analysis/lattices/lift.h" #include "analysis/lattices/stack.h" #include "analysis/liveness-transfer-function.h" #include "analysis/reaching-definitions-transfer-function.h" @@ -35,7 +42,6 @@ using namespace analysis; // Helps printing error messages. std::string LatticeComparisonNames[4] = { "No Relation", "Equal", "Less", "Greater"}; -std::string LatticeComparisonSymbols[4] = {"?", "=", "<", ">"}; uint64_t getSeed() { // Return a (truly) random 64-bit value. @@ -43,6 +49,414 @@ uint64_t getSeed() { return std::uniform_int_distribution<uint64_t>{}(rand); } +// Actually a pointer to `L::ElementImpl`, but we erase the type to avoid +// getting into a situation where `L` satisfying `Lattice` or `FullLattice` +// circularly requires that `L` satisfies `Lattice` or `FullLattice`. C++ does +// not allow concepts to depend on themselves. Also make the pointer copyable to +// satisfy that constraint on lattice elements. +template<typename L> +struct RandomElement : std::unique_ptr<void, void (*)(void*)> { + RandomElement() = default; + + RandomElement(typename L::ElementImpl&& other) + : std::unique_ptr<void, void (*)(void*)>( + new typename L::ElementImpl(std::move(other)), + [](void* e) { delete static_cast<typename L::ElementImpl*>(e); }) {} + + RandomElement(const RandomElement& other) + : RandomElement([&]() { + auto copy = *other; + return copy; + }()) {} + + RandomElement(RandomElement&& other) = default; + RandomElement& operator=(const RandomElement& other) { + if (this != &other) { + new (this) RandomElement(other); + } + return *this; + } + + RandomElement& operator=(RandomElement&& other) = default; + + typename L::ElementImpl& operator*() { + return *static_cast<typename L::ElementImpl*>(get()); + } + + const typename L::ElementImpl& operator*() const { + return *static_cast<const typename L::ElementImpl*>(get()); + } + + typename L::ElementImpl* operator->() { return &*(*this); } + + const typename L::ElementImpl* operator->() const { return &*(*this); } +}; + +struct RandomFullLattice { + // The inner lattice and lattice element types. These must be defined later + // because they depend on `RandomFullLattice` satisfying `FullLattice`, but + // that requires the type to be complete. + struct L; + struct ElementImpl; + using Element = RandomElement<RandomFullLattice>; + + Random& rand; + + // Indirect because L recursively contains RandomFullLattice. + std::unique_ptr<L> lattice; + + RandomFullLattice(Random& rand, + size_t depth = 0, + std::optional<uint32_t> maybePick = std::nullopt); + + // Make a random element of this lattice. + Element makeElement() const noexcept; + + Element getBottom() const noexcept; + Element getTop() const noexcept; + LatticeComparison compare(const Element& a, const Element& b) const noexcept; + bool join(Element& a, const Element& b) const noexcept; + bool meet(Element& a, const Element& b) const noexcept; +}; + +struct RandomLattice { + // The inner lattice and lattice element types. These must be defined later + // because they depend on `RandomLattice` satisfying `Lattice`, but that + // requires the type to be complete. + struct L; + struct ElementImpl; + using Element = RandomElement<RandomLattice>; + + Random& rand; + + // Indirect because L recursively contains RandomLattice. + std::unique_ptr<L> lattice; + + RandomLattice(Random& rand, size_t depth = 0); + + // Make a random element of this lattice. + Element makeElement() const noexcept; + + Element getBottom() const noexcept; + LatticeComparison compare(const Element& a, const Element& b) const noexcept; + bool join(Element& a, const Element& b) const noexcept; +}; + +#if __cplusplus >= 202002L +static_assert(FullLattice<RandomFullLattice>); +static_assert(Lattice<RandomLattice>); +#endif + +struct RandomFullLattice::L + : std::variant<Bool, UInt32, Inverted<RandomFullLattice>> {}; + +struct RandomFullLattice::ElementImpl + : std::variant<typename Bool::Element, + typename UInt32::Element, + typename Inverted<RandomFullLattice>::Element> {}; + +struct RandomLattice::L + : std::variant<RandomFullLattice, Flat<uint32_t>, Lift<RandomLattice>> {}; + +struct RandomLattice::ElementImpl + : std::variant<typename RandomFullLattice::Element, + typename Flat<uint32_t>::Element, + typename Lift<RandomLattice>::Element> {}; + +RandomFullLattice::RandomFullLattice(Random& rand, + size_t depth, + std::optional<uint32_t> maybePick) + : rand(rand) { + // TODO: Limit the depth once we get lattices with more fan-out. + uint32_t pick = maybePick ? *maybePick : rand.upTo(3); + switch (pick) { + case 0: + lattice = std::make_unique<L>(L{Bool{}}); + return; + case 1: + lattice = std::make_unique<L>(L{UInt32{}}); + return; + case 2: + lattice = + std::make_unique<L>(L{Inverted{RandomFullLattice{rand, depth + 1}}}); + return; + } + WASM_UNREACHABLE("unexpected pick"); +} + +RandomLattice::RandomLattice(Random& rand, size_t depth) : rand(rand) { + // TODO: Limit the depth once we get lattices with more fan-out. + uint32_t pick = rand.upTo(5); + switch (pick) { + case 0: + case 1: + case 2: + lattice = std::make_unique<L>(L{RandomFullLattice{rand, depth, pick}}); + return; + case 3: + lattice = std::make_unique<L>(L{Flat<uint32_t>{}}); + return; + case 4: + lattice = std::make_unique<L>(L{Lift{RandomLattice{rand, depth + 1}}}); + return; + } + WASM_UNREACHABLE("unexpected pick"); +} + +RandomFullLattice::Element RandomFullLattice::makeElement() const noexcept { + if (std::get_if<Bool>(lattice.get())) { + return ElementImpl{rand.pick(true, false)}; + } + if (std::get_if<UInt32>(lattice.get())) { + return ElementImpl{rand.upToSquared(33)}; + } + if (const auto* l = std::get_if<Inverted<RandomFullLattice>>(lattice.get())) { + return ElementImpl{l->lattice.makeElement()}; + } + WASM_UNREACHABLE("unexpected lattice"); +} + +RandomLattice::Element RandomLattice::makeElement() const noexcept { + if (const auto* l = std::get_if<RandomFullLattice>(lattice.get())) { + return ElementImpl{l->makeElement()}; + } + if (const auto* l = std::get_if<Flat<uint32_t>>(lattice.get())) { + auto pick = rand.upTo(6); + switch (pick) { + case 4: + return ElementImpl{l->getBottom()}; + case 5: + return ElementImpl{l->getTop()}; + default: + return ElementImpl{l->get(std::move(pick))}; + } + } + if (const auto* l = std::get_if<Lift<RandomLattice>>(lattice.get())) { + return ElementImpl{rand.oneIn(4) ? l->getBottom() + : l->get(l->lattice.makeElement())}; + } + WASM_UNREACHABLE("unexpected lattice"); +} + +void indent(std::ostream& os, int depth) { + for (int i = 0; i < depth; ++i) { + os << " "; + } +} + +void printFullElement(std::ostream& os, + const typename RandomFullLattice::Element& elem, + int depth) { + indent(os, depth); + + if (const auto* e = std::get_if<typename Bool::Element>(&*elem)) { + os << (*e ? "true" : "false") << "\n"; + } else if (const auto* e = std::get_if<typename UInt32::Element>(&*elem)) { + os << *e << "\n"; + } else if (const auto* e = + std::get_if<typename Inverted<RandomFullLattice>::Element>( + &*elem)) { + os << "Inverted(\n"; + printFullElement(os, *e, depth + 1); + indent(os, depth); + os << ")\n"; + } +} + +void printElement(std::ostream& os, + const typename RandomLattice::Element& elem, + int depth = 0) { + if (const auto* e = + std::get_if<typename RandomFullLattice::Element>(&*elem)) { + printFullElement(os, *e, depth); + return; + } + + indent(os, depth); + + if (const auto* e = std::get_if<typename Flat<uint32_t>::Element>(&*elem)) { + if (e->isBottom()) { + os << "flat bot\n"; + } else if (e->isTop()) { + os << "flat top\n"; + } else { + os << "flat " << *e->getVal() << "\n"; + } + } else if (const auto* e = + std::get_if<typename Lift<RandomLattice>::Element>(&*elem)) { + if (e->isBottom()) { + os << "lift bot\n"; + } else { + os << "Lifted(\n"; + printElement(os, **e, depth + 1); + indent(os, depth); + os << ")\n"; + } + } +} + +std::ostream& operator<<(std::ostream& os, + const typename RandomLattice::Element& elem) { + printElement(os, elem); + return os; +} + +// Check that random lattices have the correct mathematical properties by +// checking the relationships between random elements. +void checkLatticeProperties(Random& rand) { + RandomLattice lattice(rand); + + // Generate the lattice elements we will perform checks on. + typename RandomLattice::Element elems[3] = { + lattice.makeElement(), lattice.makeElement(), lattice.makeElement()}; + + // Calculate the relations between the generated elements. + LatticeComparison relation[3][3]; + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + relation[i][j] = lattice.compare(elems[i], elems[j]); + } + } + + // Reflexivity: x == x + for (int i = 0; i < 3; ++i) { + if (lattice.compare(elems[i], elems[i]) != EQUAL) { + Fatal() << "Lattice element is not reflexive:\n" << elems[i]; + } + } + + // Anti-symmetry: x < y implies y > x, etc. + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + auto forward = relation[i][j]; + auto reverse = relation[j][i]; + if (reverseComparison(forward) != reverse) { + Fatal() + << "Lattice elements are not anti-symmetric.\nFirst element:\n\n" + << elems[i] << "\nSecond element:\n\n" + << elems[j] + << "\nForward relation: " << LatticeComparisonNames[forward] + << "\nReverse relation: " << LatticeComparisonNames[reverse] << "\n"; + } + } + } + + // Transitivity: x < y and y < z imply x < z, etc. + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + auto ij = relation[i][j], jk = relation[j][k], ik = relation[i][k]; + if (ij == NO_RELATION || jk == NO_RELATION) { + continue; + } + if ((ij == LESS && jk == GREATER) || (ij == GREATER && jk == LESS)) { + continue; + } + auto expected = ij == EQUAL ? jk : ij; + if (ik != expected) { + Fatal() << "Lattice elements are not transitive.\nFrist element:\n\n" + << elems[i] << "\nSecond element:\n\n" + << elems[j] << "\nThird element:\n\n" + << elems[k] << "\nFirst to second relation: " + << LatticeComparisonNames[ij] + << "\nSecond to third relation: " + << LatticeComparisonNames[jk] + << "\nFirst to thrid relation: " << LatticeComparisonNames[ik] + << "\n"; + } + } + } + } + + // Joins (i.e. least upper bounds) + for (int i = 0; i < 3; ++i) { + { + // Identity: elem u bot = elem + auto join = elems[i]; + lattice.join(join, lattice.getBottom()); + if (lattice.compare(join, elems[i]) != EQUAL) { + Fatal() + << "Join of element and bottom is not equal to element:\nElement:\n\n" + << elems[i] << "\nJoin:\n\n" + << join; + } + } + { + // Identity: bot u elem = elem + auto join = lattice.getBottom(); + lattice.join(join, elems[i]); + if (lattice.compare(join, elems[i]) != EQUAL) { + Fatal() + << "Join of bottom and element is not equal to element:\nElement:\n\n" + << elems[i] << "\nJoin:\n\n" + << join; + } + } + { + // Identity: elem u elem = elem + auto join = elems[i]; + lattice.join(join, elems[i]); + if (lattice.compare(join, elems[i]) != EQUAL) { + Fatal() + << "Join of element with itself equal to element:\nElement:\n\n" + << elems[i] << "\nJoin:\n\n" + << join; + } + } + for (int j = 0; j < 3; ++j) { + // Commutativity: x u y = y u x + auto ij = elems[i]; + bool ijModified = lattice.join(ij, elems[j]); + auto ji = elems[j]; + bool jiModified = lattice.join(ji, elems[i]); + if (lattice.compare(ij, ji) != EQUAL) { + Fatal() << "Join is not commutative:\nFirst element:\n\n" + << elems[i] << "\nSecond element:\n\n" + << elems[j] << "\nJoin(first, second):\n\n" + << ij << "\nJoin(second, first):\n\n" + << ji; + } + + // Identity: x < y implies x u y = y + if (relation[i][j] == LESS) { + if (lattice.compare(ij, elems[j]) != EQUAL) { + Fatal() + << "Join is not equal to greater element:\nLesser element:\n\n" + << elems[i] << "\nGreater element:\n\n" + << elems[j] << "\nJoin:\n\n" + << ij; + } + if (jiModified) { + Fatal() + << "Join incorrectly reported modification:\nLesser element:\n\n" + << elems[i] << "\nGreater element:\n\n" + << elems[j]; + } + if (!ijModified) { + Fatal() + << "Join should have reported modification:\nLesser element:\n\n" + << elems[i] << "\nGreater element:\n\n" + << elems[j]; + } + } + + for (int k = 0; k < 3; ++k) { + // *Least* upper bound: x <= z && y <= z implies x u y <= z + if (relation[i][k] == LESS && relation[j][k] == LESS) { + auto IJtoK = lattice.compare(ij, elems[k]); + if (IJtoK != EQUAL && IJtoK != LESS) { + Fatal() << "Join is not least upper bound:\nFirst element:\n\n" + << elems[i] << "\nSecond element:\n\n" + << elems[j] << "\nJoin:\n\n" + << ij << "\nOther upper bound:\n\n" + << elems[k]; + } + } + } + } + } +} + // Utility class which provides methods to check properties of the transfer // function and lattice of an analysis. template<Lattice L, TransferFunction TxFn> struct AnalysisChecker { @@ -85,109 +499,7 @@ template<Lattice L, TransferFunction TxFn> struct AnalysisChecker { os << "\nfor " << funcName << " to test " << txfnName << ".\n\n"; } - // Checks reflexivity of a lattice element, i.e. x = x. - void checkReflexivity(typename L::Element& element) { - LatticeComparison result = lattice.compare(element, element); - if (result != LatticeComparison::EQUAL) { - std::stringstream ss; - printFailureInfo(ss); - ss << "Element "; - element.print(ss); - ss << " is not reflexive.\n"; - Fatal() << ss.str(); - } - } - - // Anti-Symmetry is defined as x <= y and y <= x imply x = y. Due to the - // fact that the compare(x, y) function of the lattice explicitly tells - // us if two lattice elements are <, =, or = instead of providing a - // <= comparison, it is not useful to check for anti-symmetry as it is defined - // in the fuzzer. - // - // Instead, we check for a related concept that x < y implies y > x, and - // vice versa in this checkAntiSymmetry function. - void checkAntiSymmetry(typename L::Element& x, typename L::Element& y) { - LatticeComparison result = lattice.compare(x, y); - LatticeComparison reverseResult = lattice.compare(y, x); - - if (reverseComparison(result) != reverseResult) { - std::stringstream ss; - printFailureInfo(ss); - x.print(ss); - ss << " " << LatticeComparisonNames[result] << " "; - y.print(ss); - ss << " but reverse direction comparison is " - << LatticeComparisonNames[reverseResult] << ".\n"; - Fatal() << ss.str(); - } - } - -private: - // Prints the error message when a triple of lattice elements violates - // transitivity. - void printTransitivityError(std::ostream& os, - typename L::Element& a, - typename L::Element& b, - typename L::Element& c, - LatticeComparison ab, - LatticeComparison bc, - LatticeComparison ac) { - printFailureInfo(os); - os << "Elements a = "; - a.print(os); - os << ", b = "; - b.print(os); - os << ", and c = "; - c.print(os); - os << " are not transitive. a" << LatticeComparisonSymbols[ab] << "b and b" - << LatticeComparisonSymbols[bc] << "c, but a" - << LatticeComparisonSymbols[ac] << "c.\n"; - } - - // Returns true if given a-b and b-c comparisons, the a-c comparison violates - // transitivity. - bool violatesTransitivity(LatticeComparison ab, - LatticeComparison bc, - LatticeComparison ac) { - if (ab != LatticeComparison::NO_RELATION && - (bc == LatticeComparison::EQUAL || bc == ab) && ab != ac) { - return true; - } else if (bc != LatticeComparison::NO_RELATION && - (ab == LatticeComparison::EQUAL || ab == bc) && bc != ac) { - return true; - } - return false; - } - public: - // Given three lattice elements x, y, and z, checks if transitivity holds - // between them. - void checkTransitivity(typename L::Element& x, - typename L::Element& y, - typename L::Element& z) { - LatticeComparison xy = lattice.compare(x, y); - LatticeComparison yz = lattice.compare(y, z); - LatticeComparison xz = lattice.compare(x, z); - - LatticeComparison yx = reverseComparison(xy); - LatticeComparison zy = reverseComparison(yz); - - // Cover all permutations of x, y, and z. - if (violatesTransitivity(xy, yz, xz)) { - std::stringstream ss; - printTransitivityError(ss, x, y, z, xy, yz, xz); - Fatal() << ss.str(); - } else if (violatesTransitivity(yx, xz, yz)) { - std::stringstream ss; - printTransitivityError(ss, y, x, z, yx, xz, yz); - Fatal() << ss.str(); - } else if (violatesTransitivity(xz, zy, xy)) { - std::stringstream ss; - printTransitivityError(ss, x, z, y, xz, zy, xy); - Fatal() << ss.str(); - } - } - // Given two input - output lattice pairs of a transfer function, checks if // the transfer function is monotonic. If this is violated, then we print out // the CFG block input which caused the transfer function to exhibit @@ -239,19 +551,6 @@ public: Fatal() << ss.str(); } - // Checks lattice-only properties for a triple of lattices. - void checkLatticeElements(typename L::Element x, - typename L::Element y, - typename L::Element z) { - checkReflexivity(x); - checkReflexivity(y); - checkReflexivity(z); - checkAntiSymmetry(x, y); - checkAntiSymmetry(x, z); - checkAntiSymmetry(y, z); - checkTransitivity(x, y, z); - } - // Checks transfer function relevant properties given a CFG and three input // states. It does this by applying the transfer function on each CFG block // using the same three input states each time and then checking properties on @@ -311,7 +610,6 @@ struct LivenessChecker { checker.printVerboseFunctionCase(std::cout, x, y, z); } - checker.checkLatticeElements(x, y, z); checker.checkTransferFunction(cfg, x, y, z); } }; @@ -357,11 +655,117 @@ struct ReachingDefinitionsChecker { checker.printVerboseFunctionCase(std::cout, x, y, z); } - checker.checkLatticeElements(x, y, z); checker.checkTransferFunction(cfg, x, y, z); } }; +// Uninteresting implementation details for RandomFullLattice and RandomLattice. + +RandomFullLattice::Element RandomFullLattice::getBottom() const noexcept { + return std::visit([](const auto& l) { return ElementImpl{l.getBottom()}; }, + *lattice); +} + +RandomFullLattice::Element RandomFullLattice::getTop() const noexcept { + return std::visit([](const auto& l) { return ElementImpl{l.getTop()}; }, + *lattice); +} + +// TODO: use std::remove_cvref_t from C++20 instead. +template<typename T> using bare = std::remove_cv_t<std::remove_reference_t<T>>; + +LatticeComparison RandomFullLattice::compare(const Element& a, + const Element& b) const noexcept { + return std::visit( + [](const auto& l, + const auto& elemA, + const auto& elemB) -> LatticeComparison { + using ElemT = typename bare<decltype(l)>::Element; + using A = bare<decltype(elemA)>; + using B = bare<decltype(elemB)>; + if constexpr (std::is_same_v<ElemT, A> && std::is_same_v<ElemT, B>) { + return l.compare(elemA, elemB); + } + WASM_UNREACHABLE("unexpected element types"); + }, + *lattice, + *a, + *b); +} + +bool RandomFullLattice::join(Element& a, const Element& b) const noexcept { + return std::visit( + [](const auto& l, auto& elemA, const auto& elemB) -> bool { + using ElemT = typename bare<decltype(l)>::Element; + using A = bare<decltype(elemA)>; + using B = bare<decltype(elemB)>; + if constexpr (std::is_same_v<ElemT, A> && std::is_same_v<ElemT, B>) { + return l.join(elemA, elemB); + } + WASM_UNREACHABLE("unexpected element types"); + }, + *lattice, + *a, + *b); +} + +bool RandomFullLattice::meet(Element& a, const Element& b) const noexcept { + return std::visit( + [](const auto& l, auto& elemA, const auto& elemB) -> bool { + using ElemT = typename bare<decltype(l)>::Element; + using A = bare<decltype(elemA)>; + using B = bare<decltype(elemB)>; + if constexpr (std::is_same_v<ElemT, A> && std::is_same_v<ElemT, B>) { + return l.meet(elemA, elemB); + } + WASM_UNREACHABLE("unexpected element types"); + }, + *lattice, + *a, + *b); +} + +RandomLattice::Element RandomLattice::getBottom() const noexcept { + return std::visit([](const auto& l) { return ElementImpl{l.getBottom()}; }, + *lattice); +} + +LatticeComparison RandomLattice::compare(const Element& a, + const Element& b) const noexcept { + return std::visit( + [](const auto& l, + const auto& elemA, + const auto& elemB) -> LatticeComparison { + using ElemT = typename bare<decltype(l)>::Element; + using A = bare<decltype(elemA)>; + using B = bare<decltype(elemB)>; + if constexpr (std::is_same_v<ElemT, A> && std::is_same_v<ElemT, B>) { + return l.compare(elemA, elemB); + } + WASM_UNREACHABLE("unexpected element types"); + }, + *lattice, + *a, + *b); +} + +bool RandomLattice::join(Element& a, const Element& b) const noexcept { + return std::visit( + [](const auto& l, auto& elemA, const auto& elemB) -> bool { + using ElemT = typename bare<decltype(l)>::Element; + using A = bare<decltype(elemA)>; + using B = bare<decltype(elemB)>; + if constexpr (std::is_same_v<ElemT, A> && std::is_same_v<ElemT, B>) { + return l.join(elemA, elemB); + } + WASM_UNREACHABLE("unexpected element types"); + }, + *lattice, + *a, + *b); +} + +// The main entry point. struct Fuzzer { bool verbose; @@ -380,6 +784,7 @@ struct Fuzzer { } Random rand(std::move(funcBytes)); + checkLatticeProperties(rand); CFG cfg = CFG::fromFunction(func); |