From 1fa7f7fef75d2cb795f04b2737efb46a280c9cdc Mon Sep 17 00:00:00 2001 From: Alon Zakai Date: Fri, 22 Apr 2016 15:29:16 -0700 Subject: optimize block and if returns, by merging set_locals that flow out of them --- src/passes/SimplifyLocals.cpp | 372 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 300 insertions(+), 72 deletions(-) (limited to 'src/passes/SimplifyLocals.cpp') diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index f0c786b1c..785bccf06 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -21,20 +21,52 @@ // and removing the set if there are no gets remaining (the latter is // particularly useful in ssa mode, but not only). // +// We also note where set_locals coalesce: if all breaks of a block set +// a specific local, we can use a block return value for it, in effect +// removing multiple set_locals and replacing them with one that the +// block returns to. Further optimization rounds then have the opportunity +// to remove that set_local as well. TODO: support partial traces; right +// now, whenever control flow splits, we invalidate everything. This is +// enough for SSA form, but not otherwise. +// // After this pass, some locals may be completely unused. reorder-locals // can get rid of those (the operation is trivial there after it sorts by use // frequency). #include +#include #include #include #include namespace wasm { +// Helper classes + +struct GetLocalCounter : public WalkerPass>> { + std::vector* numGetLocals; + + void visitGetLocal(GetLocal *curr) { + (*numGetLocals)[curr->index]++; + } +}; + +struct SetLocalRemover : public WalkerPass>> { + std::vector* numGetLocals; + + void visitSetLocal(SetLocal *curr) { + if ((*numGetLocals)[curr->index] == 0) { + replaceCurrent(curr->value); + } + } +}; + +// Main class + struct SimplifyLocals : public WalkerPass>> { bool isFunctionParallel() { return true; } + // information for a set_local we can sink struct SinkableInfo { Expression** item; EffectAnalyzer effects; @@ -44,19 +76,98 @@ struct SimplifyLocals : public WalkerPass Sinkables; + // locals in current linear execution trace, which we try to sink - std::map sinkables; + Sinkables sinkables; - bool sunk; + // Information about an exit from a block: the break, and the + // sinkables. For the final exit from a block (falling off) + // exitter is null. + struct BlockBreak { + Break* br; + Sinkables sinkables; + }; - // local => # of get_locals for it - std::vector numGetLocals; + // a list of all sinkable traces that exit a block. the last + // is falling off the end, others are branches. this is used for + // block returns + std::map> blockBreaks; - // for each set_local, its origin pointer - std::map setLocalOrigins; + // blocks that are the targets of a switch; we need to know this + // since we can't produce a block return value for them. + std::set unoptimizableBlocks; - void noteNonLinear() { - sinkables.clear(); + // A stack of sinkables from the current traversal state. When + // execution reaches an if-else, it splits, and can then + // be merged on return. + std::vector ifStack; + + // whether we need to run an additional cycle + bool anotherCycle; + + static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) { + auto* curr = *currp; + if (curr->is()) { + auto* br = curr->cast(); + if (br->value) { + // value means the block already has a return value + self->unoptimizableBlocks.insert(br->name); + } else { + self->blockBreaks[br->name].push_back({ br, std::move(self->sinkables) }); + } + } else if (curr->is()) { + return; // handled in visitBlock + } else if (curr->is()) { + assert(!curr->cast()->ifFalse); // if-elses are handled by doNoteIfElse* methods + } else if (curr->is()) { + auto* sw = curr->cast(); + for (auto target : sw->targets) { + self->unoptimizableBlocks.insert(target); + } + // TODO: we could use this info to stop gathering data on these blocks + } + self->sinkables.clear(); + } + + static void doNoteIfElseCondition(SimplifyLocals* self, Expression** currp) { + // we processed the condition of this if-else, and now control flow branches + // into either the true or the false sides + assert((*currp)->cast()->ifFalse); + self->sinkables.clear(); + } + + static void doNoteIfElseTrue(SimplifyLocals* self, Expression** currp) { + // we processed the ifTrue side of this if-else, save it on the stack + assert((*currp)->cast()->ifFalse); + self->ifStack.push_back(std::move(self->sinkables)); + } + + static void doNoteIfElseFalse(SimplifyLocals* self, Expression** currp) { + // we processed the ifFalse side of this if-else, we can now try to + // mere with the ifTrue side and optimize a return value, if possible + auto* iff = (*currp)->cast(); + assert(iff->ifFalse); + self->optimizeIfReturn(iff, currp, self->ifStack.back()); + self->ifStack.pop_back(); + self->sinkables.clear(); + } + + void visitBlock(Block* curr) { + bool hasBreaks = curr->name.is() && blockBreaks[curr->name].size() > 0; + + optimizeBlockReturn(curr); // can modify blockBreaks + + // post-block cleanups + if (curr->name.is()) { + unoptimizableBlocks.erase(curr->name); + } + if (hasBreaks) { + // more than one path to here, so nonlinear + sinkables.clear(); + blockBreaks.erase(curr->name); + } } void visitGetLocal(GetLocal *curr) { @@ -68,19 +179,7 @@ struct SimplifyLocals : public WalkerPasssecond.item = curr; ExpressionManipulator::nop(curr); sinkables.erase(found); - sunk = true; - } else { - numGetLocals[curr->index]++; - } - } - - void visitSetLocal(SetLocal *curr) { - // if we are a potentially-sinkable thing, then the previous - // store is dead, leave just the value - auto found = sinkables.find(curr->index); - if (found != sinkables.end()) { - *found->second.item = (*found->second.item)->cast()->value; - sinkables.erase(found); + anotherCycle = true; } } @@ -97,6 +196,8 @@ struct SimplifyLocals : public WalkerPass expressionStack; + static void visitPre(SimplifyLocals* self, Expression** currp) { Expression* curr = *currp; @@ -104,31 +205,150 @@ struct SimplifyLocals : public WalkerPasscheckInvalidations(effects); } + + self->expressionStack.push_back(curr); } static void visitPost(SimplifyLocals* self, Expression** currp) { - Expression* curr = *currp; + // perform main SetLocal processing here, since we may be the result of + // replaceCurrent, i.e., the visitor was not called. + auto* set = (*currp)->dynCast(); + + if (set) { + // if we see a set that was already potentially-sinkable, then the previous + // store is dead, leave just the value + auto found = self->sinkables.find(set->index); + if (found != self->sinkables.end()) { + *found->second.item = (*found->second.item)->cast()->value; + self->sinkables.erase(found); + self->anotherCycle = true; + } + } EffectAnalyzer effects; - if (effects.checkPost(curr)) { + if (effects.checkPost(*currp)) { self->checkInvalidations(effects); } - // noting origins in the post means it happens after a - // get_local was replaced by a set_local in a sinking - // operation, so we track those movements properly. - if (curr->is()) { - self->setLocalOrigins[curr->cast()] = currp; + if (set) { + // we may be a replacement for the current node, update the stack + self->expressionStack.pop_back(); + self->expressionStack.push_back(set); + if (!ExpressionAnalyzer::isResultUsed(self->expressionStack, self->getFunction())) { + Index index = set->index; + assert(self->sinkables.count(index) == 0); + self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); + } + } + + self->expressionStack.pop_back(); + } + + std::vector blocksToEnlarge; + std::vector ifsToEnlarge; + + void optimizeBlockReturn(Block* block) { + if (!block->name.is() || unoptimizableBlocks.count(block->name) > 0) { + return; } + auto breaks = std::move(blockBreaks[block->name]); + blockBreaks.erase(block->name); + if (breaks.size() == 0) return; // block has no branches TODO we might optimize trivial stuff here too + assert(!breaks[0].br->value); // block does not already have a return value (if one break has one, they all do) + // look for a set_local that is present in them all + bool found = false; + Index sharedIndex = -1; + for (auto& sinkable : sinkables) { + Index index = sinkable.first; + bool inAll = true; + for (size_t j = 0; j < breaks.size(); j++) { + if (breaks[j].sinkables.count(index) == 0) { + inAll = false; + break; + } + } + if (inAll) { + sharedIndex = index; + found = true; + break; + } + } + if (!found) return; + // Great, this local is set in them all, we can optimize! + if (block->list.size() == 0 || !block->list.back()->is()) { + // We can't do this here, since we can't push to the block - + // it would invalidate sinkable pointers. So we queue a request + // to grow the block at the end of the turn, we'll get this next + // cycle. + blocksToEnlarge.push_back(block); + return; + } + // move block set_local's value to the end, in return position, and nop the set + auto* blockSetLocalPointer = sinkables.at(sharedIndex).item; + auto* value = (*blockSetLocalPointer)->cast()->value; + block->list[block->list.size() - 1] = value; + block->type = value->type; + ExpressionManipulator::nop(*blockSetLocalPointer); + for (size_t j = 0; j < breaks.size(); j++) { + // move break set_local's value to the break + auto* breakSetLocalPointer = breaks[j].sinkables.at(sharedIndex).item; + assert(!breaks[j].br->value); + breaks[j].br->value = (*breakSetLocalPointer)->cast()->value; + ExpressionManipulator::nop(*breakSetLocalPointer); + } + // finally, create a set_local on the block itself + auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, block); + replaceCurrent(newSetLocal); + sinkables.clear(); + anotherCycle = true; } - static void tryMarkSinkable(SimplifyLocals* self, Expression** currp) { - auto* curr = (*currp)->dynCast(); - if (curr) { - Index index = curr->index; - assert(self->sinkables.count(index) == 0); - self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); + // optimize set_locals from both sides of an if into a return value + void optimizeIfReturn(If* iff, Expression** currp, Sinkables& ifTrue) { + assert(iff->ifFalse); + // if this if already has a result that is used, we can't do anything + assert(expressionStack.back() == iff); + if (ExpressionAnalyzer::isResultUsed(expressionStack, getFunction())) return; + // We now have the sinkables from both sides of the if. + Sinkables& ifFalse = sinkables; + Index sharedIndex = -1; + bool found = false; + for (auto& sinkable : ifTrue) { + Index index = sinkable.first; + if (ifFalse.count(index) > 0) { + sharedIndex = index; + found = true; + break; + } + } + if (!found) return; + // great, we can optimize! + // ensure we have a place to write the return values for, if not, we + // need another cycle + auto* ifTrueBlock = iff->ifTrue->dynCast(); + auto* ifFalseBlock = iff->ifFalse->dynCast(); + if (!ifTrueBlock || ifTrueBlock->list.size() == 0 || !ifTrueBlock->list.back()->is() || + !ifFalseBlock || ifFalseBlock->list.size() == 0 || !ifFalseBlock->list.back()->is()) { + ifsToEnlarge.push_back(iff); + return; } + // all set, go + auto *ifTrueItem = ifTrue.at(sharedIndex).item; + ifTrueBlock->list[ifTrueBlock->list.size() - 1] = (*ifTrueItem)->cast()->value; + ExpressionManipulator::nop(*ifTrueItem); + ifTrueBlock->finalize(); + assert(ifTrueBlock->type != none); + auto *ifFalseItem = ifFalse.at(sharedIndex).item; + ifFalseBlock->list[ifFalseBlock->list.size() - 1] = (*ifFalseItem)->cast()->value; + ExpressionManipulator::nop(*ifFalseItem); + ifFalseBlock->finalize(); + assert(ifTrueBlock->type != none); + iff->finalize(); // update type + assert(iff->type != none); + // finally, create a set_local on the iff itself + auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, iff); + *currp = newSetLocal; + anotherCycle = true; } // override scan to add a pre and a post check task to all nodes @@ -137,21 +357,14 @@ struct SimplifyLocals : public WalkerPassis()) { - // special-case blocks, by marking their children as locals. - // TODO sink from elsewhere? (need to make sure value is not used) - self->pushTask(SimplifyLocals::doNoteNonLinear, currp); - auto& list = curr->cast()->list; - int size = list.size(); - // we can't sink the last element, as it might be a return value; - // and anyhow, control flow is nonlinear at the end of the block so - // it would be invalidated. - for (int i = size - 1; i >= 0; i--) { - if (i < size - 1) { - self->pushTask(tryMarkSinkable, &list[i]); - } - self->pushTask(scan, &list[i]); - } + if (curr->is() && curr->cast()->ifFalse) { + // handle if-elses in a special manner, using the ifStack + self->pushTask(SimplifyLocals::doNoteIfElseFalse, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast()->ifFalse); + self->pushTask(SimplifyLocals::doNoteIfElseTrue, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast()->ifTrue); + self->pushTask(SimplifyLocals::doNoteIfElseCondition, currp); + self->pushTask(SimplifyLocals::scan, &curr->cast()->condition); } else { WalkerPass>>::scan(self, currp); } @@ -166,37 +379,52 @@ struct SimplifyLocals : public WalkerPassgetNumLocals()); - sunk = false; + anotherCycle = false; // main operation WalkerPass>>::walk(root); - // after optimizing a function, we can see if we have set_locals - // for a local with no remaining gets, in which case, we can - // remove the set. - std::vector optimizables; - for (auto pair : setLocalOrigins) { - SetLocal* curr = pair.first; - if (numGetLocals[curr->index] == 0) { - // no gets, can remove the set and leave just the value - optimizables.push_back(curr); + // enlarge blocks that were marked, for the next round + if (blocksToEnlarge.size() > 0) { + for (auto* block : blocksToEnlarge) { + block->list.push_back(getModule()->allocator.alloc()); } + blocksToEnlarge.clear(); + anotherCycle = true; } - for (auto* curr : optimizables) { - Expression** origin = setLocalOrigins[curr]; - *origin = curr->value; - // nested set_values need to be handled properly. - // consider (set_local x (set_local y (..)), where both can be - // reduced to their values, and we might do it in either - // order. - if (curr->value->is()) { - setLocalOrigins[curr->value->cast()] = origin; + // enlarge ifs that were marked, for the next round + if (ifsToEnlarge.size() > 0) { + for (auto* iff : ifsToEnlarge) { + auto ifTrue = Builder(*getModule()).blockify(iff->ifTrue); + iff->ifTrue = ifTrue; + if (ifTrue->list.size() == 0 || !ifTrue->list.back()->is()) { + ifTrue->list.push_back(getModule()->allocator.alloc()); + } + auto ifFalse = Builder(*getModule()).blockify(iff->ifFalse); + iff->ifFalse = ifFalse; + if (ifFalse->list.size() == 0 || !ifFalse->list.back()->is()) { + ifFalse->list.push_back(getModule()->allocator.alloc()); + } } + ifsToEnlarge.clear(); + anotherCycle = true; } // clean up - numGetLocals.clear(); - setLocalOrigins.clear(); sinkables.clear(); - } while (sunk); + blockBreaks.clear(); + unoptimizableBlocks.clear(); + } while (anotherCycle); + // Finally, after optimizing a function, we can see if we have set_locals + // for a local with no remaining gets, in which case, we can + // remove the set. + // First, count get_locals + std::vector numGetLocals; // local => # of get_locals for it + numGetLocals.resize(getFunction()->getNumLocals()); + GetLocalCounter counter; + counter.numGetLocals = &numGetLocals; + counter.walk(root); + // Second, remove unneeded sets + SetLocalRemover remover; + remover.numGetLocals = &numGetLocals; + remover.walk(root); } }; -- cgit v1.2.3