diff options
Diffstat (limited to 'src/wasm-traversal.h')
-rw-r--r-- | src/wasm-traversal.h | 67 |
1 files changed, 55 insertions, 12 deletions
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index efbe12586..031343d46 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -28,6 +28,7 @@ #define wasm_traversal_h #include "wasm.h" +#include "support/threads.h" namespace wasm { @@ -112,22 +113,29 @@ struct Walker : public Visitor<SubType> { // passes that need to do the same thing for every node type. void visitExpression(Expression* curr) {} + // Function parallelism. By default, walks are not run in parallel, but you + // can override this method to say that functions are parallelizable. This + // should always be safe *unless* you do something in the pass that makes it + // not thread-safe; in other words, the Module and Function objects and + // so forth are set up so that Functions can be processed in parallel, so + // if you do not ad global state that could be raced on, your pass could be + // function-parallel. + // + // Function-parallel passes create an instance of the Walker class per core. + // That means that you can't rely on Walker object properties to persist across + // your functions, and you can't expect a new object to be created for each + // function either (which could be very inefficient). + bool isFunctionParallel() { return false; } + // Node replacing as we walk - call replaceCurrent from // your visitors. - Expression *replace = nullptr; - void replaceCurrent(Expression *expression) { replace = expression; } // Walk starting - void startWalk(Function *func) { - SubType* self = static_cast<SubType*>(this); - self->walk(func->body); - } - void startWalk(Module *module) { // Dispatch statically through the SubType. SubType* self = static_cast<SubType*>(this); @@ -140,9 +148,42 @@ struct Walker : public Visitor<SubType> { for (auto curr : module->exports) { self->visitExport(curr); } - for (auto curr : module->functions) { - self->startWalk(curr); - self->visitFunction(curr); + + // if this is not a function-parallel traversal, run + // sequentially + if (!self->isFunctionParallel()) { + for (auto curr : module->functions) { + self->walk(curr->body); + self->visitFunction(curr); + } + } else { + // execute in parallel on helper threads + size_t num = ThreadPool::get()->size(); + std::vector<std::unique_ptr<SubType>> instances; + std::vector<std::function<ThreadWorkState ()>> doWorkers; + std::atomic<size_t> nextFunction; + nextFunction.store(0); + size_t numFunctions = module->functions.size(); + for (size_t i = 0; i < num; i++) { + auto* instance = new SubType(); + instances.push_back(std::unique_ptr<SubType>(instance)); + doWorkers.push_back([instance, &nextFunction, numFunctions, &module]() { + auto index = nextFunction.fetch_add(1); + // get the next task, if there is one + if (index >= numFunctions) { + return ThreadWorkState::Finished; // nothing left + } + Function* curr = module->functions[index]; + // do the current task + instance->walk(curr->body); + instance->visitFunction(curr); + if (index + 1 == numFunctions) { + return ThreadWorkState::Finished; // we did the last one + } + return ThreadWorkState::More; + }); + } + ThreadPool::get()->work(doWorkers); } self->visitTable(&module->table); self->visitMemory(&module->memory); @@ -161,8 +202,6 @@ struct Walker : public Visitor<SubType> { Task(TaskFunc func, Expression** currp) : func(func), currp(currp) {} }; - std::vector<Task> stack; - void pushTask(TaskFunc func, Expression** currp) { stack.emplace_back(func, currp); } @@ -216,6 +255,10 @@ struct Walker : public Visitor<SubType> { static void doVisitHost(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitHost((*currp)->cast<Host>()); } static void doVisitNop(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitNop((*currp)->cast<Nop>()); } static void doVisitUnreachable(SubType* self, Expression** currp) { self->visitExpression(*currp); self->visitUnreachable((*currp)->cast<Unreachable>()); } + +private: + Expression *replace = nullptr; // a node to replace + std::vector<Task> stack; // stack of tasks }; // Walks in post-order, i.e., children first. When there isn't an obvious |