diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/module-utils.h | 99 | ||||
-rw-r--r-- | src/wasm-type.h | 9 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 74 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 4 |
4 files changed, 105 insertions, 81 deletions
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 0fd478551..f2fda908c 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -262,33 +262,20 @@ template<typename T> inline void iterDefinedEvents(Module& wasm, T visitor) { } } -// Helper class for analyzing the call graph. -// -// Provides hooks for running some initial calculation on each function (which -// is done in parallel), writing to a FunctionInfo structure for each function. -// Then you can call propagateBack() to propagate a property of interest to the -// calling functions, transitively. -// -// For example, if some functions are known to call an import "foo", then you -// can use this to find which functions call something that might eventually -// reach foo, by initially marking the direct callers as "calling foo" and -// propagating that backwards. -template<typename T> struct CallGraphPropertyAnalysis { +// Helper class for performing an operation on all the functions in the module, +// in parallel, with an Info object for each one that can contain results of +// some computation that the operation performs. +// The operation performend should not modify the wasm module in any way. +// TODO: enforce this +template<typename T> struct ParallelFunctionAnalysis { Module& wasm; - // The basic information for each function about whom it calls and who is - // called by it. - struct FunctionInfo { - std::set<Function*> callsTo; - std::set<Function*> calledBy; - }; - typedef std::map<Function*, T> Map; Map map; typedef std::function<void(Function*, T&)> Func; - CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) { + ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) { // Fill in map, as we operate on it in parallel (each function to its own // entry). for (auto& func : wasm.functions) { @@ -304,30 +291,78 @@ template<typename T> struct CallGraphPropertyAnalysis { struct Mapper : public WalkerPass<PostWalker<Mapper>> { bool isFunctionParallel() override { return true; } + bool modifiesBinaryenIR() override { return false; } - Mapper(Module* module, Map* map, Func work) + Mapper(Module& module, Map& map, Func work) : module(module), map(map), work(work) {} Mapper* create() override { return new Mapper(module, map, work); } - void visitCall(Call* curr) { - (*map)[this->getFunction()].callsTo.insert( - module->getFunction(curr->target)); - } - - void visitFunction(Function* curr) { - assert((*map).count(curr)); - work(curr, (*map)[curr]); + void doWalkFunction(Function* curr) { + assert(map.count(curr)); + work(curr, map[curr]); } private: - Module* module; - Map* map; + Module& module; + Map& map; Func work; }; PassRunner runner(&wasm); - Mapper(&wasm, &map, work).run(&runner, &wasm); + Mapper(wasm, map, work).run(&runner, &wasm); + } +}; + +// Helper class for analyzing the call graph. +// +// Provides hooks for running some initial calculation on each function (which +// is done in parallel), writing to a FunctionInfo structure for each function. +// Then you can call propagateBack() to propagate a property of interest to the +// calling functions, transitively. +// +// For example, if some functions are known to call an import "foo", then you +// can use this to find which functions call something that might eventually +// reach foo, by initially marking the direct callers as "calling foo" and +// propagating that backwards. +template<typename T> struct CallGraphPropertyAnalysis { + Module& wasm; + + // The basic information for each function about whom it calls and who is + // called by it. + struct FunctionInfo { + std::set<Function*> callsTo; + std::set<Function*> calledBy; + }; + + typedef std::map<Function*, T> Map; + Map map; + + typedef std::function<void(Function*, T&)> Func; + + CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) { + ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) { + work(func, info); + if (func->imported()) { + return; + } + struct Mapper : public PostWalker<Mapper> { + Mapper(Module* module, T& info, Func work) + : module(module), info(info), work(work) {} + + void visitCall(Call* curr) { + info.callsTo.insert(module->getFunction(curr->target)); + } + + private: + Module* module; + T& info; + Func work; + } mapper(&wasm, info, work); + mapper.walk(func->body); + }); + + map.swap(analysis.map); // Find what is called by what. for (auto& pair : map) { diff --git a/src/wasm-type.h b/src/wasm-type.h index 7b3845aec..ddfb7c9a1 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -96,10 +96,6 @@ struct ResultType { std::string toString() const; }; -std::ostream& operator<<(std::ostream& os, Type t); -std::ostream& operator<<(std::ostream& os, ParamType t); -std::ostream& operator<<(std::ostream& os, ResultType t); - struct Signature { Type params; Type results; @@ -112,6 +108,11 @@ struct Signature { bool operator<(const Signature& other) const; }; +std::ostream& operator<<(std::ostream& os, Type t); +std::ostream& operator<<(std::ostream& os, ParamType t); +std::ostream& operator<<(std::ostream& os, ResultType t); +std::ostream& operator<<(std::ostream& os, Signature t); + constexpr Type none = Type::none; constexpr Type i32 = Type::i32; constexpr Type i64 = Type::i64; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 2336da912..52380041d 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -16,8 +16,8 @@ #include <algorithm> #include <fstream> -#include <shared_mutex> +#include "ir/module-utils.h" #include "support/bits.h" #include "wasm-binary.h" #include "wasm-stack.h" @@ -25,9 +25,30 @@ namespace wasm { void WasmBinaryWriter::prepare() { - // Collect function types and their frequencies - using Counts = std::unordered_map<Signature, size_t>; - using AtomicCounts = std::unordered_map<Signature, std::atomic_size_t>; + // Collect function types and their frequencies. Collect information in each + // function in parallel, then merge. + typedef std::unordered_map<Signature, size_t> Counts; + ModuleUtils::ParallelFunctionAnalysis<Counts> analysis( + *wasm, [&](Function* func, Counts& counts) { + if (func->imported()) { + return; + } + struct TypeCounter : PostWalker<TypeCounter> { + Module& wasm; + Counts& counts; + + TypeCounter(Module& wasm, Counts& counts) + : wasm(wasm), counts(counts) {} + + void visitCallIndirect(CallIndirect* curr) { + auto* type = wasm.getFunctionType(curr->fullType); + Signature sig(Type(type->params), type->result); + counts[sig]++; + } + }; + TypeCounter(*wasm, counts).walk(func->body); + }); + // Collect all the counts. Counts counts; for (auto& curr : wasm->functions) { counts[Signature(Type(curr->params), curr->result)]++; @@ -35,49 +56,12 @@ void WasmBinaryWriter::prepare() { for (auto& curr : wasm->events) { counts[curr->sig]++; } - - // Parallelize collection of call_indirect type counts - struct TypeCounter : WalkerPass<PostWalker<TypeCounter>> { - AtomicCounts& counts; - std::shared_timed_mutex& mutex; - TypeCounter(AtomicCounts& counts, std::shared_timed_mutex& mutex) - : counts(counts), mutex(mutex) {} - bool isFunctionParallel() override { return true; } - bool modifiesBinaryenIR() override { return false; } - void visitCallIndirect(CallIndirect* curr) { - auto* type = getModule()->getFunctionType(curr->fullType); - Signature sig(Type(type->params), type->result); - { - std::shared_lock<std::shared_timed_mutex> lock(mutex); - auto it = counts.find(sig); - if (it != counts.end()) { - it->second++; - return; - } - } - { - std::lock_guard<std::shared_timed_mutex> lock(mutex); - counts[sig]++; - } + for (auto& pair : analysis.map) { + Counts& functionCounts = pair.second; + for (auto& innerPair : functionCounts) { + counts[innerPair.first] += innerPair.second; } - Pass* create() override { return new TypeCounter(counts, mutex); } - }; - - std::shared_timed_mutex mutex; - AtomicCounts parallelCounts; - for (auto& kv : counts) { - parallelCounts[kv.first] = 0; } - - TypeCounter counter(parallelCounts, mutex); - PassRunner runner(wasm); - runner.setIsNested(true); - counter.run(&runner, wasm); - - for (auto& kv : parallelCounts) { - counts[kv.first] += kv.second; - } - std::vector<std::pair<Signature, size_t>> sorted(counts.begin(), counts.end()); std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index e114a5540..50ddc82f8 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -222,6 +222,10 @@ std::ostream& operator<<(std::ostream& os, ResultType param) { return printPrefixedTypes(os, "result", param.type); } +std::ostream& operator<<(std::ostream& os, Signature sig) { + return os << "Signature(" << sig.params << " => " << sig.results << ")"; +} + std::string Type::toString() const { return genericToString(*this); } std::string ParamType::toString() const { return genericToString(*this); } |