summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/pass.h26
-rw-r--r--src/passes/CoalesceLocals.cpp6
-rw-r--r--src/passes/DeadCodeElimination.cpp4
-rw-r--r--src/passes/DropReturnValues.cpp4
-rw-r--r--src/passes/DuplicateFunctionElimination.cpp40
-rw-r--r--src/passes/MergeBlocks.cpp4
-rw-r--r--src/passes/OptimizeInstructions.cpp4
-rw-r--r--src/passes/PostEmscripten.cpp4
-rw-r--r--src/passes/Print.cpp8
-rw-r--r--src/passes/RemoveUnusedBrs.cpp4
-rw-r--r--src/passes/RemoveUnusedNames.cpp4
-rw-r--r--src/passes/ReorderLocals.cpp4
-rw-r--r--src/passes/SimplifyLocals.cpp8
-rw-r--r--src/passes/Vacuum.cpp4
-rw-r--r--src/passes/pass.cpp91
-rw-r--r--src/wasm-printing.h2
-rw-r--r--src/wasm-traversal.h67
17 files changed, 145 insertions, 139 deletions
diff --git a/src/pass.h b/src/pass.h
index fff2b04eb..4a5fe609d 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -89,7 +89,7 @@ struct PassRunner {
}
template<class P, class Arg>
- void add(Arg& arg){
+ void add(Arg arg){
passes.push_back(new P(arg));
}
@@ -116,6 +116,9 @@ struct PassRunner {
P* getLast();
~PassRunner();
+
+private:
+ void runPassOnFunction(Pass* pass, Function* func);
};
//
@@ -136,6 +139,25 @@ public:
WASM_UNREACHABLE(); // by default, passes cannot be run this way
}
+ // Function parallelism. By default, passes are not run in parallel, but you
+ // can override this method to say that functions are parallelizable. This
+ // should always be safe *unless* you do something in the pass that makes it
+ // not thread-safe; in other words, the Module and Function objects and
+ // so forth are set up so that Functions can be processed in parallel, so
+ // if you do not ad global state that could be raced on, your pass could be
+ // function-parallel.
+ //
+ // Function-parallel passes create an instance of the Walker class per function.
+ // That means that you can't rely on Walker object properties to persist across
+ // your functions, and you can't expect a new object to be created for each
+ // function either (which could be very inefficient).
+ virtual bool isFunctionParallel() { return false; }
+
+ // This method is used to create instances per function for a function-parallel
+ // pass. You may need to override this if you subclass a Walker, as otherwise
+ // this will create the parent class.
+ virtual Pass* create() { WASM_UNREACHABLE(); }
+
std::string name;
protected:
@@ -197,7 +219,7 @@ protected:
public:
Printer() : o(std::cout) {}
- Printer(std::ostream& o) : o(o) {}
+ Printer(std::ostream* o) : o(*o) {}
void run(PassRunner* runner, Module* module) override;
};
diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp
index c36371ec0..e2d11d697 100644
--- a/src/passes/CoalesceLocals.cpp
+++ b/src/passes/CoalesceLocals.cpp
@@ -159,7 +159,9 @@ struct Liveness {
};
struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<CoalesceLocals>, Liveness>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new CoalesceLocals; }
Index numLocals;
@@ -462,7 +464,7 @@ void CoalesceLocals::applyIndices(std::vector<Index>& indices, Expression* root)
}
struct CoalesceLocalsWithLearning : public CoalesceLocals {
- virtual CoalesceLocals* create() override { return new CoalesceLocalsWithLearning; }
+ virtual Pass* create() override { return new CoalesceLocalsWithLearning; }
virtual void pickIndices(std::vector<Index>& indices) override;
};
diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp
index e6669fdd9..08515753c 100644
--- a/src/passes/DeadCodeElimination.cpp
+++ b/src/passes/DeadCodeElimination.cpp
@@ -35,7 +35,9 @@
namespace wasm {
struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination, Visitor<DeadCodeElimination>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new DeadCodeElimination; }
// whether the current code is actually reachable
bool reachable = true;
diff --git a/src/passes/DropReturnValues.cpp b/src/passes/DropReturnValues.cpp
index b89de3011..a146b8a45 100644
--- a/src/passes/DropReturnValues.cpp
+++ b/src/passes/DropReturnValues.cpp
@@ -26,7 +26,9 @@
namespace wasm {
struct DropReturnValues : public WalkerPass<PostWalker<DropReturnValues, Visitor<DropReturnValues>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new DropReturnValues; }
std::vector<Expression*> expressionStack;
diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp
index c33baedd4..715df0815 100644
--- a/src/passes/DuplicateFunctionElimination.cpp
+++ b/src/passes/DuplicateFunctionElimination.cpp
@@ -26,17 +26,13 @@
namespace wasm {
-struct FunctionHasher : public PostWalker<FunctionHasher, Visitor<FunctionHasher>> {
- bool isFunctionParallel() { return true; }
+struct FunctionHasher : public WalkerPass<PostWalker<FunctionHasher, Visitor<FunctionHasher>>> {
+ bool isFunctionParallel() override { return true; }
- FunctionHasher* create() override {
- auto* ret = new FunctionHasher;
- ret->setOutput(output);
- return ret;
- }
+ FunctionHasher(std::map<Function*, uint32_t>* output) : output(output) {}
- void setOutput(std::map<Function*, uint32_t>* output_) {
- output = output_;
+ FunctionHasher* create() override {
+ return new FunctionHasher(output);
}
void doWalkFunction(Function* func) {
@@ -63,17 +59,13 @@ private:
};
};
-struct FunctionReplacer : public PostWalker<FunctionReplacer, Visitor<FunctionReplacer>> {
- bool isFunctionParallel() { return true; }
+struct FunctionReplacer : public WalkerPass<PostWalker<FunctionReplacer, Visitor<FunctionReplacer>>> {
+ bool isFunctionParallel() override { return true; }
- FunctionReplacer* create() override {
- auto* ret = new FunctionReplacer;
- ret->setReplacements(replacements);
- return ret;
- }
+ FunctionReplacer(std::map<Name, Name>* replacements) : replacements(replacements) {}
- void setReplacements(std::map<Name, Name>* replacements_) {
- replacements = replacements_;
+ FunctionReplacer* create() override {
+ return new FunctionReplacer(replacements);
}
void visitCall(Call* curr) {
@@ -95,9 +87,9 @@ struct DuplicateFunctionElimination : public Pass {
for (auto& func : module->functions) {
hashes[func.get()] = 0; // ensure an entry for each function - we must not modify the map shape in parallel, just the values
}
- FunctionHasher hasher;
- hasher.setOutput(&hashes);
- hasher.walkModule(module);
+ PassRunner hasherRunner(module);
+ hasherRunner.add<FunctionHasher>(&hashes);
+ hasherRunner.run();
// Find hash-equal groups
std::map<uint32_t, std::vector<Function*>> hashGroups;
for (auto& func : module->functions) {
@@ -127,9 +119,9 @@ struct DuplicateFunctionElimination : public Pass {
}), v.end());
module->updateFunctionsMap();
// replace direct calls
- FunctionReplacer replacer;
- replacer.setReplacements(&replacements);
- replacer.walkModule(module);
+ PassRunner replacerRunner(module);
+ replacerRunner.add<FunctionReplacer>(&replacements);
+ replacerRunner.run();
// replace in table
for (auto& name : module->table.names) {
auto iter = replacements.find(name);
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 7b9723b0b..39d2b85c9 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -68,7 +68,9 @@
namespace wasm {
struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks, Visitor<MergeBlocks>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new MergeBlocks; }
void visitBlock(Block *curr) {
bool more = true;
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 3d98d0a67..b140c5539 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -26,7 +26,9 @@
namespace wasm {
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, Visitor<OptimizeInstructions>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new OptimizeInstructions; }
void visitIf(If* curr) {
// flip branches to get rid of an i32.eqz
diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp
index 4bd828a00..527e32132 100644
--- a/src/passes/PostEmscripten.cpp
+++ b/src/passes/PostEmscripten.cpp
@@ -25,7 +25,9 @@
namespace wasm {
struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten, Visitor<PostEmscripten>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new PostEmscripten; }
// When we have a Load from a local value (typically a GetLocal) plus a constant offset,
// we may be able to fold it in.
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 04b1d0edb..792d35457 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -628,9 +628,9 @@ static RegisterPass<Printer> registerPass("print", "print in s-expression format
// Prints out a minified module
class MinifiedPrinter : public Printer {
- public:
+public:
MinifiedPrinter() : Printer() {}
- MinifiedPrinter(std::ostream& o) : Printer(o) {}
+ MinifiedPrinter(std::ostream* o) : Printer(o) {}
void run(PassRunner* runner, Module* module) override {
PrintSExpression print(o);
@@ -644,9 +644,9 @@ static RegisterPass<MinifiedPrinter> registerMinifyPass("print-minified", "print
// Prints out a module withough elision, i.e., the full ast
class FullPrinter : public Printer {
- public:
+public:
FullPrinter() : Printer() {}
- FullPrinter(std::ostream& o) : Printer(o) {}
+ FullPrinter(std::ostream* o) : Printer(o) {}
void run(PassRunner* runner, Module* module) override {
PrintSExpression print(o);
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 9aa063f60..ffdba0768 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -25,7 +25,9 @@
namespace wasm {
struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<RemoveUnusedBrs>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new RemoveUnusedBrs; }
bool anotherCycle;
diff --git a/src/passes/RemoveUnusedNames.cpp b/src/passes/RemoveUnusedNames.cpp
index 4a37270e6..705ec2724 100644
--- a/src/passes/RemoveUnusedNames.cpp
+++ b/src/passes/RemoveUnusedNames.cpp
@@ -24,7 +24,9 @@
namespace wasm {
struct RemoveUnusedNames : public WalkerPass<PostWalker<RemoveUnusedNames, Visitor<RemoveUnusedNames>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new RemoveUnusedNames; }
// We maintain a list of branches that we saw in children, then when we reach
// a parent block, we know if it was branched to
diff --git a/src/passes/ReorderLocals.cpp b/src/passes/ReorderLocals.cpp
index eea8ea962..60d892b87 100644
--- a/src/passes/ReorderLocals.cpp
+++ b/src/passes/ReorderLocals.cpp
@@ -28,7 +28,9 @@
namespace wasm {
struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals, Visitor<ReorderLocals>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new ReorderLocals; }
std::map<Index, Index> counts; // local => times it is used
std::map<Index, Index> firstUses; // local => index in the list of which local is first seen
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 57057e8f9..b79b55eea 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -42,7 +42,7 @@ namespace wasm {
// Helper classes
-struct GetLocalCounter : public WalkerPass<PostWalker<GetLocalCounter, Visitor<GetLocalCounter>>> {
+struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCounter>> {
std::vector<int>* numGetLocals;
void visitGetLocal(GetLocal *curr) {
@@ -50,7 +50,7 @@ struct GetLocalCounter : public WalkerPass<PostWalker<GetLocalCounter, Visitor<G
}
};
-struct SetLocalRemover : public WalkerPass<PostWalker<SetLocalRemover, Visitor<SetLocalRemover>>> {
+struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemover>> {
std::vector<int>* numGetLocals;
void visitSetLocal(SetLocal *curr) {
@@ -63,7 +63,9 @@ struct SetLocalRemover : public WalkerPass<PostWalker<SetLocalRemover, Visitor<S
// Main class
struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new SimplifyLocals; }
// information for a set_local we can sink
struct SinkableInfo {
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
index 1524927ec..5e1558285 100644
--- a/src/passes/Vacuum.cpp
+++ b/src/passes/Vacuum.cpp
@@ -26,7 +26,9 @@
namespace wasm {
struct Vacuum : public WalkerPass<PostWalker<Vacuum, Visitor<Vacuum>>> {
- bool isFunctionParallel() { return true; }
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new Vacuum; }
std::vector<Expression*> expressionStack;
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index e5bbf5866..e43b712e5 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -94,64 +94,97 @@ void PassRunner::addDefaultGlobalOptimizationPasses() {
}
void PassRunner::run() {
- std::chrono::high_resolution_clock::time_point beforeEverything;
- size_t padding = 0;
if (debug) {
+ // for debug logging purposes, run each pass in full before running the other
+ std::chrono::high_resolution_clock::time_point beforeEverything;
+ size_t padding = 0;
std::cerr << "[PassRunner] running passes..." << std::endl;
beforeEverything = std::chrono::high_resolution_clock::now();
for (auto pass : passes) {
padding = std::max(padding, pass->name.size());
}
- }
- for (auto pass : passes) {
- currPass = pass;
- std::chrono::high_resolution_clock::time_point before;
- if (debug) {
+ for (auto* pass : passes) {
+ currPass = pass;
+ std::chrono::high_resolution_clock::time_point before;
std::cerr << "[PassRunner] running pass: " << pass->name << "... ";
for (size_t i = 0; i < padding - pass->name.size(); i++) {
std::cerr << ' ';
}
before = std::chrono::high_resolution_clock::now();
- }
- pass->run(this, wasm);
- if (debug) {
+ pass->run(this, wasm);
auto after = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = after - before;
std::cerr << diff.count() << " seconds." << std::endl;
}
- }
- if (debug) {
auto after = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = after - beforeEverything;
std::cerr << "[PassRunner] passes took " << diff.count() << " seconds." << std::endl;
+ } else {
+ // non-debug normal mode, run them in an optimal manner - for locality it is better
+ // to run as many passes as possible on a single function before moving to the next
+ std::vector<Pass*> stack;
+ auto flush = [&]() {
+ if (stack.size() > 0) {
+ // run the stack of passes on all the functions, in parallel
+ size_t num = ThreadPool::get()->size();
+ std::vector<std::function<ThreadWorkState ()>> doWorkers;
+ std::atomic<size_t> nextFunction;
+ nextFunction.store(0);
+ size_t numFunctions = wasm->functions.size();
+ for (size_t i = 0; i < num; i++) {
+ doWorkers.push_back([&]() {
+ auto index = nextFunction.fetch_add(1);
+ // get the next task, if there is one
+ if (index >= numFunctions) {
+ return ThreadWorkState::Finished; // nothing left
+ }
+ Function* func = wasm->functions[index].get();
+ // do the current task: run all passes on this function
+ for (auto* pass : stack) {
+ runPassOnFunction(pass, func);
+ }
+ if (index + 1 == numFunctions) {
+ return ThreadWorkState::Finished; // we did the last one
+ }
+ return ThreadWorkState::More;
+ });
+ }
+ ThreadPool::get()->work(doWorkers);
+ }
+ stack.clear();
+ };
+ for (auto* pass : passes) {
+ if (pass->isFunctionParallel()) {
+ stack.push_back(pass);
+ } else {
+ flush();
+ pass->run(this, wasm);
+ }
+ }
+ flush();
}
}
void PassRunner::runFunction(Function* func) {
- for (auto pass : passes) {
- pass->runFunction(this, wasm, func);
+ for (auto* pass : passes) {
+ runPassOnFunction(pass, func);
}
}
-template<class P>
-P* PassRunner::getLast() {
- bool found = false;
- P* ret;
- for (int i = passes.size() - 1; i >= 0; i--) {
- if (found && (ret = dynamic_cast<P*>(passes[i]))) {
- return ret;
- }
- if (passes[i] == currPass) {
- found = true;
- }
- }
- return nullptr;
-}
-
PassRunner::~PassRunner() {
for (auto pass : passes) {
delete pass;
}
}
+void PassRunner::runPassOnFunction(Pass* pass, Function* func) {
+ // function-parallel passes get a new instance per function
+ if (pass->isFunctionParallel()) {
+ auto instance = std::unique_ptr<Pass>(pass->create());
+ instance->runFunction(this, wasm, func);
+ } else {
+ pass->runFunction(this, wasm, func);
+ }
+}
+
} // namespace wasm
diff --git a/src/wasm-printing.h b/src/wasm-printing.h
index d1f1c42a5..2f1c97831 100644
--- a/src/wasm-printing.h
+++ b/src/wasm-printing.h
@@ -27,7 +27,7 @@ namespace wasm {
struct WasmPrinter {
static std::ostream& printModule(Module* module, std::ostream& o) {
PassRunner passRunner(module);
- passRunner.add<Printer>(o);
+ passRunner.add<Printer>(&o);
passRunner.run();
return o;
}
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index e5da420b0..cf023d3fa 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -140,28 +140,6 @@ struct UnifiedExpressionVisitor : public Visitor<SubType> {
//
template<typename SubType, typename VisitorType>
struct Walker : public VisitorType {
- // Function parallelism. By default, walks are not run in parallel, but you
- // can override this method to say that functions are parallelizable. This
- // should always be safe *unless* you do something in the pass that makes it
- // not thread-safe; in other words, the Module and Function objects and
- // so forth are set up so that Functions can be processed in parallel, so
- // if you do not ad global state that could be raced on, your pass could be
- // function-parallel.
- //
- // Function-parallel passes create an instance of the Walker class per function.
- // That means that you can't rely on Walker object properties to persist across
- // your functions, and you can't expect a new object to be created for each
- // function either (which could be very inefficient).
- bool isFunctionParallel() { return false; }
-
- // This method is used to create instances per function for a function-parallel
- // pass. You may need to override this if you subclass a Walker, as otherwise
- // this will create the parent class.
- // Note that this returns nullptr, and we check if the result is nullptr and
- // do new SubType later. This is important since non-function parallel
- // passes may not be constructable via new SubType.
- virtual SubType* create() { return nullptr; }
-
// Useful methods for visitor implementions
// Replace the current node. You can call this in your visit*() methods.
@@ -214,49 +192,8 @@ struct Walker : public VisitorType {
for (auto& curr : module->exports) {
self->visitExport(curr.get());
}
-
- auto processFunction = [this](Module* module, SubType* instance, Function* func) {
- std::unique_ptr<SubType> allocated;
- if (!instance) {
- instance = create();
- if (!instance) instance = new SubType;
- assert(module);
- instance->setModule(module);
- allocated = std::unique_ptr<SubType>(instance);
- }
- instance->walkFunction(func);
- };
-
- // if this is not a function-parallel traversal, run
- // sequentially
- if (!self->isFunctionParallel()) {
- for (auto& curr : module->functions) {
- processFunction(nullptr, self, curr.get());
- }
- } else {
- // execute in parallel on helper threads
- size_t num = ThreadPool::get()->size();
- std::vector<std::function<ThreadWorkState ()>> doWorkers;
- std::atomic<size_t> nextFunction;
- nextFunction.store(0);
- size_t numFunctions = module->functions.size();
- for (size_t i = 0; i < num; i++) {
- doWorkers.push_back([&nextFunction, numFunctions, &module, processFunction]() {
- auto index = nextFunction.fetch_add(1);
- // get the next task, if there is one
- if (index >= numFunctions) {
- return ThreadWorkState::Finished; // nothing left
- }
- Function* curr = module->functions[index].get();
- // do the current task
- processFunction(module, nullptr, curr);
- if (index + 1 == numFunctions) {
- return ThreadWorkState::Finished; // we did the last one
- }
- return ThreadWorkState::More;
- });
- }
- ThreadPool::get()->work(doWorkers);
+ for (auto& curr : module->functions) {
+ self->walkFunction(curr.get());
}
self->visitTable(&module->table);
self->visitMemory(&module->memory);