diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 14 | ||||
-rw-r--r-- | src/wasm-traversal.h | 26 |
2 files changed, 32 insertions, 8 deletions
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 98806410f..5c7b7475a 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -45,16 +45,14 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> { void visitBlock(Block *curr) { // note locals, we can sink them from here TODO sink from elsewhere? - ExpressionList& list = curr->list; - for (size_t z = 0; z < list.size(); z++) { - walk(list[z]); - auto* item = list[z]; - if (item->is<SetLocal>()) { - Name name = item->cast<SetLocal>()->name; + derecurseBlocks(curr, [&](Block* block) {}, [&](Block* block, Expression*& child) { + walk(child); + if (child->is<SetLocal>()) { + Name name = child->cast<SetLocal>()->name; assert(sinkables.count(name) == 0); - sinkables.emplace(std::make_pair(name, SinkableInfo(&list[z]))); + sinkables.emplace(std::make_pair(name, SinkableInfo(&child))); } - } + }, [&](Block* block) {}); } void visitGetLocal(GetLocal *curr) { diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 04db3cb35..24ec4905c 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -98,6 +98,32 @@ struct WasmVisitor { #undef DELEGATE + // Helper method to de-recurse blocks, which often nest in their first position very heavily + void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock, + std::function<void (Block*, Expression*&)> onChild, + std::function<void (Block*)> postBlock) { + std::vector<Block*> stack; + stack.push_back(block); + while (block->list.size() > 0 && block->list[0]->is<Block>()) { + block = block->list[0]->cast<Block>(); + stack.push_back(block); + } + for (size_t i = 0; i < stack.size(); i++) { + preBlock(stack[i]); + } + for (int i = int(stack.size()) - 1; i >= 0; i--) { + auto* block = stack[i]; + auto& list = block->list; + for (size_t j = 0; j < list.size(); j++) { + if (i < int(stack.size()) - 1 && j == 0) { + // nested block, we already called its pre + } else { + onChild(block, list[j]); + } + } + postBlock(block); + } + } }; // |