diff options
Diffstat (limited to 'src/passes/CodeFolding.cpp')
-rw-r--r-- | src/passes/CodeFolding.cpp | 253 |
1 files changed, 154 insertions, 99 deletions
diff --git a/src/passes/CodeFolding.cpp b/src/passes/CodeFolding.cpp index a79980cfe..0479472d8 100644 --- a/src/passes/CodeFolding.cpp +++ b/src/passes/CodeFolding.cpp @@ -57,28 +57,29 @@ #include <iterator> -#include "wasm.h" -#include "pass.h" -#include "wasm-builder.h" -#include "ir/utils.h" #include "ir/branch-utils.h" #include "ir/effects.h" #include "ir/label-utils.h" +#include "ir/utils.h" +#include "pass.h" +#include "wasm-builder.h" +#include "wasm.h" namespace wasm { static const Index WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH = 3; -struct ExpressionMarker : public PostWalker<ExpressionMarker, UnifiedExpressionVisitor<ExpressionMarker>> { +struct ExpressionMarker + : public PostWalker<ExpressionMarker, + UnifiedExpressionVisitor<ExpressionMarker>> { std::set<Expression*>& marked; - ExpressionMarker(std::set<Expression*>& marked, Expression* expr) : marked(marked) { + ExpressionMarker(std::set<Expression*>& marked, Expression* expr) + : marked(marked) { walk(expr); } - void visitExpression(Expression* expr) { - marked.insert(expr); - } + void visitExpression(Expression* expr) { marked.insert(expr); } }; struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { @@ -91,15 +92,18 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { struct Tail { Expression* expr; // nullptr if this is a fallthrough Block* block; // the enclosing block of code we hope to merge at its tail - Expression** pointer; // for an expr with no parent block, the location it is at, so we can replace it + Expression** pointer; // for an expr with no parent block, the location it + // is at, so we can replace it // For a fallthrough Tail(Block* block) : expr(nullptr), block(block), pointer(nullptr) {} // For a break - Tail(Expression* expr, Block* block) : expr(expr), block(block), pointer(nullptr) { + Tail(Expression* expr, Block* block) + : expr(expr), block(block), pointer(nullptr) { validate(); } - Tail(Expression* expr, Expression** pointer) : expr(expr), block(nullptr), pointer(pointer) {} + Tail(Expression* expr, Expression** pointer) + : expr(expr), block(nullptr), pointer(pointer) {} bool isFallthrough() const { return expr == nullptr; } @@ -116,11 +120,13 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { // pass state - std::map<Name, std::vector<Tail>> breakTails; // break target name => tails that reach it + std::map<Name, std::vector<Tail>> breakTails; // break target name => tails + // that reach it std::vector<Tail> unreachableTails; // tails leading to (unreachable) - std::vector<Tail> returnTails; // tails leading to (return) - std::set<Name> unoptimizables; // break target names that we can't handle - std::set<Expression*> modifieds; // modified code should not be processed again, wait for next pass + std::vector<Tail> returnTails; // tails leading to (return) + std::set<Name> unoptimizables; // break target names that we can't handle + std::set<Expression*> modifieds; // modified code should not be processed + // again, wait for next pass // walking @@ -167,20 +173,25 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { return; } } - // otherwise, if we have a large value, it might be worth optimizing us as well + // otherwise, if we have a large value, it might be worth optimizing us as + // well returnTails.push_back(Tail(curr, getCurrentPointer())); } void visitBlock(Block* curr) { - if (curr->list.empty()) return; - if (!curr->name.is()) return; - if (unoptimizables.count(curr->name) > 0) return; + if (curr->list.empty()) + return; + if (!curr->name.is()) + return; + if (unoptimizables.count(curr->name) > 0) + return; // we can't optimize a fallthrough value if (isConcreteType(curr->list.back()->type)) { return; } auto iter = breakTails.find(curr->name); - if (iter == breakTails.end()) return; + if (iter == breakTails.end()) + return; // looks promising auto& tails = iter->second; // see if there is a fallthrough @@ -191,23 +202,22 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { } } if (hasFallthrough) { - tails.push_back({ Tail(curr) }); + tails.push_back({Tail(curr)}); } optimizeExpressionTails(tails, curr); } void visitIf(If* curr) { - if (!curr->ifFalse) return; + if (!curr->ifFalse) + return; // if both sides are identical, this is easy to fold if (ExpressionAnalyzer::equal(curr->ifTrue, curr->ifFalse)) { Builder builder(*getModule()); // remove if (4 bytes), remove one arm, add drop (1), add block (3), // so this must be a net savings markAsModified(curr); - auto* ret = builder.makeSequence( - builder.makeDrop(curr->condition), - curr->ifTrue - ); + auto* ret = + builder.makeSequence(builder.makeDrop(curr->condition), curr->ifTrue); // we must ensure we present the same type as the if had ret->finalize(curr->type); replaceCurrent(ret); @@ -237,9 +247,8 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { } // we need nameless blocks, as if there is a name, someone might branch // to the end, skipping the code we want to merge - if (left && right && - !left->name.is() && !right->name.is()) { - std::vector<Tail> tails = { Tail(left), Tail(right) }; + if (left && right && !left->name.is() && !right->name.is()) { + std::vector<Tail> tails = {Tail(left), Tail(right)}; optimizeExpressionTails(tails, curr); } } @@ -251,7 +260,8 @@ struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { anotherPass = false; super::doWalkFunction(func); optimizeTerminatingTails(unreachableTails); - // optimize returns at the end, so we can benefit from a fallthrough if there is a value TODO: separate passes for them? + // optimize returns at the end, so we can benefit from a fallthrough if + // there is a value TODO: separate passes for them? optimizeTerminatingTails(returnTails); // TODO add fallthrough for returns // TODO optimize returns not in blocks, a big return value can be worth it @@ -277,7 +287,10 @@ private: for (auto* item : items) { auto exiting = BranchUtils::getExitingBranches(item); std::vector<Name> intersection; - std::set_intersection(allTargets.begin(), allTargets.end(), exiting.begin(), exiting.end(), + std::set_intersection(allTargets.begin(), + allTargets.end(), + exiting.begin(), + exiting.end(), std::back_inserter(intersection)); if (intersection.size() > 0) { // anything exiting that is in all targets is something bad @@ -287,15 +300,18 @@ private: return true; } - // optimize tails that reach the outside of an expression. code that is identical in all - // paths leading to the block exit can be merged. + // optimize tails that reach the outside of an expression. code that is + // identical in all paths leading to the block exit can be merged. template<typename T> void optimizeExpressionTails(std::vector<Tail>& tails, T* curr) { - if (tails.size() < 2) return; + if (tails.size() < 2) + return; // see if anything is untoward, and we should not do this for (auto& tail : tails) { - if (tail.expr && modifieds.count(tail.expr) > 0) return; - if (modifieds.count(tail.block) > 0) return; + if (tail.expr && modifieds.count(tail.expr) > 0) + return; + if (modifieds.count(tail.block) > 0) + return; // if we were not modified, then we should be valid for processing tail.validate(); } @@ -316,7 +332,7 @@ private: // elements to be worth that extra block (although, there is // some chance the block would get merged higher up, see later) std::vector<Expression*> mergeable; // the elements we can merge - Index num = 0; // how many elements back from the tail to look at + Index num = 0; // how many elements back from the tail to look at Index saved = 0; // how much we can save while (1) { // check if this num is still relevant @@ -329,7 +345,8 @@ private: break; } } - if (stop) break; + if (stop) + break; auto* item = getMergeable(tails[0], num); for (auto& tail : tails) { if (!ExpressionAnalyzer::equal(item, getMergeable(tail, num))) { @@ -338,15 +355,18 @@ private: break; } } - if (stop) break; + if (stop) + break; // we may have found another one we can merge - can we move it? - if (!canMove({ item }, curr)) break; + if (!canMove({item}, curr)) + break; // we found another one we can merge mergeable.push_back(item); num++; saved += Measurer::measure(item); } - if (saved == 0) return; + if (saved == 0) + return; // we may be able to save enough. if (saved < WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH) { // it's not obvious we can save enough. see if we get rid @@ -363,13 +383,16 @@ private: if (!willEmptyBlock) { // last chance, if our parent is a block, then it should be // fine to create a new block here, it will be merged up - assert(curr == controlFlowStack.back()); // we are an if or a block, at the top + // we are an if or a block, at the top + assert(curr == controlFlowStack.back()); if (controlFlowStack.size() <= 1) { return; // no parent at all - // TODO: if we are the toplevel in the function, then in the binary format - // we might avoid emitting a block, so the same logic applies here? + // TODO: if we are the toplevel in the function, then in the binary + // format we might avoid emitting a block, so the same logic + // applies here? } - auto* parent = controlFlowStack[controlFlowStack.size() - 2]->dynCast<Block>(); + auto* parent = + controlFlowStack[controlFlowStack.size() - 2]->dynCast<Block>(); if (!parent) { return; // parent is not a block } @@ -440,15 +463,23 @@ private: // deeper merges first. // returns whether we optimized something. bool optimizeTerminatingTails(std::vector<Tail>& tails, Index num = 0) { - if (tails.size() < 2) return false; - // remove things that are untoward and cannot be optimized - tails.erase(std::remove_if(tails.begin(), tails.end(), [&](Tail& tail) { - if (tail.expr && modifieds.count(tail.expr) > 0) return true; - if (tail.block && modifieds.count(tail.block) > 0) return true; - // if we were not modified, then we should be valid for processing - tail.validate(); + if (tails.size() < 2) return false; - }), tails.end()); + // remove things that are untoward and cannot be optimized + tails.erase( + std::remove_if(tails.begin(), + tails.end(), + [&](Tail& tail) { + if (tail.expr && modifieds.count(tail.expr) > 0) + return true; + if (tail.block && modifieds.count(tail.block) > 0) + return true; + // if we were not modified, then we should be valid for + // processing + tail.validate(); + return false; + }), + tails.end()); // now let's try to find subsets that are mergeable. we don't look hard // for the most optimal; further passes may find more // effectiveSize: TODO: special-case fallthrough, matters for returns @@ -481,7 +512,7 @@ private: // estimate if a merging is worth the cost auto worthIt = [&](Index num, std::vector<Tail>& tails) { auto items = getTailItems(num, tails); // the elements we can merge - Index saved = 0; // how much we can save + Index saved = 0; // how much we can save for (auto* item : items) { saved += Measurer::measure(item) * (tails.size() - 1); } @@ -496,7 +527,8 @@ private: cost += WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; // if we cannot merge to the end, then we definitely need 2 blocks, // and a branch - if (!canMove(items, getFunction()->body)) { // TODO: efficiency, entire body + // TODO: efficiency, entire body + if (!canMove(items, getFunction()->body)) { cost += 1 + WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; // TODO: to do this, we need to maintain a map of element=>parent, // so that we can insert the new blocks in the right place @@ -509,64 +541,86 @@ private: // let's see if we can merge deeper than num, to num + 1 auto next = tails; // remove tails that are too short, or that we hit an item we can't handle - next.erase(std::remove_if(next.begin(), next.end(), [&](Tail& tail) { - if (effectiveSize(tail) < num + 1) return true; - auto* newItem = getItem(tail, num); - // ignore tails that break to outside blocks. we want to move code to - // the very outermost position, so such code cannot be moved - // TODO: this should not be a problem in *non*-terminating tails, - // but double-verify that - if (EffectAnalyzer(getPassOptions(), newItem).hasExternalBreakTargets()) { - return true; - } - return false; - }), next.end()); + next.erase(std::remove_if(next.begin(), + next.end(), + [&](Tail& tail) { + if (effectiveSize(tail) < num + 1) + return true; + auto* newItem = getItem(tail, num); + // ignore tails that break to outside blocks. we + // want to move code to the very outermost + // position, so such code cannot be moved + // TODO: this should not be a problem in + // *non*-terminating tails, but + // double-verify that + if (EffectAnalyzer(getPassOptions(), newItem) + .hasExternalBreakTargets()) { + return true; + } + return false; + }), + next.end()); // if we have enough to investigate, do so if (next.size() >= 2) { - // now we want to find a mergeable item - any item that is equal among a subset + // now we want to find a mergeable item - any item that is equal among a + // subset std::map<Expression*, HashType> hashes; // expression => hash value - std::map<HashType, std::vector<Expression*>> hashed; // hash value => expressions with that hash + // hash value => expressions with that hash + std::map<HashType, std::vector<Expression*>> hashed; for (auto& tail : next) { auto* item = getItem(tail, num); auto hash = hashes[item] = ExpressionAnalyzer::hash(item); hashed[hash].push_back(item); } - // look at each hash value exactly once. we do this in a deterministic order. + // look at each hash value exactly once. we do this in a deterministic + // order. std::set<HashType> seen; for (auto& tail : next) { auto* item = getItem(tail, num); auto hash = hashes[item]; - if (seen.count(hash)) continue; + if (seen.count(hash)) + continue; seen.insert(hash); auto& items = hashed[hash]; - if (items.size() == 1) continue; + if (items.size() == 1) + continue; assert(items.size() > 0); // look for an item that has another match. while (items.size() >= 2) { auto first = items[0]; std::vector<Expression*> others; - items.erase(std::remove_if(items.begin(), items.end(), [&](Expression* item) { - if (item == first || // don't bother comparing the first - ExpressionAnalyzer::equal(item, first)) { - // equal, keep it - return false; - } else { - // unequal, look at it later - others.push_back(item); - return true; - } - }), items.end()); + items.erase( + std::remove_if(items.begin(), + items.end(), + [&](Expression* item) { + if (item == + first || // don't bother comparing the first + ExpressionAnalyzer::equal(item, first)) { + // equal, keep it + return false; + } else { + // unequal, look at it later + others.push_back(item); + return true; + } + }), + items.end()); if (items.size() >= 2) { // possible merge here, investigate it auto* correct = items[0]; auto explore = next; - explore.erase(std::remove_if(explore.begin(), explore.end(), [&](Tail& tail) { - auto* item = getItem(tail, num); - return !ExpressionAnalyzer::equal(item, correct); - }), explore.end()); - // try to optimize this deeper tail. if we succeed, then stop here, as the - // changes may influence us. we leave further opts to further passes (as this - // is rare in practice, it's generally not a perf issue, but TODO optimize) + explore.erase(std::remove_if(explore.begin(), + explore.end(), + [&](Tail& tail) { + auto* item = getItem(tail, num); + return !ExpressionAnalyzer::equal( + item, correct); + }), + explore.end()); + // try to optimize this deeper tail. if we succeed, then stop here, + // as the changes may influence us. we leave further opts to further + // passes (as this is rare in practice, it's generally not a perf + // issue, but TODO optimize) if (optimizeTerminatingTails(explore, num + 1)) { return true; } @@ -578,15 +632,18 @@ private: // we explored deeper (higher num) options, but perhaps there // was nothing there while there is something we can do at this level // but if we are at num == 0, then we found nothing at all - if (num == 0) return false; + if (num == 0) + return false; // if not worth it, stop - if (!worthIt(num, tails)) return false; + if (!worthIt(num, tails)) + return false; // this is worth doing, do it! auto mergeable = getTailItems(num, tails); // the elements we can merge // since we managed a merge, then it might open up more opportunities later anotherPass = true; Builder builder(*getModule()); - LabelUtils::LabelManager labels(getFunction()); // TODO: don't create one per merge, linear in function size + // TODO: don't create one per merge, linear in function size + LabelUtils::LabelManager labels(getFunction()); Name innerName = labels.getUnique("folding-inner"); for (auto& tail : tails) { // remove the items we are merging / moving, and add a break @@ -623,7 +680,8 @@ private: // rules, and now it won't be toplevel in the function, it can // change) auto* toplevel = old->dynCast<Block>(); - if (toplevel) toplevel->finalize(); + if (toplevel) + toplevel->finalize(); if (old->type != unreachable) { inner->list.push_back(builder.makeReturn(old)); } else { @@ -649,9 +707,6 @@ private: } }; -Pass *createCodeFoldingPass() { - return new CodeFolding(); -} +Pass* createCodeFoldingPass() { return new CodeFolding(); } } // namespace wasm - |