summaryrefslogtreecommitdiff
path: root/src/wasm/wasm-binary.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2019-11-26 15:22:04 -0800
committerGitHub <noreply@github.com>2019-11-26 15:22:04 -0800
commitec53d11e0792884e1125fe5a1a437a5eff260259 (patch)
tree0ebeab40cade82309507aceab7f810e9a37929fb /src/wasm/wasm-binary.cpp
parent7665f703f4e3437564be25ae276e1daaedd98d79 (diff)
downloadbinaryen-ec53d11e0792884e1125fe5a1a437a5eff260259.tar.gz
binaryen-ec53d11e0792884e1125fe5a1a437a5eff260259.tar.bz2
binaryen-ec53d11e0792884e1125fe5a1a437a5eff260259.zip
Refactor and optimize binary writing type collection (#2478)
Create a new ParallelFunctionAnalysis helper, which lets us run in parallel on all functions and collect info from them, without manually handling locks etc. Use that in the binary writing code's type collection logic, avoiding a lock for each type increment. Also add Signature printing which was useful to debug this.
Diffstat (limited to 'src/wasm/wasm-binary.cpp')
-rw-r--r--src/wasm/wasm-binary.cpp74
1 files changed, 29 insertions, 45 deletions
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 2336da912..52380041d 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -16,8 +16,8 @@
#include <algorithm>
#include <fstream>
-#include <shared_mutex>
+#include "ir/module-utils.h"
#include "support/bits.h"
#include "wasm-binary.h"
#include "wasm-stack.h"
@@ -25,9 +25,30 @@
namespace wasm {
void WasmBinaryWriter::prepare() {
- // Collect function types and their frequencies
- using Counts = std::unordered_map<Signature, size_t>;
- using AtomicCounts = std::unordered_map<Signature, std::atomic_size_t>;
+ // Collect function types and their frequencies. Collect information in each
+ // function in parallel, then merge.
+ typedef std::unordered_map<Signature, size_t> Counts;
+ ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(
+ *wasm, [&](Function* func, Counts& counts) {
+ if (func->imported()) {
+ return;
+ }
+ struct TypeCounter : PostWalker<TypeCounter> {
+ Module& wasm;
+ Counts& counts;
+
+ TypeCounter(Module& wasm, Counts& counts)
+ : wasm(wasm), counts(counts) {}
+
+ void visitCallIndirect(CallIndirect* curr) {
+ auto* type = wasm.getFunctionType(curr->fullType);
+ Signature sig(Type(type->params), type->result);
+ counts[sig]++;
+ }
+ };
+ TypeCounter(*wasm, counts).walk(func->body);
+ });
+ // Collect all the counts.
Counts counts;
for (auto& curr : wasm->functions) {
counts[Signature(Type(curr->params), curr->result)]++;
@@ -35,49 +56,12 @@ void WasmBinaryWriter::prepare() {
for (auto& curr : wasm->events) {
counts[curr->sig]++;
}
-
- // Parallelize collection of call_indirect type counts
- struct TypeCounter : WalkerPass<PostWalker<TypeCounter>> {
- AtomicCounts& counts;
- std::shared_timed_mutex& mutex;
- TypeCounter(AtomicCounts& counts, std::shared_timed_mutex& mutex)
- : counts(counts), mutex(mutex) {}
- bool isFunctionParallel() override { return true; }
- bool modifiesBinaryenIR() override { return false; }
- void visitCallIndirect(CallIndirect* curr) {
- auto* type = getModule()->getFunctionType(curr->fullType);
- Signature sig(Type(type->params), type->result);
- {
- std::shared_lock<std::shared_timed_mutex> lock(mutex);
- auto it = counts.find(sig);
- if (it != counts.end()) {
- it->second++;
- return;
- }
- }
- {
- std::lock_guard<std::shared_timed_mutex> lock(mutex);
- counts[sig]++;
- }
+ for (auto& pair : analysis.map) {
+ Counts& functionCounts = pair.second;
+ for (auto& innerPair : functionCounts) {
+ counts[innerPair.first] += innerPair.second;
}
- Pass* create() override { return new TypeCounter(counts, mutex); }
- };
-
- std::shared_timed_mutex mutex;
- AtomicCounts parallelCounts;
- for (auto& kv : counts) {
- parallelCounts[kv.first] = 0;
}
-
- TypeCounter counter(parallelCounts, mutex);
- PassRunner runner(wasm);
- runner.setIsNested(true);
- counter.run(&runner, wasm);
-
- for (auto& kv : parallelCounts) {
- counts[kv.first] += kv.second;
- }
-
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
counts.end());
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {