summaryrefslogtreecommitdiff
path: root/src/wasm-traversal.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-traversal.h')
-rw-r--r--src/wasm-traversal.h67
1 files changed, 55 insertions, 12 deletions
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index efbe12586..031343d46 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -28,6 +28,7 @@
#define wasm_traversal_h
#include "wasm.h"
+#include "support/threads.h"
namespace wasm {
@@ -112,22 +113,29 @@ struct Walker : public Visitor<SubType> {
// passes that need to do the same thing for every node type.
void visitExpression(Expression* curr) {}
+ // Function parallelism. By default, walks are not run in parallel, but you
+ // can override this method to say that functions are parallelizable. This
+ // should always be safe *unless* you do something in the pass that makes it
+ // not thread-safe; in other words, the Module and Function objects and
+ // so forth are set up so that Functions can be processed in parallel, so
+ // if you do not ad global state that could be raced on, your pass could be
+ // function-parallel.
+ //
+ // Function-parallel passes create an instance of the Walker class per core.
+ // That means that you can't rely on Walker object properties to persist across
+ // your functions, and you can't expect a new object to be created for each
+ // function either (which could be very inefficient).
+ bool isFunctionParallel() { return false; }
+
// Node replacing as we walk - call replaceCurrent from
// your visitors.
- Expression *replace = nullptr;
-
void replaceCurrent(Expression *expression) {
replace = expression;
}
// Walk starting
- void startWalk(Function *func) {
- SubType* self = static_cast<SubType*>(this);
- self->walk(func->body);
- }
-
void startWalk(Module *module) {
// Dispatch statically through the SubType.
SubType* self = static_cast<SubType*>(this);
@@ -140,9 +148,42 @@ struct Walker : public Visitor<SubType> {
for (auto curr : module->exports) {
self->visitExport(curr);
}
- for (auto curr : module->functions) {
- self->startWalk(curr);
- self->visitFunction(curr);
+
+ // if this is not a function-parallel traversal, run
+ // sequentially
+ if (!self->isFunctionParallel()) {
+ for (auto curr : module->functions) {
+ self->walk(curr->body);
+ self->visitFunction(curr);
+ }
+ } else {
+ // execute in parallel on helper threads
+ size_t num = ThreadPool::get()->size();
+ std::vector<std::unique_ptr<SubType>> instances;
+ std::vector<std::function<ThreadWorkState ()>> doWorkers;
+ std::atomic<size_t> nextFunction;
+ nextFunction.store(0);
+ size_t numFunctions = module->functions.size();
+ for (size_t i = 0; i < num; i++) {
+ auto* instance = new SubType();
+ instances.push_back(std::unique_ptr<SubType>(instance));
+ doWorkers.push_back([instance, &nextFunction, numFunctions, &module]() {
+ auto index = nextFunction.fetch_add(1);
+ // get the next task, if there is one
+ if (index >= numFunctions) {
+ return ThreadWorkState::Finished; // nothing left
+ }
+ Function* curr = module->functions[index];
+ // do the current task
+ instance->walk(curr->body);
+ instance->visitFunction(curr);
+ if (index + 1 == numFunctions) {
+ return ThreadWorkState::Finished; // we did the last one
+ }
+ return ThreadWorkState::More;
+ });
+ }
+ ThreadPool::get()->work(doWorkers);
}
self->visitTable(&module->table);
self->visitMemory(&module->memory);
@@ -161,8 +202,6 @@ struct Walker : public Visitor<SubType> {
Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {}
};
- std::vector<Task> stack;
-
void pushTask(TaskFunc func, Expression** currp) {
stack.emplace_back(func, currp);
}
@@ -216,6 +255,10 @@ struct Walker : public Visitor<SubType> {
static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); }
static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); }
static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); }
+
+private:
+ Expression *replace = nullptr; // a node to replace
+ std::vector<Task> stack; // stack of tasks
};
// Walks in post-order, i.e., children first. When there isn't an obvious