summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2015-11-01 17:12:20 -0800
committerAlon Zakai <alonzakai@gmail.com>2015-11-01 17:12:20 -0800
commit01fa6454d5b375ea85ba8662306cb0ac30a1de29 (patch)
tree5946f77916332dca6c28fc573024b200029641bf
parent5ad0538c67dcc27c28fb99e058995de110550de1 (diff)
downloadbinaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.tar.gz
binaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.tar.bz2
binaryen-01fa6454d5b375ea85ba8662306cb0ac30a1de29.zip
fix wasm walker replacement logic
-rw-r--r--src/asm2wasm.h19
-rw-r--r--src/wasm.h119
2 files changed, 76 insertions, 62 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index 0ac8ac83e..6c14e5757 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -1031,20 +1031,23 @@ void Asm2WasmBuilder::optimize() {
struct BlockBreakOptimizer : public WasmWalker {
BlockBreakOptimizer() : WasmWalker(nullptr) {}
- Expression* visitBlock(Block *curr) override {
+ void visitBlock(Block *curr) override {
if (curr->list.size() > 1) {
// we can't remove the block, but if it ends in a break on this very block, then just put the value there
Break *last = curr->list[curr->list.size()-1]->dyn_cast<Break>();
if (last && last->value && last->name == curr->name) {
curr->list[curr->list.size()-1] = last->value;
}
- return curr;
+ return;
}
// just one element; maybe we can return just the element
- if (curr->name.isNull()) return curr->list[0];
+ if (curr->name.isNull()) {
+ replaceCurrent(curr->list[0]);
+ return;
+ }
// we might be broken to, but if it's a trivial singleton child break, we can optimize here as well
Break *child = curr->list[0]->dyn_cast<Break>();
- if (!child || child->name != curr->name || !child->value) return curr;
+ if (!child || child->name != curr->name || !child->value) return;
struct BreakSeeker : public WasmWalker {
IString target; // look for this one
@@ -1052,7 +1055,7 @@ void Asm2WasmBuilder::optimize() {
BreakSeeker(IString target) : target(target), found(false) {}
- Expression* visitBreak(Break *curr) override {
+ void visitBreak(Break *curr) override {
if (curr->name == target) found++;
}
};
@@ -1060,9 +1063,9 @@ void Asm2WasmBuilder::optimize() {
// look in the child's children to see if there are more uses of this name
BreakSeeker breakSeeker(curr->name);
breakSeeker.walk(child->value);
- if (breakSeeker.found == 0) return child->value;
-
- return curr; // failed to optimize
+ if (breakSeeker.found == 0) {
+ replaceCurrent(child->value);
+ }
}
};
diff --git a/src/wasm.h b/src/wasm.h
index 9ef39e42f..95b44c344 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -985,39 +985,45 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
// Simple WebAssembly children-first walking
//
-struct WasmWalker : public WasmVisitor<Expression*> {
+struct WasmWalker : public WasmVisitor<void> {
wasm::Arena* allocator; // use an existing allocator, or null if no allocations
+ Expression* replace;
- WasmWalker() : allocator(nullptr) {}
- WasmWalker(wasm::Arena* allocator) : allocator(allocator) {}
-
- // Each method receives an AST pointer, and it is replaced with what is returned.
- Expression* visitBlock(Block *curr) override { return curr; };
- Expression* visitIf(If *curr) override { return curr; };
- Expression* visitLoop(Loop *curr) override { return curr; };
- Expression* visitLabel(Label *curr) override { return curr; };
- Expression* visitBreak(Break *curr) override { return curr; };
- Expression* visitSwitch(Switch *curr) override { return curr; };
- Expression* visitCall(Call *curr) override { return curr; };
- Expression* visitCallImport(CallImport *curr) override { return curr; };
- Expression* visitCallIndirect(CallIndirect *curr) override { return curr; };
- Expression* visitGetLocal(GetLocal *curr) override { return curr; };
- Expression* visitSetLocal(SetLocal *curr) override { return curr; };
- Expression* visitLoad(Load *curr) override { return curr; };
- Expression* visitStore(Store *curr) override { return curr; };
- Expression* visitConst(Const *curr) override { return curr; };
- Expression* visitUnary(Unary *curr) override { return curr; };
- Expression* visitBinary(Binary *curr) override { return curr; };
- Expression* visitCompare(Compare *curr) override { return curr; };
- Expression* visitConvert(Convert *curr) override { return curr; };
- Expression* visitHost(Host *curr) override { return curr; };
- Expression* visitNop(Nop *curr) override { return curr; };
+ WasmWalker() : allocator(nullptr), replace(nullptr) {}
+ WasmWalker(wasm::Arena* allocator) : allocator(allocator), replace(nullptr) {}
+
+ // the visit* methods can call this to replace the current node
+ void replaceCurrent(Expression *expression) {
+ replace = expression;
+ }
+
+ // By default, do nothing
+ void visitBlock(Block *curr) override {};
+ void visitIf(If *curr) override {};
+ void visitLoop(Loop *curr) override {};
+ void visitLabel(Label *curr) override {};
+ void visitBreak(Break *curr) override {};
+ void visitSwitch(Switch *curr) override {};
+ void visitCall(Call *curr) override {};
+ void visitCallImport(CallImport *curr) override {};
+ void visitCallIndirect(CallIndirect *curr) override {};
+ void visitGetLocal(GetLocal *curr) override {};
+ void visitSetLocal(SetLocal *curr) override {};
+ void visitLoad(Load *curr) override {};
+ void visitStore(Store *curr) override {};
+ void visitConst(Const *curr) override {};
+ void visitUnary(Unary *curr) override {};
+ void visitBinary(Binary *curr) override {};
+ void visitCompare(Compare *curr) override {};
+ void visitConvert(Convert *curr) override {};
+ void visitHost(Host *curr) override {};
+ void visitNop(Nop *curr) override {};
// children-first
- Expression *walk(Expression *curr) {
- if (!curr) return curr;
+ void walk(Expression*& curr) {
+ if (!curr) return;
- struct ChildWalker : public WasmVisitor<void> {
+ struct ChildWalker : public WasmVisitor {
WasmWalker& parent;
ChildWalker(WasmWalker& parent) : parent(parent) {}
@@ -1025,77 +1031,77 @@ struct WasmWalker : public WasmVisitor<Expression*> {
void visitBlock(Block *curr) override {
ExpressionList& list = curr->list;
for (size_t z = 0; z < list.size(); z++) {
- list[z] = parent.walk(list[z]);
+ parent.walk(list[z]);
}
}
void visitIf(If *curr) override {
- curr->condition = parent.walk(curr->condition);
- curr->ifTrue = parent.walk(curr->ifTrue);
- curr->ifFalse = parent.walk(curr->ifFalse);
+ parent.walk(curr->condition);
+ parent.walk(curr->ifTrue);
+ parent.walk(curr->ifFalse);
}
void visitLoop(Loop *curr) override {
- curr->body = parent.walk(curr->body);
+ parent.walk(curr->body);
}
void visitLabel(Label *curr) override {}
void visitBreak(Break *curr) override {
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->value);
}
void visitSwitch(Switch *curr) override {
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->value);
for (auto& case_ : curr->cases) {
- case_.body = parent.walk(case_.body);
+ parent.walk(case_.body);
}
- curr->default_ = parent.walk(curr->default_);
+ parent.walk(curr->default_);
}
void visitCall(Call *curr) override {
ExpressionList& list = curr->operands;
for (size_t z = 0; z < list.size(); z++) {
- list[z] = parent.walk(list[z]);
+ parent.walk(list[z]);
}
}
void visitCallImport(CallImport *curr) override {
ExpressionList& list = curr->operands;
for (size_t z = 0; z < list.size(); z++) {
- list[z] = parent.walk(list[z]);
+ parent.walk(list[z]);
}
}
void visitCallIndirect(CallIndirect *curr) override {
- curr->target = parent.walk(curr->target);
+ parent.walk(curr->target);
ExpressionList& list = curr->operands;
for (size_t z = 0; z < list.size(); z++) {
- list[z] = parent.walk(list[z]);
+ parent.walk(list[z]);
}
}
void visitGetLocal(GetLocal *curr) override {}
void visitSetLocal(SetLocal *curr) override {
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->value);
}
void visitLoad(Load *curr) override {
- curr->ptr = parent.walk(curr->ptr);
+ parent.walk(curr->ptr);
}
void visitStore(Store *curr) override {
- curr->ptr = parent.walk(curr->ptr);
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->ptr);
+ parent.walk(curr->value);
}
void visitConst(Const *curr) override {}
void visitUnary(Unary *curr) override {
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->value);
}
void visitBinary(Binary *curr) override {
- curr->left = parent.walk(curr->left);
- curr->right = parent.walk(curr->right);
+ parent.walk(curr->left);
+ parent.walk(curr->right);
}
void visitCompare(Compare *curr) override {
- curr->left = parent.walk(curr->left);
- curr->right = parent.walk(curr->right);
+ parent.walk(curr->left);
+ parent.walk(curr->right);
}
void visitConvert(Convert *curr) override {
- curr->value = parent.walk(curr->value);
+ parent.walk(curr->value);
}
void visitHost(Host *curr) override {
ExpressionList& list = curr->operands;
for (size_t z = 0; z < list.size(); z++) {
- list[z] = parent.walk(list[z]);
+ parent.walk(list[z]);
}
}
void visitNop(Nop *curr) override {}
@@ -1103,11 +1109,16 @@ struct WasmWalker : public WasmVisitor<Expression*> {
ChildWalker(*this).visit(curr);
- return visit(curr);
+ visit(curr);
+
+ if (replace) {
+ curr = replace;
+ replace = nullptr;
+ }
}
void startWalk(Function *func) {
- func->body = walk(func->body);
+ walk(func->body);
}
};