diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2021-02-18 17:48:58 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-18 17:48:58 -0800 |
commit | 22fc60d86538a6111f7b953fd70362ee73dce7d8 (patch) | |
tree | 21e3a7de3bc25516e77bb88cfec2ed8e925f679b /src | |
parent | 3e31f4fd583324ad446fb96bc0d073e141157f7f (diff) | |
download | binaryen-22fc60d86538a6111f7b953fd70362ee73dce7d8.tar.gz binaryen-22fc60d86538a6111f7b953fd70362ee73dce7d8.tar.bz2 binaryen-22fc60d86538a6111f7b953fd70362ee73dce7d8.zip |
Fix TypeBuilder canonicalization (#3578)
When types or heap types were used multiple times in a TypeBuilder instance, it
was possible for the canonicalization algorithm to canonicalize a parent type
before canonicalizing all of its component child types, leaking the temporary
types into globally interned types. This bug led to incorrect canonicalization
results and use-after free bugs.
The cause of the bug was that types were canonicalized in the reverse of the
order that they were visited in, but children were visited after the first
occurrence of their parents, not necessarily after the last occurrence of their
parents. One fix could have been to remove the logic that prevented types from
being visited multiple times so that children would always be visited after
their parents. That simple fix, however, would not scale gracefully to handle
recursive types because it would require some way to detect recursions without
accidentally reintroducing these bugs.
This PR implements a more robust solution: topologically sorting the traversed
types to ensure that children are canonicalized before their parents. This
solution will be trivial to adapt for recursive types because recursive types
are trivial to detect from the reachability graph used to perform the
topological sort.
Diffstat (limited to 'src')
-rw-r--r-- | src/wasm/wasm-type.cpp | 123 |
1 files changed, 96 insertions, 27 deletions
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index ce871c728..87dd2b289 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include <algorithm> #include <array> #include <cassert> #include <shared_mutex> @@ -1135,13 +1136,14 @@ struct Canonicalizer { // The work list of Types and HeapTypes remaining to be scanned. std::vector<Item> scanList; - // The list of Types and HeapTypes to visit constructed in forward preorder - // and eventually traversed in reverse to give a reverse postorder. - std::vector<Item> visitList; + // Maps Type and HeapType IDs to the IDs of their ancestor Types and HeapTypes + // in the type graph. Only considers compound Types and HeapTypes. + std::unordered_map<TypeID, std::unordered_set<TypeID>> ancestors; - // Maps Type and HeapType IDs to the IDs of Types and HeapTypes they can - // reach in the type graph. Only considers compound Types and HeapTypes. - std::unordered_map<TypeID, std::unordered_set<TypeID>> reaches; + // Maps each temporary Type and HeapType to the locations where they will have + // to be replaced with canonical Types and HeapTypes. + std::unordered_map<Type, std::vector<Type*>> typeLocations; + std::unordered_map<HeapType, std::vector<HeapType*>> heapTypeLocations; // Maps Types and HeapTypes backed by the TypeBuilder's Stores to globally // canonical Types and HeapTypes. @@ -1155,7 +1157,8 @@ struct Canonicalizer { template<typename T1, typename T2> void noteChild(T1 parent, T2* child); void scanHeapType(HeapType* ht); void scanType(Type* child); - void makeReachabilityFixedPoint(); + void makeAncestorsFixedPoint(); + std::vector<Item> getOrderedItems(); // Replaces the pointee Type or HeapType of `type` with its globally canonical // equivalent, recording the substitution for future use in either @@ -1199,22 +1202,22 @@ Canonicalizer::Canonicalizer(TypeBuilder& builder) : builder(builder) { // Check for recursive types and heap types. TODO: pre-canonicalize these into // their minimal finite representations. - makeReachabilityFixedPoint(); - for (auto& reach : reaches) { - if (reach.second.count(reach.first) != 0) { + makeAncestorsFixedPoint(); + for (auto& kv : ancestors) { + if (kv.second.count(kv.first) != 0) { WASM_UNREACHABLE("TODO: support recursive types"); } } // Visit the types and heap types in reverse postorder, replacing them with // their canonicalized versions. - for (auto it = visitList.rbegin(); it != visitList.rend(); ++it) { - switch (it->kind) { + for (auto it : getOrderedItems()) { + switch (it.kind) { case Item::TypeKind: - canonicalize(it->type, canonicalTypes); + canonicalize(it.type, canonicalTypes); break; case Item::HeapTypeKind: - canonicalize(it->heapType, canonicalHeapTypes); + canonicalize(it.heapType, canonicalHeapTypes); break; } } @@ -1223,14 +1226,14 @@ Canonicalizer::Canonicalizer(TypeBuilder& builder) : builder(builder) { template<typename T1, typename T2> void Canonicalizer::noteChild(T1 parent, T2* child) { if (child->isCompound()) { - reaches[parent.getID()].insert(child->getID()); + ancestors[child->getID()].insert(parent.getID()); scanList.push_back(child); } } void Canonicalizer::scanHeapType(HeapType* ht) { assert(ht->isCompound()); - visitList.push_back(ht); + heapTypeLocations[*ht].push_back(ht); if (scanned.count(ht->getID())) { return; } @@ -1255,7 +1258,7 @@ void Canonicalizer::scanHeapType(HeapType* ht) { void Canonicalizer::scanType(Type* type) { assert(type->isCompound()); - visitList.push_back(type); + typeLocations[*type].push_back(type); if (scanned.count(type->getID())) { return; } @@ -1277,27 +1280,93 @@ void Canonicalizer::scanType(Type* type) { } } -void Canonicalizer::makeReachabilityFixedPoint() { +void Canonicalizer::makeAncestorsFixedPoint() { // Naively calculate the transitive closure of the reachability graph. bool changed; do { changed = false; - for (auto& entry : reaches) { - auto& reachable = entry.second; - std::unordered_set<TypeID> nextReachable; - for (auto& other : reachable) { - auto& otherReaches = reaches[other]; - nextReachable.insert(otherReaches.begin(), otherReaches.end()); + for (auto& entry : ancestors) { + auto& succs = entry.second; + std::unordered_set<TypeID> nextAncestors; + for (auto& other : succs) { + auto& otherAncestors = ancestors[other]; + nextAncestors.insert(otherAncestors.begin(), otherAncestors.end()); } - size_t oldSize = reachable.size(); - reachable.insert(nextReachable.begin(), nextReachable.end()); - if (reachable.size() != oldSize) { + size_t oldSize = succs.size(); + succs.insert(nextAncestors.begin(), nextAncestors.end()); + if (succs.size() != oldSize) { changed = true; } } } while (changed); } +std::vector<Canonicalizer::Item> Canonicalizer::getOrderedItems() { + // Topologically sort the Types and HeapTypes so that all children are + // canonicalized before their parents. + + std::vector<TypeID> sorted; + std::unordered_set<TypeID> seen; + + // Topologically sort so that all parents are pushed before their children. + // This sort will be reversed later to have children before their parents. + std::function<void(TypeID)> visit = [&](TypeID i) { + if (seen.count(i)) { + return; + } + // Push ancestors of the current type before pushing the current type. + auto it = ancestors.find(i); + if (it != ancestors.end()) { + for (auto ancestor : it->second) { + visit(ancestor); + } + } + seen.insert(i); + sorted.push_back(i); + }; + + // Collect the items to be sorted, including all the result HeapTypes and any + // type that participates in the ancestor relation. + std::unordered_set<TypeID> allIDs; + for (auto ht : results) { + allIDs.insert(ht.getID()); + } + for (auto& kv : ancestors) { + allIDs.insert(kv.first); + for (auto& id : kv.second) { + allIDs.insert(id); + } + } + + // Perform the sort. + for (TypeID i : allIDs) { + visit(i); + } + + // Swap order to have all children before their parents. + std::reverse(sorted.begin(), sorted.end()); + + // Create a list of Items to canonicalize in place according to the + // topological ordering of types and heap types. + std::vector<Item> items; + for (TypeID id : sorted) { + // IDs may be Types or HeapTypes, so just try both. + auto typeIt = typeLocations.find(Type(id)); + if (typeIt != typeLocations.end()) { + for (Type* loc : typeIt->second) { + items.emplace_back(loc); + } + } else { + auto heapTypeIt = heapTypeLocations.find(HeapType(id)); + assert(heapTypeIt != heapTypeLocations.end()); + for (HeapType* loc : heapTypeIt->second) { + items.emplace_back(loc); + } + } + } + return items; +} + template<typename T> void Canonicalizer::canonicalize(T* type, std::unordered_map<T, T>& canonicals) { |