summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2016-04-11 19:43:58 -0700
committerAlon Zakai <alonzakai@gmail.com>2016-04-11 19:43:58 -0700
commit73c606a04d01dc7018d028eed3216a507ab03ee9 (patch)
treebf4255d2fead5fdf721ea178607abfc408cd6ac8 /src
parent65d9334b3066bae667e729f3202f7aa2d7c11530 (diff)
parent1044d6cbca6d279d457cdd1cf7000671ec48e841 (diff)
downloadbinaryen-73c606a04d01dc7018d028eed3216a507ab03ee9.tar.gz
binaryen-73c606a04d01dc7018d028eed3216a507ab03ee9.tar.bz2
binaryen-73c606a04d01dc7018d028eed3216a507ab03ee9.zip
Merge pull request #334 from WebAssembly/opts2
More optimizations
Diffstat (limited to 'src')
-rw-r--r--src/asm2wasm.h4
-rw-r--r--src/ast_utils.h38
-rw-r--r--src/binaryen-shell.cpp4
-rw-r--r--src/passes/MergeBlocks.cpp2
-rw-r--r--src/passes/OptimizeInstructions.cpp4
-rw-r--r--src/passes/PostEmscripten.cpp6
-rw-r--r--src/passes/Print.cpp10
-rw-r--r--src/passes/RemoveUnusedBrs.cpp10
-rw-r--r--src/passes/SimplifyLocals.cpp57
-rw-r--r--src/s2wasm.h2
-rw-r--r--src/wasm-s-parser.h2
-rw-r--r--src/wasm-traversal.h1
-rw-r--r--src/wasm.h2
13 files changed, 102 insertions, 40 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index 72a8b9048..a5a76b396 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -433,7 +433,7 @@ private:
// ensure a nameless block
Block* blockify(Expression* expression) {
- if (expression->is<Block>() && !expression->cast<Block>()->name.is()) return expression->dyn_cast<Block>();
+ if (expression->is<Block>() && !expression->cast<Block>()->name.is()) return expression->dynCast<Block>();
auto ret = allocator.alloc<Block>();
ret->list.push_back(expression);
ret->finalize();
@@ -1351,7 +1351,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) {
auto ret = processStatements(ast[1], 0);
if (name.is()) {
breakStack.pop_back();
- Block* block = ret->dyn_cast<Block>();
+ Block* block = ret->dynCast<Block>();
if (block && block->name.isNull()) {
block->name = name;
} else {
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 561e12983..f97697265 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -40,26 +40,38 @@ struct BreakSeeker : public PostWalker<BreakSeeker> {
};
// Look for side effects, including control flow
-// TODO: look at individual locals
+// TODO: optimize
struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
bool branches = false;
bool calls = false;
- bool readsLocal = false;
- bool writesLocal = false;
+ std::set<Name> localsRead;
+ std::set<Name> localsWritten;
bool readsMemory = false;
bool writesMemory = false;
- bool accessesLocal() { return readsLocal || writesLocal; }
+ bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; }
bool accessesMemory() { return calls || readsMemory || writesMemory; }
- bool hasSideEffects() { return calls || writesLocal || writesMemory; }
- bool hasAnything() { return branches || calls || readsLocal || writesLocal || readsMemory || writesMemory; }
+ bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory; }
+ bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory; }
// checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us)
bool invalidates(EffectAnalyzer& other) {
- return branches || other.branches
- || ((writesMemory || calls) && other.accessesMemory()) || (writesLocal && other.accessesLocal())
- || (accessesMemory() && (other.writesMemory || other.calls)) || (accessesLocal() && other.writesLocal);
+ if (branches || other.branches
+ || ((writesMemory || calls) && other.accessesMemory())
+ || (accessesMemory() && (other.writesMemory || other.calls))) {
+ return true;
+ }
+ assert(localsWritten.size() + localsRead.size() <= 1); // the code below is fast on that case, of one element vs many
+ for (auto local : localsWritten) {
+ if (other.localsWritten.count(local) || other.localsRead.count(local)) {
+ return true;
+ }
+ }
+ for (auto local : localsRead) {
+ if (other.localsWritten.count(local)) return true;
+ }
+ return false;
}
// the checks above happen after the node's children were processed, in the order of execution
@@ -85,8 +97,12 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> {
void visitCall(Call *curr) { calls = true; }
void visitCallImport(CallImport *curr) { calls = true; }
void visitCallIndirect(CallIndirect *curr) { calls = true; }
- void visitGetLocal(GetLocal *curr) { readsLocal = true; }
- void visitSetLocal(SetLocal *curr) { writesLocal = true; }
+ void visitGetLocal(GetLocal *curr) {
+ localsRead.insert(curr->name);
+ }
+ void visitSetLocal(SetLocal *curr) {
+ localsWritten.insert(curr->name);
+ }
void visitLoad(Load *curr) { readsMemory = true; }
void visitStore(Store *curr) { writesMemory = true; }
void visitReturn(Return *curr) { branches = true; }
diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp
index 7f5b3077e..dd0f0eeca 100644
--- a/src/binaryen-shell.cpp
+++ b/src/binaryen-shell.cpp
@@ -52,7 +52,7 @@ struct Invocation {
name = invoke[1]->str();
for (size_t j = 2; j < invoke.size(); j++) {
Expression* argument = builder.parseExpression(*invoke[j]);
- arguments.push_back(argument->dyn_cast<Const>()->value);
+ arguments.push_back(argument->dynCast<Const>()->value);
}
}
@@ -150,7 +150,7 @@ static void run_asserts(size_t* i, bool* checked, AllocatingModule* wasm,
if (curr.size() >= 3) {
Literal expected = builder->get()
->parseExpression(*curr[2])
- ->dyn_cast<Const>()
+ ->dynCast<Const>()
->value;
std::cerr << "seen " << result << ", expected " << expected << '\n';
verify_result(expected, result);
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index ab210123c..578d4fc45 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -29,7 +29,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
while (more) {
more = false;
for (size_t i = 0; i < curr->list.size(); i++) {
- Block* child = curr->list[i]->dyn_cast<Block>();
+ Block* child = curr->list[i]->dynCast<Block>();
if (!child) continue;
if (child->name.is()) continue; // named blocks can have breaks to them (and certainly do, if we ran RemoveUnusedNames and RemoveUnusedBrs)
ExpressionList merged;
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 3d89d7af4..ca79468f5 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -29,7 +29,7 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions>
void visitIf(If* curr) {
// flip branches to get rid of an i32.eqz
if (curr->ifFalse) {
- auto condition = curr->condition->dyn_cast<Unary>();
+ auto condition = curr->condition->dynCast<Unary>();
if (condition && condition->op == EqZ && condition->value->type == i32) {
curr->condition = condition->value;
std::swap(curr->ifTrue, curr->ifFalse);
@@ -39,7 +39,7 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions>
void visitUnary(Unary* curr) {
if (curr->op == EqZ) {
// fold comparisons that flow into an EqZ
- auto* child = curr->value->dyn_cast<Binary>();
+ auto* child = curr->value->dynCast<Binary>();
if (child && (child->type == i32 || child->type == i64)) {
switch (child->op) {
case Eq: child->op = Ne; break;
diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp
index 99b172d65..effbad30a 100644
--- a/src/passes/PostEmscripten.cpp
+++ b/src/passes/PostEmscripten.cpp
@@ -44,12 +44,12 @@ struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten>> {
void visitMemoryOp(T *curr) {
if (curr->offset) return;
Expression* ptr = curr->ptr;
- auto add = ptr->dyn_cast<Binary>();
+ auto add = ptr->dynCast<Binary>();
if (!add || add->op != Add) return;
assert(add->type == i32);
- auto c = add->right->dyn_cast<Const>();
+ auto c = add->right->dynCast<Const>();
if (!c) {
- c = add->left->dyn_cast<Const>();
+ c = add->left->dynCast<Const>();
if (c) {
// if one is a const, it's ok to swap
add->left = add->right;
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 7b68d956e..ce974c5f0 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -105,14 +105,14 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
incIndent();
printFullLine(curr->condition);
// ifTrue and False have implict blocks, avoid printing them if possible
- if (!fullAST && curr->ifTrue->is<Block>() && curr->ifTrue->dyn_cast<Block>()->name.isNull() && curr->ifTrue->dyn_cast<Block>()->list.size() == 1) {
- printFullLine(curr->ifTrue->dyn_cast<Block>()->list.back());
+ if (!fullAST && curr->ifTrue->is<Block>() && curr->ifTrue->dynCast<Block>()->name.isNull() && curr->ifTrue->dynCast<Block>()->list.size() == 1) {
+ printFullLine(curr->ifTrue->dynCast<Block>()->list.back());
} else {
printFullLine(curr->ifTrue);
}
if (curr->ifFalse) {
- if (!fullAST && curr->ifFalse->is<Block>() && curr->ifFalse->dyn_cast<Block>()->name.isNull() && curr->ifFalse->dyn_cast<Block>()->list.size() == 1) {
- printFullLine(curr->ifFalse->dyn_cast<Block>()->list.back());
+ if (!fullAST && curr->ifFalse->is<Block>() && curr->ifFalse->dynCast<Block>()->name.isNull() && curr->ifFalse->dynCast<Block>()->list.size() == 1) {
+ printFullLine(curr->ifFalse->dynCast<Block>()->list.back());
} else {
printFullLine(curr->ifFalse);
}
@@ -129,7 +129,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> {
o << ' ' << curr->in;
}
incIndent();
- auto block = curr->body->dyn_cast<Block>();
+ auto block = curr->body->dynCast<Block>();
if (!fullAST && block && block->name.isNull()) {
// wasm spec has loops containing children directly, while our ast
// has a single child for simplicity. print out the optimal form.
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 998142724..41db36d2c 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -30,7 +30,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
void visitIf(If* curr) {
if (!curr->ifFalse) {
// try to reduce an if (condition) br => br_if (condition) , which might open up other optimization opportunities
- Break* br = curr->ifTrue->dyn_cast<Break>();
+ Break* br = curr->ifTrue->dynCast<Break>();
if (br && !br->condition) { // TODO: if there is a condition, join them
br->condition = curr->condition;
replaceCurrent(br);
@@ -40,7 +40,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
if (isConcreteWasmType(curr->type)) return; // already has a returned value
// an if_else that indirectly returns a value by breaking to the same target can potentially remove both breaks, and break outside once
auto getLast = [](Expression *side) -> Expression* {
- Block* b = side->dyn_cast<Block>();
+ Block* b = side->dynCast<Block>();
if (!b) return nullptr;
if (b->list.size() == 0) return nullptr;
return b->list.back();
@@ -49,7 +49,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
Expression* last = getLast(side);
if (!last) return Name();
Block* b = side->cast<Block>();
- Break* br = last->dyn_cast<Break>();
+ Break* br = last->dynCast<Break>();
if (!br) return Name();
if (br->condition) return Name();
if (!br->value) return Name();
@@ -76,14 +76,14 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
if (curr->list.size() == 0) return;
// preparation - remove all code after an unconditional break, since it can't execute, and it might confuse us (we look at the last)
for (size_t i = 0; i < curr->list.size()-1; i++) {
- Break* br = curr->list[i]->dyn_cast<Break>();
+ Break* br = curr->list[i]->dynCast<Break>();
if (br && !br->condition) {
curr->list.resize(i+1);
break;
}
}
Expression* last = curr->list.back();
- if (Break* br = last->dyn_cast<Break>()) {
+ if (Break* br = last->dynCast<Break>()) {
if (br->condition) return;
if (br->name == curr->name) {
if (!br->value) {
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 53e77eb22..77e4f788a 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -17,7 +17,9 @@
//
// Locals-related optimizations
//
-// This "sinks" set_locals, pushing them to the next get_local where possible
+// This "sinks" set_locals, pushing them to the next get_local where possible,
+// and removing the set if there are no gets remaining (the latter is
+// particularly useful in ssa mode, but not only).
#include <wasm.h>
#include <wasm-traversal.h>
@@ -39,6 +41,12 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>>
// locals in current linear execution trace, which we try to sink
std::map<Name, SinkableInfo> sinkables;
+ // name => # of get_locals for it
+ std::map<Name, int> numGetLocals;
+
+ // for each set_local, its origin pointer
+ std::map<SetLocal*, Expression**> setLocalOrigins;
+
void noteNonLinear() {
sinkables.clear();
}
@@ -52,6 +60,8 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>>
*found->second.item = curr;
ExpressionManipulator::nop(curr);
sinkables.erase(found);
+ } else {
+ numGetLocals[curr->name]++;
}
}
@@ -79,21 +89,32 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>>
}
static void visitPre(SimplifyLocals* self, Expression** currp) {
+ Expression* curr = *currp;
+
EffectAnalyzer effects;
- if (effects.checkPre(*currp)) {
+ if (effects.checkPre(curr)) {
self->checkInvalidations(effects);
}
}
static void visitPost(SimplifyLocals* self, Expression** currp) {
+ Expression* curr = *currp;
+
EffectAnalyzer effects;
- if (effects.checkPost(*currp)) {
+ if (effects.checkPost(curr)) {
self->checkInvalidations(effects);
}
+
+ // noting origins in the post means it happens after a
+ // get_local was replaced by a set_local in a sinking
+ // operation, so we track those movements properly.
+ if (curr->is<SetLocal>()) {
+ self->setLocalOrigins[curr->cast<SetLocal>()] = currp;
+ }
}
static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) {
- auto* curr = (*currp)->dyn_cast<SetLocal>();
+ auto* curr = (*currp)->dynCast<SetLocal>();
if (curr) {
Name name = curr->name;
assert(self->sinkables.count(name) == 0);
@@ -107,7 +128,6 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>>
auto* curr = *currp;
-
if (curr->is<Block>()) {
// special-case blocks, by marking their children as locals.
// TODO sink from elsewhere? (need to make sure value is not used)
@@ -129,6 +149,33 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals>>
self->pushTask(visitPre, currp);
}
+
+ void visitFunction(Function *curr) {
+ // after optimizing a function, we can see if we have set_locals
+ // for a local with no remaining gets, in which case, we can
+ // remove the set.
+ std::vector<SetLocal*> optimizables;
+ for (auto pair : setLocalOrigins) {
+ SetLocal* curr = pair.first;
+ if (numGetLocals[curr->name] == 0) {
+ // no gets, can remove the set and leave just the value
+ optimizables.push_back(curr);
+ }
+ }
+ for (auto* curr : optimizables) {
+ Expression** origin = setLocalOrigins[curr];
+ *origin = curr->value;
+ // nested set_values need to be handled properly.
+ // consider (set_local x (set_local y (..)), where both can be
+ // reduced to their values, and we might do it in either
+ // order.
+ if (curr->value->is<SetLocal>()) {
+ setLocalOrigins[curr->value->cast<SetLocal>()] = origin;
+ }
+ }
+ numGetLocals.clear();
+ setLocalOrigins.clear();
+ }
};
static RegisterPass<SimplifyLocals> registerPass("simplify-locals", "miscellaneous locals-related optimizations");
diff --git a/src/s2wasm.h b/src/s2wasm.h
index 9d166624c..9bdc4961d 100644
--- a/src/s2wasm.h
+++ b/src/s2wasm.h
@@ -1086,7 +1086,7 @@ class S2WasmBuilder {
for (auto block : loopBlocks) {
block->name = Name();
}
- func->body->dyn_cast<Block>()->finalize();
+ func->body->dynCast<Block>()->finalize();
wasm.addFunction(func);
}
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 0d2ecd570..40f2a41a0 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -897,7 +897,7 @@ private:
auto* ret = parseExpression(&s);
labelStack.pop_back();
if (explicitThenElse) {
- ret->dyn_cast<Block>()->name = name;
+ ret->dynCast<Block>()->name = name;
} else {
// add a block if we must
if (BreakSeeker::has(ret, name)) {
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index d1fff2753..c8de5d886 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -33,7 +33,6 @@ namespace wasm {
template<typename SubType, typename ReturnType = void>
struct Visitor {
- virtual ~Visitor() {}
// Expression visitors
ReturnType visitBlock(Block *curr) {}
ReturnType visitIf(If *curr) {}
diff --git a/src/wasm.h b/src/wasm.h
index c9494d4d6..2b04ed8cd 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -758,7 +758,7 @@ public:
}
template<class T>
- T* dyn_cast() {
+ T* dynCast() {
return _id == T()._id ? (T*)this : nullptr;
}