diff options
Diffstat (limited to 'src/passes/SimplifyLocals.cpp')
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 372 |
1 files changed, 300 insertions, 72 deletions
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index f0c786b1c..785bccf06 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -21,20 +21,52 @@ // and removing the set if there are no gets remaining (the latter is // particularly useful in ssa mode, but not only). // +// We also note where set_locals coalesce: if all breaks of a block set +// a specific local, we can use a block return value for it, in effect +// removing multiple set_locals and replacing them with one that the +// block returns to. Further optimization rounds then have the opportunity +// to remove that set_local as well. TODO: support partial traces; right +// now, whenever control flow splits, we invalidate everything. This is +// enough for SSA form, but not otherwise. +// // After this pass, some locals may be completely unused. reorder-locals // can get rid of those (the operation is trivial there after it sorts by use // frequency). #include <wasm.h> +#include <wasm-builder.h> #include <wasm-traversal.h> #include <pass.h> #include <ast_utils.h> namespace wasm { +// Helper classes + +struct GetLocalCounter : public WalkerPass<PostWalker<GetLocalCounter, Visitor<GetLocalCounter>>> { + std::vector<int>* numGetLocals; + + void visitGetLocal(GetLocal *curr) { + (*numGetLocals)[curr->index]++; + } +}; + +struct SetLocalRemover : public WalkerPass<PostWalker<SetLocalRemover, Visitor<SetLocalRemover>>> { + std::vector<int>* numGetLocals; + + void visitSetLocal(SetLocal *curr) { + if ((*numGetLocals)[curr->index] == 0) { + replaceCurrent(curr->value); + } + } +}; + +// Main class + struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>> { bool isFunctionParallel() { return true; } + // information for a set_local we can sink struct SinkableInfo { Expression** item; EffectAnalyzer effects; @@ -44,19 +76,98 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } }; + // a list of sinkables in a linear execution trace + typedef std::map<Index, SinkableInfo> Sinkables; + // locals in current linear execution trace, which we try to sink - std::map<Index, SinkableInfo> sinkables; + Sinkables sinkables; - bool sunk; + // Information about an exit from a block: the break, and the + // sinkables. For the final exit from a block (falling off) + // exitter is null. + struct BlockBreak { + Break* br; + Sinkables sinkables; + }; - // local => # of get_locals for it - std::vector<int> numGetLocals; + // a list of all sinkable traces that exit a block. the last + // is falling off the end, others are branches. this is used for + // block returns + std::map<Name, std::vector<BlockBreak>> blockBreaks; - // for each set_local, its origin pointer - std::map<SetLocal*, Expression**> setLocalOrigins; + // blocks that are the targets of a switch; we need to know this + // since we can't produce a block return value for them. + std::set<Name> unoptimizableBlocks; - void noteNonLinear() { - sinkables.clear(); + // A stack of sinkables from the current traversal state. When + // execution reaches an if-else, it splits, and can then + // be merged on return. + std::vector<Sinkables> ifStack; + + // whether we need to run an additional cycle + bool anotherCycle; + + static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) { + auto* curr = *currp; + if (curr->is<Break>()) { + auto* br = curr->cast<Break>(); + if (br->value) { + // value means the block already has a return value + self->unoptimizableBlocks.insert(br->name); + } else { + self->blockBreaks[br->name].push_back({ br, std::move(self->sinkables) }); + } + } else if (curr->is<Block>()) { + return; // handled in visitBlock + } else if (curr->is<If>()) { + assert(!curr->cast<If>()->ifFalse); // if-elses are handled by doNoteIfElse* methods + } else if (curr->is<Switch>()) { + auto* sw = curr->cast<Switch>(); + for (auto target : sw->targets) { + self->unoptimizableBlocks.insert(target); + } + // TODO: we could use this info to stop gathering data on these blocks + } + self->sinkables.clear(); + } + + static void doNoteIfElseCondition(SimplifyLocals* self, Expression** currp) { + // we processed the condition of this if-else, and now control flow branches + // into either the true or the false sides + assert((*currp)->cast<If>()->ifFalse); + self->sinkables.clear(); + } + + static void doNoteIfElseTrue(SimplifyLocals* self, Expression** currp) { + // we processed the ifTrue side of this if-else, save it on the stack + assert((*currp)->cast<If>()->ifFalse); + self->ifStack.push_back(std::move(self->sinkables)); + } + + static void doNoteIfElseFalse(SimplifyLocals* self, Expression** currp) { + // we processed the ifFalse side of this if-else, we can now try to + // mere with the ifTrue side and optimize a return value, if possible + auto* iff = (*currp)->cast<If>(); + assert(iff->ifFalse); + self->optimizeIfReturn(iff, currp, self->ifStack.back()); + self->ifStack.pop_back(); + self->sinkables.clear(); + } + + void visitBlock(Block* curr) { + bool hasBreaks = curr->name.is() && blockBreaks[curr->name].size() > 0; + + optimizeBlockReturn(curr); // can modify blockBreaks + + // post-block cleanups + if (curr->name.is()) { + unoptimizableBlocks.erase(curr->name); + } + if (hasBreaks) { + // more than one path to here, so nonlinear + sinkables.clear(); + blockBreaks.erase(curr->name); + } } void visitGetLocal(GetLocal *curr) { @@ -68,19 +179,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, *found->second.item = curr; ExpressionManipulator::nop(curr); sinkables.erase(found); - sunk = true; - } else { - numGetLocals[curr->index]++; - } - } - - void visitSetLocal(SetLocal *curr) { - // if we are a potentially-sinkable thing, then the previous - // store is dead, leave just the value - auto found = sinkables.find(curr->index); - if (found != sinkables.end()) { - *found->second.item = (*found->second.item)->cast<SetLocal>()->value; - sinkables.erase(found); + anotherCycle = true; } } @@ -97,6 +196,8 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } } + std::vector<Expression*> expressionStack; + static void visitPre(SimplifyLocals* self, Expression** currp) { Expression* curr = *currp; @@ -104,31 +205,150 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, if (effects.checkPre(curr)) { self->checkInvalidations(effects); } + + self->expressionStack.push_back(curr); } static void visitPost(SimplifyLocals* self, Expression** currp) { - Expression* curr = *currp; + // perform main SetLocal processing here, since we may be the result of + // replaceCurrent, i.e., the visitor was not called. + auto* set = (*currp)->dynCast<SetLocal>(); + + if (set) { + // if we see a set that was already potentially-sinkable, then the previous + // store is dead, leave just the value + auto found = self->sinkables.find(set->index); + if (found != self->sinkables.end()) { + *found->second.item = (*found->second.item)->cast<SetLocal>()->value; + self->sinkables.erase(found); + self->anotherCycle = true; + } + } EffectAnalyzer effects; - if (effects.checkPost(curr)) { + if (effects.checkPost(*currp)) { self->checkInvalidations(effects); } - // noting origins in the post means it happens after a - // get_local was replaced by a set_local in a sinking - // operation, so we track those movements properly. - if (curr->is<SetLocal>()) { - self->setLocalOrigins[curr->cast<SetLocal>()] = currp; + if (set) { + // we may be a replacement for the current node, update the stack + self->expressionStack.pop_back(); + self->expressionStack.push_back(set); + if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) { + Index index = set->index; + assert(self->sinkables.count(index) == 0); + self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); + } + } + + self->expressionStack.pop_back(); + } + + std::vector<Block*> blocksToEnlarge; + std::vector<If*> ifsToEnlarge; + + void optimizeBlockReturn(Block* block) { + if (!block->name.is() || unoptimizableBlocks.count(block->name) > 0) { + return; } + auto breaks = std::move(blockBreaks[block->name]); + blockBreaks.erase(block->name); + if (breaks.size() == 0) return; // block has no branches TODO we might optimize trivial stuff here too + assert(!breaks[0].br->value); // block does not already have a return value (if one break has one, they all do) + // look for a set_local that is present in them all + bool found = false; + Index sharedIndex = -1; + for (auto& sinkable : sinkables) { + Index index = sinkable.first; + bool inAll = true; + for (size_t j = 0; j < breaks.size(); j++) { + if (breaks[j].sinkables.count(index) == 0) { + inAll = false; + break; + } + } + if (inAll) { + sharedIndex = index; + found = true; + break; + } + } + if (!found) return; + // Great, this local is set in them all, we can optimize! + if (block->list.size() == 0 || !block->list.back()->is<Nop>()) { + // We can't do this here, since we can't push to the block - + // it would invalidate sinkable pointers. So we queue a request + // to grow the block at the end of the turn, we'll get this next + // cycle. + blocksToEnlarge.push_back(block); + return; + } + // move block set_local's value to the end, in return position, and nop the set + auto* blockSetLocalPointer = sinkables.at(sharedIndex).item; + auto* value = (*blockSetLocalPointer)->cast<SetLocal>()->value; + block->list[block->list.size() - 1] = value; + block->type = value->type; + ExpressionManipulator::nop(*blockSetLocalPointer); + for (size_t j = 0; j < breaks.size(); j++) { + // move break set_local's value to the break + auto* breakSetLocalPointer = breaks[j].sinkables.at(sharedIndex).item; + assert(!breaks[j].br->value); + breaks[j].br->value = (*breakSetLocalPointer)->cast<SetLocal>()->value; + ExpressionManipulator::nop(*breakSetLocalPointer); + } + // finally, create a set_local on the block itself + auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, block); + replaceCurrent(newSetLocal); + sinkables.clear(); + anotherCycle = true; } - static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) { - auto* curr = (*currp)->dynCast<SetLocal>(); - if (curr) { - Index index = curr->index; - assert(self->sinkables.count(index) == 0); - self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); + // optimize set_locals from both sides of an if into a return value + void optimizeIfReturn(If* iff, Expression** currp, Sinkables& ifTrue) { + assert(iff->ifFalse); + // if this if already has a result that is used, we can't do anything + assert(expressionStack.back() == iff); + if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) return; + // We now have the sinkables from both sides of the if. + Sinkables& ifFalse = sinkables; + Index sharedIndex = -1; + bool found = false; + for (auto& sinkable : ifTrue) { + Index index = sinkable.first; + if (ifFalse.count(index) > 0) { + sharedIndex = index; + found = true; + break; + } + } + if (!found) return; + // great, we can optimize! + // ensure we have a place to write the return values for, if not, we + // need another cycle + auto* ifTrueBlock = iff->ifTrue->dynCast<Block>(); + auto* ifFalseBlock = iff->ifFalse->dynCast<Block>(); + if (!ifTrueBlock || ifTrueBlock->list.size() == 0 || !ifTrueBlock->list.back()->is<Nop>() || + !ifFalseBlock || ifFalseBlock->list.size() == 0 || !ifFalseBlock->list.back()->is<Nop>()) { + ifsToEnlarge.push_back(iff); + return; } + // all set, go + auto *ifTrueItem = ifTrue.at(sharedIndex).item; + ifTrueBlock->list[ifTrueBlock->list.size() - 1] = (*ifTrueItem)->cast<SetLocal>()->value; + ExpressionManipulator::nop(*ifTrueItem); + ifTrueBlock->finalize(); + assert(ifTrueBlock->type != none); + auto *ifFalseItem = ifFalse.at(sharedIndex).item; + ifFalseBlock->list[ifFalseBlock->list.size() - 1] = (*ifFalseItem)->cast<SetLocal>()->value; + ExpressionManipulator::nop(*ifFalseItem); + ifFalseBlock->finalize(); + assert(ifTrueBlock->type != none); + iff->finalize(); // update type + assert(iff->type != none); + // finally, create a set_local on the iff itself + auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, iff); + *currp = newSetLocal; + anotherCycle = true; } // override scan to add a pre and a post check task to all nodes @@ -137,21 +357,14 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, auto* curr = *currp; - if (curr->is<Block>()) { - // special-case blocks, by marking their children as locals. - // TODO sink from elsewhere? (need to make sure value is not used) - self->pushTask(SimplifyLocals::doNoteNonLinear, currp); - auto& list = curr->cast<Block>()->list; - int size = list.size(); - // we can't sink the last element, as it might be a return value; - // and anyhow, control flow is nonlinear at the end of the block so - // it would be invalidated. - for (int i = size - 1; i >= 0; i--) { - if (i < size - 1) { - self->pushTask(tryMarkSinkable, &list[i]); - } - self->pushTask(scan, &list[i]); - } + if (curr->is<If>() && curr->cast<If>()->ifFalse) { + // handle if-elses in a special manner, using the ifStack + self->pushTask(SimplifyLocals::doNoteIfElseFalse, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->ifFalse); + self->pushTask(SimplifyLocals::doNoteIfElseTrue, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->ifTrue); + self->pushTask(SimplifyLocals::doNoteIfElseCondition, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast<If>()->condition); } else { WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>>::scan(self, currp); } @@ -166,37 +379,52 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // c(x, y) // the load cannot cross the store, but y can be sunk, after which so can x do { - numGetLocals.resize(getFunction()->getNumLocals()); - sunk = false; + anotherCycle = false; // main operation WalkerPass<LinearExecutionWalker<SimplifyLocals, Visitor<SimplifyLocals>>>::walk(root); - // after optimizing a function, we can see if we have set_locals - // for a local with no remaining gets, in which case, we can - // remove the set. - std::vector<SetLocal*> optimizables; - for (auto pair : setLocalOrigins) { - SetLocal* curr = pair.first; - if (numGetLocals[curr->index] == 0) { - // no gets, can remove the set and leave just the value - optimizables.push_back(curr); + // enlarge blocks that were marked, for the next round + if (blocksToEnlarge.size() > 0) { + for (auto* block : blocksToEnlarge) { + block->list.push_back(getModule()->allocator.alloc<Nop>()); } + blocksToEnlarge.clear(); + anotherCycle = true; } - for (auto* curr : optimizables) { - Expression** origin = setLocalOrigins[curr]; - *origin = curr->value; - // nested set_values need to be handled properly. - // consider (set_local x (set_local y (..)), where both can be - // reduced to their values, and we might do it in either - // order. - if (curr->value->is<SetLocal>()) { - setLocalOrigins[curr->value->cast<SetLocal>()] = origin; + // enlarge ifs that were marked, for the next round + if (ifsToEnlarge.size() > 0) { + for (auto* iff : ifsToEnlarge) { + auto ifTrue = Builder(*getModule()).blockify(iff->ifTrue); + iff->ifTrue = ifTrue; + if (ifTrue->list.size() == 0 || !ifTrue->list.back()->is<Nop>()) { + ifTrue->list.push_back(getModule()->allocator.alloc<Nop>()); + } + auto ifFalse = Builder(*getModule()).blockify(iff->ifFalse); + iff->ifFalse = ifFalse; + if (ifFalse->list.size() == 0 || !ifFalse->list.back()->is<Nop>()) { + ifFalse->list.push_back(getModule()->allocator.alloc<Nop>()); + } } + ifsToEnlarge.clear(); + anotherCycle = true; } // clean up - numGetLocals.clear(); - setLocalOrigins.clear(); sinkables.clear(); - } while (sunk); + blockBreaks.clear(); + unoptimizableBlocks.clear(); + } while (anotherCycle); + // Finally, after optimizing a function, we can see if we have set_locals + // for a local with no remaining gets, in which case, we can + // remove the set. + // First, count get_locals + std::vector<int> numGetLocals; // local => # of get_locals for it + numGetLocals.resize(getFunction()->getNumLocals()); + GetLocalCounter counter; + counter.numGetLocals = &numGetLocals; + counter.walk(root); + // Second, remove unneeded sets + SetLocalRemover remover; + remover.numGetLocals = &numGetLocals; + remover.walk(root); } }; |