summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/SimplifyLocals.cpp14
-rw-r--r--src/wasm-traversal.h26
2 files changed, 32 insertions, 8 deletions
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 98806410f..5c7b7475a 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -45,16 +45,14 @@ struct SimplifyLocals : public WalkerPass<FastExecutionWalker<SimplifyLocals>> {
void visitBlock(Block *curr) {
// note locals, we can sink them from here TODO sink from elsewhere?
- ExpressionList& list = curr->list;
- for (size_t z = 0; z < list.size(); z++) {
- walk(list[z]);
- auto* item = list[z];
- if (item->is<SetLocal>()) {
- Name name = item->cast<SetLocal>()->name;
+ derecurseBlocks(curr, [&](Block* block) {}, [&](Block* block, Expression*& child) {
+ walk(child);
+ if (child->is<SetLocal>()) {
+ Name name = child->cast<SetLocal>()->name;
assert(sinkables.count(name) == 0);
- sinkables.emplace(std::make_pair(name, SinkableInfo(&list[z])));
+ sinkables.emplace(std::make_pair(name, SinkableInfo(&child)));
}
- }
+ }, [&](Block* block) {});
}
void visitGetLocal(GetLocal *curr) {
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 04db3cb35..24ec4905c 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -98,6 +98,32 @@ struct WasmVisitor {
#undef DELEGATE
+ // Helper method to de-recurse blocks, which often nest in their first position very heavily
+ void derecurseBlocks(Block* block, std::function<void (Block*)> preBlock,
+ std::function<void (Block*, Expression*&)> onChild,
+ std::function<void (Block*)> postBlock) {
+ std::vector<Block*> stack;
+ stack.push_back(block);
+ while (block->list.size() > 0 && block->list[0]->is<Block>()) {
+ block = block->list[0]->cast<Block>();
+ stack.push_back(block);
+ }
+ for (size_t i = 0; i < stack.size(); i++) {
+ preBlock(stack[i]);
+ }
+ for (int i = int(stack.size()) - 1; i >= 0; i--) {
+ auto* block = stack[i];
+ auto& list = block->list;
+ for (size_t j = 0; j < list.size(); j++) {
+ if (i < int(stack.size()) - 1 && j == 0) {
+ // nested block, we already called its pre
+ } else {
+ onChild(block, list[j]);
+ }
+ }
+ postBlock(block);
+ }
+ }
};
//