diff options
author | Alon Zakai <alonzakai@gmail.com> | 2015-11-01 17:12:20 -0800 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2015-11-01 17:12:20 -0800 |
commit | 01fa6454d5b375ea85ba8662306cb0ac30a1de29 (patch) | |
tree | 5946f77916332dca6c28fc573024b200029641bf | |
parent | 5ad0538c67dcc27c28fb99e058995de110550de1 (diff) | |
download | binaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.tar.gz binaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.tar.bz2 binaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.zip |
fix wasm walker replacement logic
-rw-r--r-- | src/asm2wasm.h | 19 | ||||
-rw-r--r-- | src/wasm.h | 119 |
2 files changed, 76 insertions, 62 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 0ac8ac83e..6c14e5757 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -1031,20 +1031,23 @@ void Asm2WasmBuilder::optimize() { struct BlockBreakOptimizer : public WasmWalker { BlockBreakOptimizer() : WasmWalker(nullptr) {} - Expression* visitBlock(Block *curr) override { + void visitBlock(Block *curr) override { if (curr->list.size() > 1) { // we can't remove the block, but if it ends in a break on this very block, then just put the value there Break *last = curr->list[curr->list.size()-1]->dyn_cast<Break>(); if (last && last->value && last->name == curr->name) { curr->list[curr->list.size()-1] = last->value; } - return curr; + return; } // just one element; maybe we can return just the element - if (curr->name.isNull()) return curr->list[0]; + if (curr->name.isNull()) { + replaceCurrent(curr->list[0]); + return; + } // we might be broken to, but if it's a trivial singleton child break, we can optimize here as well Break *child = curr->list[0]->dyn_cast<Break>(); - if (!child || child->name != curr->name || !child->value) return curr; + if (!child || child->name != curr->name || !child->value) return; struct BreakSeeker : public WasmWalker { IString target; // look for this one @@ -1052,7 +1055,7 @@ void Asm2WasmBuilder::optimize() { BreakSeeker(IString target) : target(target), found(false) {} - Expression* visitBreak(Break *curr) override { + void visitBreak(Break *curr) override { if (curr->name == target) found++; } }; @@ -1060,9 +1063,9 @@ void Asm2WasmBuilder::optimize() { // look in the child's children to see if there are more uses of this name BreakSeeker breakSeeker(curr->name); breakSeeker.walk(child->value); - if (breakSeeker.found == 0) return child->value; - - return curr; // failed to optimize + if (breakSeeker.found == 0) { + replaceCurrent(child->value); + } } }; diff --git a/src/wasm.h b/src/wasm.h index 9ef39e42f..95b44c344 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -985,39 +985,45 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { // Simple WebAssembly children-first walking // -struct WasmWalker : public WasmVisitor<Expression*> { +struct WasmWalker : public WasmVisitor<void> { wasm::Arena* allocator; // use an existing allocator, or null if no allocations + Expression* replace; - WasmWalker() : allocator(nullptr) {} - WasmWalker(wasm::Arena* allocator) : allocator(allocator) {} - - // Each method receives an AST pointer, and it is replaced with what is returned. - Expression* visitBlock(Block *curr) override { return curr; }; - Expression* visitIf(If *curr) override { return curr; }; - Expression* visitLoop(Loop *curr) override { return curr; }; - Expression* visitLabel(Label *curr) override { return curr; }; - Expression* visitBreak(Break *curr) override { return curr; }; - Expression* visitSwitch(Switch *curr) override { return curr; }; - Expression* visitCall(Call *curr) override { return curr; }; - Expression* visitCallImport(CallImport *curr) override { return curr; }; - Expression* visitCallIndirect(CallIndirect *curr) override { return curr; }; - Expression* visitGetLocal(GetLocal *curr) override { return curr; }; - Expression* visitSetLocal(SetLocal *curr) override { return curr; }; - Expression* visitLoad(Load *curr) override { return curr; }; - Expression* visitStore(Store *curr) override { return curr; }; - Expression* visitConst(Const *curr) override { return curr; }; - Expression* visitUnary(Unary *curr) override { return curr; }; - Expression* visitBinary(Binary *curr) override { return curr; }; - Expression* visitCompare(Compare *curr) override { return curr; }; - Expression* visitConvert(Convert *curr) override { return curr; }; - Expression* visitHost(Host *curr) override { return curr; }; - Expression* visitNop(Nop *curr) override { return curr; }; + WasmWalker() : allocator(nullptr), replace(nullptr) {} + WasmWalker(wasm::Arena* allocator) : allocator(allocator), replace(nullptr) {} + + // the visit* methods can call this to replace the current node + void replaceCurrent(Expression *expression) { + replace = expression; + } + + // By default, do nothing + void visitBlock(Block *curr) override {}; + void visitIf(If *curr) override {}; + void visitLoop(Loop *curr) override {}; + void visitLabel(Label *curr) override {}; + void visitBreak(Break *curr) override {}; + void visitSwitch(Switch *curr) override {}; + void visitCall(Call *curr) override {}; + void visitCallImport(CallImport *curr) override {}; + void visitCallIndirect(CallIndirect *curr) override {}; + void visitGetLocal(GetLocal *curr) override {}; + void visitSetLocal(SetLocal *curr) override {}; + void visitLoad(Load *curr) override {}; + void visitStore(Store *curr) override {}; + void visitConst(Const *curr) override {}; + void visitUnary(Unary *curr) override {}; + void visitBinary(Binary *curr) override {}; + void visitCompare(Compare *curr) override {}; + void visitConvert(Convert *curr) override {}; + void visitHost(Host *curr) override {}; + void visitNop(Nop *curr) override {}; // children-first - Expression *walk(Expression *curr) { - if (!curr) return curr; + void walk(Expression*& curr) { + if (!curr) return; - struct ChildWalker : public WasmVisitor<void> { + struct ChildWalker : public WasmVisitor { WasmWalker& parent; ChildWalker(WasmWalker& parent) : parent(parent) {} @@ -1025,77 +1031,77 @@ struct WasmWalker : public WasmVisitor<Expression*> { void visitBlock(Block *curr) override { ExpressionList& list = curr->list; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitIf(If *curr) override { - curr->condition = parent.walk(curr->condition); - curr->ifTrue = parent.walk(curr->ifTrue); - curr->ifFalse = parent.walk(curr->ifFalse); + parent.walk(curr->condition); + parent.walk(curr->ifTrue); + parent.walk(curr->ifFalse); } void visitLoop(Loop *curr) override { - curr->body = parent.walk(curr->body); + parent.walk(curr->body); } void visitLabel(Label *curr) override {} void visitBreak(Break *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitSwitch(Switch *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); for (auto& case_ : curr->cases) { - case_.body = parent.walk(case_.body); + parent.walk(case_.body); } - curr->default_ = parent.walk(curr->default_); + parent.walk(curr->default_); } void visitCall(Call *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitCallImport(CallImport *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitCallIndirect(CallIndirect *curr) override { - curr->target = parent.walk(curr->target); + parent.walk(curr->target); ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitGetLocal(GetLocal *curr) override {} void visitSetLocal(SetLocal *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitLoad(Load *curr) override { - curr->ptr = parent.walk(curr->ptr); + parent.walk(curr->ptr); } void visitStore(Store *curr) override { - curr->ptr = parent.walk(curr->ptr); - curr->value = parent.walk(curr->value); + parent.walk(curr->ptr); + parent.walk(curr->value); } void visitConst(Const *curr) override {} void visitUnary(Unary *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitBinary(Binary *curr) override { - curr->left = parent.walk(curr->left); - curr->right = parent.walk(curr->right); + parent.walk(curr->left); + parent.walk(curr->right); } void visitCompare(Compare *curr) override { - curr->left = parent.walk(curr->left); - curr->right = parent.walk(curr->right); + parent.walk(curr->left); + parent.walk(curr->right); } void visitConvert(Convert *curr) override { - curr->value = parent.walk(curr->value); + parent.walk(curr->value); } void visitHost(Host *curr) override { ExpressionList& list = curr->operands; for (size_t z = 0; z < list.size(); z++) { - list[z] = parent.walk(list[z]); + parent.walk(list[z]); } } void visitNop(Nop *curr) override {} @@ -1103,11 +1109,16 @@ struct WasmWalker : public WasmVisitor<Expression*> { ChildWalker(*this).visit(curr); - return visit(curr); + visit(curr); + + if (replace) { + curr = replace; + replace = nullptr; + } } void startWalk(Function *func) { - func->body = walk(func->body); + walk(func->body); } }; |