diff options
-rw-r--r-- | src/ir/module-utils.cpp | 101 | ||||
-rw-r--r-- | src/support/topological_sort.h | 118 |
2 files changed, 152 insertions, 67 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index e1c7aaccb..b5d0d7d48 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -16,6 +16,7 @@ #include "module-utils.h" #include "support/insert_ordered.h" +#include "support/topological_sort.h" namespace wasm::ModuleUtils { @@ -189,38 +190,6 @@ void setIndices(IndexedHeapTypes& indexedTypes) { } } -// A utility data structure for performing topological sorts. The elements to be -// sorted should all be `push`ed to the stack, then while the stack is not -// empty, predecessors of the top element should be pushed. On each iteration, -// if the top of the stack is unchanged after pushing all the predecessors, that -// means the predecessors have all been finished so the current element is -// finished as well and should be popped. -template<typename T> struct TopologicalSortStack { - std::vector<T> workStack; - // Remember which items we have finished so we don't visit them again. - std::unordered_set<T> finished; - - bool empty() const { return workStack.empty(); } - - void push(T item) { - if (finished.count(item)) { - return; - } - workStack.push_back(item); - } - - T& peek() { return workStack.back(); } - - void pop() { - // Pop until the stack is empty or it has an unfinished item on top. - finished.insert(workStack.back()); - workStack.pop_back(); - while (!workStack.empty() && finished.count(workStack.back())) { - workStack.pop_back(); - } - } -}; - } // anonymous namespace std::vector<HeapType> collectHeapTypes(Module& wasm) { @@ -271,7 +240,16 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { } }; - std::unordered_map<RecGroup, GroupInfo> groupInfos; + struct GroupInfoMap : std::unordered_map<RecGroup, GroupInfo> { + void sort(std::vector<RecGroup>& groups) { + std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) { + return this->at(a) < this->at(b); + }); + } + }; + + // Collect the information that will be used to sort the recursion groups. + GroupInfoMap groupInfos; for (auto& [type, _] : counts) { RecGroup group = type.getRecGroup(); // Try to initialize a new info or get the existing info. @@ -295,51 +273,40 @@ IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { info.useCount /= group.size(); } - // Sort the preds of each group by increasing use count so the topological - // sort visits the most used first. Break ties with the group's appearance - // index to ensure determinism. - auto sortGroups = [&](std::vector<RecGroup>& groups) { - std::sort(groups.begin(), groups.end(), [&](auto& a, auto& b) { - return groupInfos.at(a) < groupInfos.at(b); - }); - }; + // Sort the predecessors so the most used will be visited first. for (auto& [group, info] : groupInfos) { info.sortedPreds.insert( info.sortedPreds.end(), info.preds.begin(), info.preds.end()); - sortGroups(info.sortedPreds); + groupInfos.sort(info.sortedPreds); info.preds.clear(); } - // Sort all the groups so the topological sort visits the most used first. - std::vector<RecGroup> sortedGroups; - sortedGroups.reserve(groupInfos.size()); - for (auto& [group, _] : groupInfos) { - sortedGroups.push_back(group); - } - sortGroups(sortedGroups); - - // Perform the topological sort. - TopologicalSortStack<RecGroup> workStack; - for (auto group : sortedGroups) { - workStack.push(group); - } - sortedGroups.clear(); - while (!workStack.empty()) { - auto group = workStack.peek(); - for (auto pred : groupInfos.at(group).sortedPreds) { - workStack.push(pred); + struct RecGroupSort : TopologicalSort<RecGroup, RecGroupSort> { + GroupInfoMap& groupInfos; + RecGroupSort(GroupInfoMap& groupInfos) : groupInfos(groupInfos) { + // Sort all the groups so the topological sort visits the most used first. + std::vector<RecGroup> sortedGroups; + sortedGroups.reserve(groupInfos.size()); + for (auto& [group, _] : groupInfos) { + sortedGroups.push_back(group); + } + groupInfos.sort(sortedGroups); + for (auto group : sortedGroups) { + push(group); + } } - if (workStack.peek() == group) { - // All the predecessors are finished, so `group` is too. - workStack.pop(); - sortedGroups.push_back(group); + + void pushPredecessors(RecGroup group) { + for (auto pred : groupInfos.at(group).sortedPreds) { + push(pred); + } } - } + }; - // Collect and return the grouped types. + // Perform the topological sort and collect the types. IndexedHeapTypes indexedTypes; indexedTypes.types.reserve(counts.size()); - for (auto group : sortedGroups) { + for (auto group : RecGroupSort(groupInfos)) { for (auto member : group) { indexedTypes.types.push_back(member); } diff --git a/src/support/topological_sort.h b/src/support/topological_sort.h new file mode 100644 index 000000000..91353dd37 --- /dev/null +++ b/src/support/topological_sort.h @@ -0,0 +1,118 @@ +/* + * Copyright 2022 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_support_topological_sort_h +#define wasm_support_topological_sort_h + +#include <cstddef> +#include <iterator> +#include <unordered_set> +#include <vector> + +namespace wasm { + +// CRTP utility that provides an iterator through arbitrary directed acyclic +// graphs of data that will visit the data in a topologically sorted order +// (https://en.wikipedia.org/wiki/Topological_sorting). In other words, the +// iterator will produce each item only after all that items predecessors have +// been produced. +// +// Subclasses should call `push` on all the root items in their constructors and +// implement a `void pushPredecessors(T item)` method that calls `push` on all +// the immediate predecessors of `item`. +// +// Cycles in the graph are not detected and will result in an infinite loop. +template<typename T, typename Subtype> struct TopologicalSort { +private: + // The DFS work list. + std::vector<T> workStack; + + // Remember which items we have finished so we don't visit them again. + std::unordered_set<T> finished; + + // Should be overridden by `Subtype`. + void pushPredecessors(T item) { + static_assert(&TopologicalSort<T, Subtype>::pushPredecessors != + &Subtype::pushPredecessors, + "TopologicalSort subclass must implement `pushPredecessors`"); + } + + // Pop until the stack is empty or it has an unfinished item on top. + void finishCurr() { + finished.insert(workStack.back()); + workStack.pop_back(); + while (!workStack.empty() && finished.count(workStack.back())) { + workStack.pop_back(); + } + } + + // Advance until the next item to be finished is on top of the stack or the + // stack is empty. + void stepToNext() { + while (!workStack.empty()) { + T item = workStack.back(); + static_cast<Subtype*>(this)->pushPredecessors(item); + if (workStack.back() == item) { + // No unfinished predecessors, so this is the next item in the sort. + break; + } + } + } + +protected: + // Call this from the `Subtype` constructor to add the root items and from + // `Subtype::pushPredecessors` to add predecessors. + void push(T item) { + if (finished.count(item)) { + return; + } + workStack.push_back(item); + } + +public: + struct Iterator { + using value_type = T; + using difference_type = std::ptrdiff_t; + using reference = T&; + using pointer = T*; + using iterator_category = std::input_iterator_tag; + + TopologicalSort<T, Subtype>* parent; + + bool isEnd() const { return !parent || parent->workStack.empty(); } + bool operator==(Iterator& other) const { return isEnd() == other.isEnd(); } + bool operator!=(Iterator& other) const { return !(*this == other); } + T operator*() { return parent->workStack.back(); } + void operator++(int) { + parent->finishCurr(); + parent->stepToNext(); + } + Iterator& operator++() { + (*this)++; + return *this; + } + }; + + Iterator begin() { + stepToNext(); + return {this}; + } + Iterator end() { return {nullptr}; } +}; + +} // namespace wasm + +#endif // wasm_support_topological_sort_h |