summaryrefslogtreecommitdiff
path: root/src/passes/SimplifyLocals.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/SimplifyLocals.cpp')
-rw-r--r--src/passes/SimplifyLocals.cpp372
1 files changed, 300 insertions, 72 deletions
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index f0c786b1c..785bccf06 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -21,20 +21,52 @@
// and removing the set if there are no gets remaining (the latter is
// particularly useful in ssa mode, but not only).
//
+// We also note where set_locals coalesce: if all breaks of a block set
+// a specific local, we can use a block return value for it, in effect
+// removing multiple set_locals and replacing them with one that the
+// block returns to. Further optimization rounds then have the opportunity
+// to remove that set_local as well. TODO: support partial traces; right
+// now, whenever control flow splits, we invalidate everything. This is
+// enough for SSA form, but not otherwise.
+//
// After this pass, some locals may be completely unused. reorder-locals
// can get rid of those (the operation is trivial there after it sorts by use
// frequency).
#include <wasm.h>
+#include <wasm-builder.h>
#include <wasm-traversal.h>
#include <pass.h>
#include <ast_utils.h>
namespace wasm {
+// Helper classes
+
+struct GetLocalCounter : public WalkerPass<PostWalker<GetLocalCounter, Visitor<GetLocalCounter>>> {
+ std::vector<int>* numGetLocals;
+
+ void visitGetLocal(GetLocal *curr) {
+ (*numGetLocals)[curr->index]++;
+ }
+};
+
+struct SetLocalRemover : public WalkerPass<PostWalker<SetLocalRemover, Visitor<SetLocalRemover>>> {
+ std::vector<int>* numGetLocals;
+
+ void visitSetLocal(SetLocal *curr) {
+ if ((*numGetLocals)[curr->index] == 0) {
+ replaceCurrent(curr->value);
+ }
+ }
+};
+
+// Main class
+
struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>> {
bool isFunctionParallel() { return true; }
+ // information for a set_local we can sink
struct SinkableInfo {
Expression** item;
EffectAnalyzer effects;
@@ -44,19 +76,98 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
}
};
+ // a list of sinkables in a linear execution trace
+ typedef std::map<Index, SinkableInfo> Sinkables;
+
// locals in current linear execution trace, which we try to sink
- std::map<Index, SinkableInfo> sinkables;
+ Sinkables sinkables;
- bool sunk;
+ // Information about an exit from a block: the break, and the
+ // sinkables. For the final exit from a block (falling off)
+ // exitter is null.
+ struct BlockBreak {
+ Break* br;
+ Sinkables sinkables;
+ };
- // local => # of get_locals for it
- std::vector<int> numGetLocals;
+ // a list of all sinkable traces that exit a block. the last
+ // is falling off the end, others are branches. this is used for
+ // block returns
+ std::map<Name, std::vector<BlockBreak>> blockBreaks;
- // for each set_local, its origin pointer
- std::map<SetLocal*, Expression**> setLocalOrigins;
+ // blocks that are the targets of a switch; we need to know this
+ // since we can't produce a block return value for them.
+ std::set<Name> unoptimizableBlocks;
- void noteNonLinear() {
- sinkables.clear();
+ // A stack of sinkables from the current traversal state. When
+ // execution reaches an if-else, it splits, and can then
+ // be merged on return.
+ std::vector<Sinkables> ifStack;
+
+ // whether we need to run an additional cycle
+ bool anotherCycle;
+
+ static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) {
+ auto* curr = *currp;
+ if (curr->is<Break>()) {
+ auto* br = curr->cast<Break>();
+ if (br->value) {
+ // value means the block already has a return value
+ self->unoptimizableBlocks.insert(br->name);
+ } else {
+ self->blockBreaks[br->name].push_back({ br, std::move(self->sinkables) });
+ }
+ } else if (curr->is<Block>()) {
+ return; // handled in visitBlock
+ } else if (curr->is<If>()) {
+ assert(!curr->cast<If>()->ifFalse); // if-elses are handled by doNoteIfElse* methods
+ } else if (curr->is<Switch>()) {
+ auto* sw = curr->cast<Switch>();
+ for (auto target : sw->targets) {
+ self->unoptimizableBlocks.insert(target);
+ }
+ // TODO: we could use this info to stop gathering data on these blocks
+ }
+ self->sinkables.clear();
+ }
+
+ static void doNoteIfElseCondition(SimplifyLocals* self, Expression** currp) {
+ // we processed the condition of this if-else, and now control flow branches
+ // into either the true or the false sides
+ assert((*currp)->cast<If>()->ifFalse);
+ self->sinkables.clear();
+ }
+
+ static void doNoteIfElseTrue(SimplifyLocals* self, Expression** currp) {
+ // we processed the ifTrue side of this if-else, save it on the stack
+ assert((*currp)->cast<If>()->ifFalse);
+ self->ifStack.push_back(std::move(self->sinkables));
+ }
+
+ static void doNoteIfElseFalse(SimplifyLocals* self, Expression** currp) {
+ // we processed the ifFalse side of this if-else, we can now try to
+ // mere with the ifTrue side and optimize a return value, if possible
+ auto* iff = (*currp)->cast<If>();
+ assert(iff->ifFalse);
+ self->optimizeIfReturn(iff, currp, self->ifStack.back());
+ self->ifStack.pop_back();
+ self->sinkables.clear();
+ }
+
+ void visitBlock(Block* curr) {
+ bool hasBreaks = curr->name.is() && blockBreaks[curr->name].size() > 0;
+
+ optimizeBlockReturn(curr); // can modify blockBreaks
+
+ // post-block cleanups
+ if (curr->name.is()) {
+ unoptimizableBlocks.erase(curr->name);
+ }
+ if (hasBreaks) {
+ // more than one path to here, so nonlinear
+ sinkables.clear();
+ blockBreaks.erase(curr->name);
+ }
}
void visitGetLocal(GetLocal *curr) {
@@ -68,19 +179,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
*found->second.item = curr;
ExpressionManipulator::nop(curr);
sinkables.erase(found);
- sunk = true;
- } else {
- numGetLocals[curr->index]++;
- }
- }
-
- void visitSetLocal(SetLocal *curr) {
- // if we are a potentially-sinkable thing, then the previous
- // store is dead, leave just the value
- auto found = sinkables.find(curr->index);
- if (found != sinkables.end()) {
- *found->second.item = (*found->second.item)->cast<SetLocal>()->value;
- sinkables.erase(found);
+ anotherCycle = true;
}
}
@@ -97,6 +196,8 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
}
}
+ std::vector<Expression*> expressionStack;
+
static void visitPre(SimplifyLocals* self, Expression** currp) {
Expression* curr = *currp;
@@ -104,31 +205,150 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
if (effects.checkPre(curr)) {
self->checkInvalidations(effects);
}
+
+ self->expressionStack.push_back(curr);
}
static void visitPost(SimplifyLocals* self, Expression** currp) {
- Expression* curr = *currp;
+ // perform main SetLocal processing here, since we may be the result of
+ // replaceCurrent, i.e., the visitor was not called.
+ auto* set = (*currp)->dynCast<SetLocal>();
+
+ if (set) {
+ // if we see a set that was already potentially-sinkable, then the previous
+ // store is dead, leave just the value
+ auto found = self->sinkables.find(set->index);
+ if (found != self->sinkables.end()) {
+ *found->second.item = (*found->second.item)->cast<SetLocal>()->value;
+ self->sinkables.erase(found);
+ self->anotherCycle = true;
+ }
+ }
EffectAnalyzer effects;
- if (effects.checkPost(curr)) {
+ if (effects.checkPost(*currp)) {
self->checkInvalidations(effects);
}
- // noting origins in the post means it happens after a
- // get_local was replaced by a set_local in a sinking
- // operation, so we track those movements properly.
- if (curr->is<SetLocal>()) {
- self->setLocalOrigins[curr->cast<SetLocal>()] = currp;
+ if (set) {
+ // we may be a replacement for the current node, update the stack
+ self->expressionStack.pop_back();
+ self->expressionStack.push_back(set);
+ if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) {
+ Index index = set->index;
+ assert(self->sinkables.count(index) == 0);
+ self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
+ }
+ }
+
+ self->expressionStack.pop_back();
+ }
+
+ std::vector<Block*> blocksToEnlarge;
+ std::vector<If*> ifsToEnlarge;
+
+ void optimizeBlockReturn(Block* block) {
+ if (!block->name.is() || unoptimizableBlocks.count(block->name) > 0) {
+ return;
}
+ auto breaks = std::move(blockBreaks[block->name]);
+ blockBreaks.erase(block->name);
+ if (breaks.size() == 0) return; // block has no branches TODO we might optimize trivial stuff here too
+ assert(!breaks[0].br->value); // block does not already have a return value (if one break has one, they all do)
+ // look for a set_local that is present in them all
+ bool found = false;
+ Index sharedIndex = -1;
+ for (auto& sinkable : sinkables) {
+ Index index = sinkable.first;
+ bool inAll = true;
+ for (size_t j = 0; j < breaks.size(); j++) {
+ if (breaks[j].sinkables.count(index) == 0) {
+ inAll = false;
+ break;
+ }
+ }
+ if (inAll) {
+ sharedIndex = index;
+ found = true;
+ break;
+ }
+ }
+ if (!found) return;
+ // Great, this local is set in them all, we can optimize!
+ if (block->list.size() == 0 || !block->list.back()->is<Nop>()) {
+ // We can't do this here, since we can't push to the block -
+ // it would invalidate sinkable pointers. So we queue a request
+ // to grow the block at the end of the turn, we'll get this next
+ // cycle.
+ blocksToEnlarge.push_back(block);
+ return;
+ }
+ // move block set_local's value to the end, in return position, and nop the set
+ auto* blockSetLocalPointer = sinkables.at(sharedIndex).item;
+ auto* value = (*blockSetLocalPointer)->cast<SetLocal>()->value;
+ block->list[block->list.size() - 1] = value;
+ block->type = value->type;
+ ExpressionManipulator::nop(*blockSetLocalPointer);
+ for (size_t j = 0; j < breaks.size(); j++) {
+ // move break set_local's value to the break
+ auto* breakSetLocalPointer = breaks[j].sinkables.at(sharedIndex).item;
+ assert(!breaks[j].br->value);
+ breaks[j].br->value = (*breakSetLocalPointer)->cast<SetLocal>()->value;
+ ExpressionManipulator::nop(*breakSetLocalPointer);
+ }
+ // finally, create a set_local on the block itself
+ auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, block);
+ replaceCurrent(newSetLocal);
+ sinkables.clear();
+ anotherCycle = true;
}
- static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) {
- auto* curr = (*currp)->dynCast<SetLocal>();
- if (curr) {
- Index index = curr->index;
- assert(self->sinkables.count(index) == 0);
- self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
+ // optimize set_locals from both sides of an if into a return value
+ void optimizeIfReturn(If* iff, Expression** currp, Sinkables& ifTrue) {
+ assert(iff->ifFalse);
+ // if this if already has a result that is used, we can't do anything
+ assert(expressionStack.back() == iff);
+ if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) return;
+ // We now have the sinkables from both sides of the if.
+ Sinkables& ifFalse = sinkables;
+ Index sharedIndex = -1;
+ bool found = false;
+ for (auto& sinkable : ifTrue) {
+ Index index = sinkable.first;
+ if (ifFalse.count(index) > 0) {
+ sharedIndex = index;
+ found = true;
+ break;
+ }
+ }
+ if (!found) return;
+ // great, we can optimize!
+ // ensure we have a place to write the return values for, if not, we
+ // need another cycle
+ auto* ifTrueBlock = iff->ifTrue->dynCast<Block>();
+ auto* ifFalseBlock = iff->ifFalse->dynCast<Block>();
+ if (!ifTrueBlock || ifTrueBlock->list.size() == 0 || !ifTrueBlock->list.back()->is<Nop>() ||
+ !ifFalseBlock || ifFalseBlock->list.size() == 0 || !ifFalseBlock->list.back()->is<Nop>()) {
+ ifsToEnlarge.push_back(iff);
+ return;
}
+ // all set, go
+ auto *ifTrueItem = ifTrue.at(sharedIndex).item;
+ ifTrueBlock->list[ifTrueBlock->list.size() - 1] = (*ifTrueItem)->cast<SetLocal>()->value;
+ ExpressionManipulator::nop(*ifTrueItem);
+ ifTrueBlock->finalize();
+ assert(ifTrueBlock->type != none);
+ auto *ifFalseItem = ifFalse.at(sharedIndex).item;
+ ifFalseBlock->list[ifFalseBlock->list.size() - 1] = (*ifFalseItem)->cast<SetLocal>()->value;
+ ExpressionManipulator::nop(*ifFalseItem);
+ ifFalseBlock->finalize();
+ assert(ifTrueBlock->type != none);
+ iff->finalize(); // update type
+ assert(iff->type != none);
+ // finally, create a set_local on the iff itself
+ auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, iff);
+ *currp = newSetLocal;
+ anotherCycle = true;
}
// override scan to add a pre and a post check task to all nodes
@@ -137,21 +357,14 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
auto* curr = *currp;
- if (curr->is<Block>()) {
- // special-case blocks, by marking their children as locals.
- // TODO sink from elsewhere? (need to make sure value is not used)
- self->pushTask(SimplifyLocals::doNoteNonLinear, currp);
- auto& list = curr->cast<Block>()->list;
- int size = list.size();
- // we can't sink the last element, as it might be a return value;
- // and anyhow, control flow is nonlinear at the end of the block so
- // it would be invalidated.
- for (int i = size - 1; i >= 0; i--) {
- if (i < size - 1) {
- self->pushTask(tryMarkSinkable, &list[i]);
- }
- self->pushTask(scan, &list[i]);
- }
+ if (curr->is<If>() && curr->cast<If>()->ifFalse) {
+ // handle if-elses in a special manner, using the ifStack
+ self->pushTask(SimplifyLocals::doNoteIfElseFalse, currp);
+ self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->ifFalse);
+ self->pushTask(SimplifyLocals::doNoteIfElseTrue, currp);
+ self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->ifTrue);
+ self->pushTask(SimplifyLocals::doNoteIfElseCondition, currp);
+ self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->condition);
} else {
WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>>::scan(self, currp);
}
@@ -166,37 +379,52 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// c(x, y)
// the load cannot cross the store, but y can be sunk, after which so can x
do {
- numGetLocals.resize(getFunction()->getNumLocals());
- sunk = false;
+ anotherCycle = false;
// main operation
WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>>::walk(root);
- // after optimizing a function, we can see if we have set_locals
- // for a local with no remaining gets, in which case, we can
- // remove the set.
- std::vector<SetLocal*> optimizables;
- for (auto pair : setLocalOrigins) {
- SetLocal* curr = pair.first;
- if (numGetLocals[curr->index] == 0) {
- // no gets, can remove the set and leave just the value
- optimizables.push_back(curr);
+ // enlarge blocks that were marked, for the next round
+ if (blocksToEnlarge.size() > 0) {
+ for (auto* block : blocksToEnlarge) {
+ block->list.push_back(getModule()->allocator.alloc<Nop>());
}
+ blocksToEnlarge.clear();
+ anotherCycle = true;
}
- for (auto* curr : optimizables) {
- Expression** origin = setLocalOrigins[curr];
- *origin = curr->value;
- // nested set_values need to be handled properly.
- // consider (set_local x (set_local y (..)), where both can be
- // reduced to their values, and we might do it in either
- // order.
- if (curr->value->is<SetLocal>()) {
- setLocalOrigins[curr->value->cast<SetLocal>()] = origin;
+ // enlarge ifs that were marked, for the next round
+ if (ifsToEnlarge.size() > 0) {
+ for (auto* iff : ifsToEnlarge) {
+ auto ifTrue = Builder(*getModule()).blockify(iff->ifTrue);
+ iff->ifTrue = ifTrue;
+ if (ifTrue->list.size() == 0 || !ifTrue->list.back()->is<Nop>()) {
+ ifTrue->list.push_back(getModule()->allocator.alloc<Nop>());
+ }
+ auto ifFalse = Builder(*getModule()).blockify(iff->ifFalse);
+ iff->ifFalse = ifFalse;
+ if (ifFalse->list.size() == 0 || !ifFalse->list.back()->is<Nop>()) {
+ ifFalse->list.push_back(getModule()->allocator.alloc<Nop>());
+ }
}
+ ifsToEnlarge.clear();
+ anotherCycle = true;
}
// clean up
- numGetLocals.clear();
- setLocalOrigins.clear();
sinkables.clear();
- } while (sunk);
+ blockBreaks.clear();
+ unoptimizableBlocks.clear();
+ } while (anotherCycle);
+ // Finally, after optimizing a function, we can see if we have set_locals
+ // for a local with no remaining gets, in which case, we can
+ // remove the set.
+ // First, count get_locals
+ std::vector<int> numGetLocals; // local => # of get_locals for it
+ numGetLocals.resize(getFunction()->getNumLocals());
+ GetLocalCounter counter;
+ counter.numGetLocals = &numGetLocals;
+ counter.walk(root);
+ // Second, remove unneeded sets
+ SetLocalRemover remover;
+ remover.numGetLocals = &numGetLocals;
+ remover.walk(root);
}
};