summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ir/module-utils.h99
-rw-r--r--src/wasm-type.h9
-rw-r--r--src/wasm/wasm-binary.cpp74
-rw-r--r--src/wasm/wasm-type.cpp4
4 files changed, 105 insertions, 81 deletions
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index 0fd478551..f2fda908c 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -262,33 +262,20 @@ template<typename T> inline void iterDefinedEvents(Module& wasm, T visitor) {
}
}
-// Helper class for analyzing the call graph.
-//
-// Provides hooks for running some initial calculation on each function (which
-// is done in parallel), writing to a FunctionInfo structure for each function.
-// Then you can call propagateBack() to propagate a property of interest to the
-// calling functions, transitively.
-//
-// For example, if some functions are known to call an import "foo", then you
-// can use this to find which functions call something that might eventually
-// reach foo, by initially marking the direct callers as "calling foo" and
-// propagating that backwards.
-template<typename T> struct CallGraphPropertyAnalysis {
+// Helper class for performing an operation on all the functions in the module,
+// in parallel, with an Info object for each one that can contain results of
+// some computation that the operation performs.
+// The operation performend should not modify the wasm module in any way.
+// TODO: enforce this
+template<typename T> struct ParallelFunctionAnalysis {
Module& wasm;
- // The basic information for each function about whom it calls and who is
- // called by it.
- struct FunctionInfo {
- std::set<Function*> callsTo;
- std::set<Function*> calledBy;
- };
-
typedef std::map<Function*, T> Map;
Map map;
typedef std::function<void(Function*, T&)> Func;
- CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
+ ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) {
// Fill in map, as we operate on it in parallel (each function to its own
// entry).
for (auto& func : wasm.functions) {
@@ -304,30 +291,78 @@ template<typename T> struct CallGraphPropertyAnalysis {
struct Mapper : public WalkerPass<PostWalker<Mapper>> {
bool isFunctionParallel() override { return true; }
+ bool modifiesBinaryenIR() override { return false; }
- Mapper(Module* module, Map* map, Func work)
+ Mapper(Module& module, Map& map, Func work)
: module(module), map(map), work(work) {}
Mapper* create() override { return new Mapper(module, map, work); }
- void visitCall(Call* curr) {
- (*map)[this->getFunction()].callsTo.insert(
- module->getFunction(curr->target));
- }
-
- void visitFunction(Function* curr) {
- assert((*map).count(curr));
- work(curr, (*map)[curr]);
+ void doWalkFunction(Function* curr) {
+ assert(map.count(curr));
+ work(curr, map[curr]);
}
private:
- Module* module;
- Map* map;
+ Module& module;
+ Map& map;
Func work;
};
PassRunner runner(&wasm);
- Mapper(&wasm, &map, work).run(&runner, &wasm);
+ Mapper(wasm, map, work).run(&runner, &wasm);
+ }
+};
+
+// Helper class for analyzing the call graph.
+//
+// Provides hooks for running some initial calculation on each function (which
+// is done in parallel), writing to a FunctionInfo structure for each function.
+// Then you can call propagateBack() to propagate a property of interest to the
+// calling functions, transitively.
+//
+// For example, if some functions are known to call an import "foo", then you
+// can use this to find which functions call something that might eventually
+// reach foo, by initially marking the direct callers as "calling foo" and
+// propagating that backwards.
+template<typename T> struct CallGraphPropertyAnalysis {
+ Module& wasm;
+
+ // The basic information for each function about whom it calls and who is
+ // called by it.
+ struct FunctionInfo {
+ std::set<Function*> callsTo;
+ std::set<Function*> calledBy;
+ };
+
+ typedef std::map<Function*, T> Map;
+ Map map;
+
+ typedef std::function<void(Function*, T&)> Func;
+
+ CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) {
+ ParallelFunctionAnalysis<T> analysis(wasm, [&](Function* func, T& info) {
+ work(func, info);
+ if (func->imported()) {
+ return;
+ }
+ struct Mapper : public PostWalker<Mapper> {
+ Mapper(Module* module, T& info, Func work)
+ : module(module), info(info), work(work) {}
+
+ void visitCall(Call* curr) {
+ info.callsTo.insert(module->getFunction(curr->target));
+ }
+
+ private:
+ Module* module;
+ T& info;
+ Func work;
+ } mapper(&wasm, info, work);
+ mapper.walk(func->body);
+ });
+
+ map.swap(analysis.map);
// Find what is called by what.
for (auto& pair : map) {
diff --git a/src/wasm-type.h b/src/wasm-type.h
index 7b3845aec..ddfb7c9a1 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -96,10 +96,6 @@ struct ResultType {
std::string toString() const;
};
-std::ostream& operator<<(std::ostream& os, Type t);
-std::ostream& operator<<(std::ostream& os, ParamType t);
-std::ostream& operator<<(std::ostream& os, ResultType t);
-
struct Signature {
Type params;
Type results;
@@ -112,6 +108,11 @@ struct Signature {
bool operator<(const Signature& other) const;
};
+std::ostream& operator<<(std::ostream& os, Type t);
+std::ostream& operator<<(std::ostream& os, ParamType t);
+std::ostream& operator<<(std::ostream& os, ResultType t);
+std::ostream& operator<<(std::ostream& os, Signature t);
+
constexpr Type none = Type::none;
constexpr Type i32 = Type::i32;
constexpr Type i64 = Type::i64;
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 2336da912..52380041d 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -16,8 +16,8 @@
#include <algorithm>
#include <fstream>
-#include <shared_mutex>
+#include "ir/module-utils.h"
#include "support/bits.h"
#include "wasm-binary.h"
#include "wasm-stack.h"
@@ -25,9 +25,30 @@
namespace wasm {
void WasmBinaryWriter::prepare() {
- // Collect function types and their frequencies
- using Counts = std::unordered_map<Signature, size_t>;
- using AtomicCounts = std::unordered_map<Signature, std::atomic_size_t>;
+ // Collect function types and their frequencies. Collect information in each
+ // function in parallel, then merge.
+ typedef std::unordered_map<Signature, size_t> Counts;
+ ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(
+ *wasm, [&](Function* func, Counts& counts) {
+ if (func->imported()) {
+ return;
+ }
+ struct TypeCounter : PostWalker<TypeCounter> {
+ Module& wasm;
+ Counts& counts;
+
+ TypeCounter(Module& wasm, Counts& counts)
+ : wasm(wasm), counts(counts) {}
+
+ void visitCallIndirect(CallIndirect* curr) {
+ auto* type = wasm.getFunctionType(curr->fullType);
+ Signature sig(Type(type->params), type->result);
+ counts[sig]++;
+ }
+ };
+ TypeCounter(*wasm, counts).walk(func->body);
+ });
+ // Collect all the counts.
Counts counts;
for (auto& curr : wasm->functions) {
counts[Signature(Type(curr->params), curr->result)]++;
@@ -35,49 +56,12 @@ void WasmBinaryWriter::prepare() {
for (auto& curr : wasm->events) {
counts[curr->sig]++;
}
-
- // Parallelize collection of call_indirect type counts
- struct TypeCounter : WalkerPass<PostWalker<TypeCounter>> {
- AtomicCounts& counts;
- std::shared_timed_mutex& mutex;
- TypeCounter(AtomicCounts& counts, std::shared_timed_mutex& mutex)
- : counts(counts), mutex(mutex) {}
- bool isFunctionParallel() override { return true; }
- bool modifiesBinaryenIR() override { return false; }
- void visitCallIndirect(CallIndirect* curr) {
- auto* type = getModule()->getFunctionType(curr->fullType);
- Signature sig(Type(type->params), type->result);
- {
- std::shared_lock<std::shared_timed_mutex> lock(mutex);
- auto it = counts.find(sig);
- if (it != counts.end()) {
- it->second++;
- return;
- }
- }
- {
- std::lock_guard<std::shared_timed_mutex> lock(mutex);
- counts[sig]++;
- }
+ for (auto& pair : analysis.map) {
+ Counts& functionCounts = pair.second;
+ for (auto& innerPair : functionCounts) {
+ counts[innerPair.first] += innerPair.second;
}
- Pass* create() override { return new TypeCounter(counts, mutex); }
- };
-
- std::shared_timed_mutex mutex;
- AtomicCounts parallelCounts;
- for (auto& kv : counts) {
- parallelCounts[kv.first] = 0;
}
-
- TypeCounter counter(parallelCounts, mutex);
- PassRunner runner(wasm);
- runner.setIsNested(true);
- counter.run(&runner, wasm);
-
- for (auto& kv : parallelCounts) {
- counts[kv.first] += kv.second;
- }
-
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
counts.end());
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index e114a5540..50ddc82f8 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -222,6 +222,10 @@ std::ostream& operator<<(std::ostream& os, ResultType param) {
return printPrefixedTypes(os, "result", param.type);
}
+std::ostream& operator<<(std::ostream& os, Signature sig) {
+ return os << "Signature(" << sig.params << " => " << sig.results << ")";
+}
+
std::string Type::toString() const { return genericToString(*this); }
std::string ParamType::toString() const { return genericToString(*this); }