/* * Copyright 2017 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Folds duplicate code together, saving space (and possibly phis in // the wasm VM, which can save time). // // We fold tails of code where they merge and moving the code // to the merge point is helpful. There are two cases here: (1) expressions, // in which we merge to right after the expression itself, in these cases: // * blocks, we merge the fallthrough + the breaks // * if-else, we merge the arms // and (2) the function body as a whole, in which we can merge returns or // unreachables, putting the merged code at the end of the function body. // // For example, with an if-else, we might merge this: // (if (condition) // (block // A // C // ) // (block // B // C // ) // ) // to // (if (condition) // (block // A // ) // (block // B // ) // ) // C // // Note that the merged code, C in the example above, can be anything, // including code with control flow. If C is identical in all the locations, // then it must be safe to merge (if it contains a branch to something // higher up, then since our branch target names are unique, it must be // to the same thing, and after merging it can still reach it). // #include #include "ir/branch-utils.h" #include "ir/effects.h" #include "ir/eh-utils.h" #include "ir/find_all.h" #include "ir/label-utils.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" #include "wasm.h" namespace wasm { static const Index WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH = 3; struct ExpressionMarker : public PostWalker> { std::set& marked; ExpressionMarker(std::set& marked, Expression* expr) : marked(marked) { walk(expr); } void visitExpression(Expression* expr) { marked.insert(expr); } }; struct CodeFolding : public WalkerPass< ControlFlowWalker>> { bool isFunctionParallel() override { return true; } std::unique_ptr create() override { return std::make_unique(); } // information about a "tail" - code that reaches a point that we can // merge (e.g., a branch and some code leading up to it) struct Tail { Expression* expr; // nullptr if this is a fallthrough Block* block; // the enclosing block of code we hope to merge at its tail Expression** pointer; // for an expr with no parent block, the location it // is at, so we can replace it // For a fallthrough Tail(Block* block) : expr(nullptr), block(block), pointer(nullptr) {} // For a break Tail(Expression* expr, Block* block) : expr(expr), block(block), pointer(nullptr) {} Tail(Expression* expr, Expression** pointer) : expr(expr), block(nullptr), pointer(pointer) {} bool isFallthrough() const { return expr == nullptr; } }; // state // Set when we optimized and believe another pass is warranted. bool anotherPass; // Set when we optimized in a manner that requires EH fixups specifically, // which is generally the case when we wrap things in a block. bool needEHFixups; // pass state std::map> breakTails; // break target name => tails // that reach it std::vector unreachableTails; // tails leading to (unreachable) std::vector returnTails; // tails leading to (return) std::set unoptimizables; // break target names that we can't handle std::set modifieds; // modified code should not be processed // again, wait for next pass // walking void visitExpression(Expression* curr) { // For any branching instruction not explicitly handled by this pass, mark // the labels it branches to unoptimizable. // TODO: Handle folding br_on* instructions. br_on_null could be folded with // other kinds of branches and br_on_non_null, br_on_cast, and // br_on_cast_fail instructions could be folded with other copies of // themselves. BranchUtils::operateOnScopeNameUses( curr, [&](Name label) { unoptimizables.insert(label); }); } void visitBreak(Break* curr) { if (curr->condition) { unoptimizables.insert(curr->name); } else { // we can only optimize if we are at the end of the parent block. // TODO: Relax this. Block* parent = controlFlowStack.back()->dynCast(); if (parent && curr == parent->list.back()) { breakTails[curr->name].push_back(Tail(curr, parent)); } else { unoptimizables.insert(curr->name); } } } void visitUnreachable(Unreachable* curr) { // we can only optimize if we are at the end of the parent block if (!controlFlowStack.empty()) { Block* parent = controlFlowStack.back()->dynCast(); if (parent && curr == parent->list.back()) { unreachableTails.push_back(Tail(curr, parent)); } } } void handleReturn(Expression* curr) { if (!controlFlowStack.empty()) { // we can easily optimize if we are at the end of the parent block Block* parent = controlFlowStack.back()->dynCast(); if (parent && curr == parent->list.back()) { returnTails.push_back(Tail(curr, parent)); return; } } // otherwise, if we have a large value, it might be worth optimizing us as // well returnTails.push_back(Tail(curr, getCurrentPointer())); } void visitReturn(Return* curr) { handleReturn(curr); } void visitCall(Call* curr) { if (curr->isReturn) { handleReturn(curr); } } void visitCallIndirect(CallIndirect* curr) { if (curr->isReturn) { handleReturn(curr); } } void visitCallRef(CallRef* curr) { if (curr->isReturn) { handleReturn(curr); } } void visitBlock(Block* curr) { if (curr->list.empty()) { return; } if (!curr->name.is()) { return; } if (unoptimizables.count(curr->name) > 0) { return; } auto iter = breakTails.find(curr->name); if (iter == breakTails.end()) { return; } // Looks promising. auto& tails = iter->second; // If the end of the block cannot be reached, then we don't need to include // it in the set of folded tails. bool includeFallthrough = !std::any_of(curr->list.begin(), curr->list.end(), [&](auto* child) { return child->type == Type::unreachable; }); if (includeFallthrough) { tails.push_back({Tail(curr)}); } optimizeExpressionTails(tails, curr); } void visitIf(If* curr) { if (!curr->ifFalse) { return; } if (curr->condition->type == Type::unreachable) { // If the arms are foldable and concrete, we would be replacing an // unreachable If with a concrete block, which may or may not be valid, // depending on the context. Leave this for DCE rather than trying to // handle that. return; } // If both are blocks, look for a tail we can merge. auto* left = curr->ifTrue->dynCast(); auto* right = curr->ifFalse->dynCast(); // If one is a block and the other isn't, and the non-block is a tail of the // other, we can fold that - for our convenience, we just add a block and // run the rest of the optimization mormally. auto maybeAddBlock = [this](Block* block, Expression*& other) -> Block* { // If other is a suffix of the block, wrap it in a block. if (block->list.empty() || !ExpressionAnalyzer::equal(other, block->list.back())) { return nullptr; } // Do it, assign to the out param `other`, and return the block. Builder builder(*getModule()); auto* ret = builder.makeBlock(other); other = ret; return ret; }; if (left && !right) { right = maybeAddBlock(left, curr->ifFalse); } else if (!left && right) { left = maybeAddBlock(right, curr->ifTrue); } // We need nameless blocks, as if there is a name, someone might branch to // the end, skipping the code we want to merge. if (left && right && !left->name.is() && !right->name.is()) { std::vector tails = {Tail(left), Tail(right)}; optimizeExpressionTails(tails, curr); } } void doWalkFunction(Function* func) { anotherPass = true; while (anotherPass) { anotherPass = false; needEHFixups = false; Super::doWalkFunction(func); optimizeTerminatingTails(unreachableTails); // optimize returns at the end, so we can benefit from a fallthrough if // there is a value TODO: separate passes for them? optimizeTerminatingTails(returnTails); // TODO add fallthrough for returns // TODO optimize returns not in blocks, a big return value can be worth it // clean up breakTails.clear(); unreachableTails.clear(); returnTails.clear(); unoptimizables.clear(); modifieds.clear(); if (needEHFixups) { EHUtils::handleBlockNestedPops(func, *getModule()); } } } private: // check if we can move a list of items out of another item. we can't do so // if one of the items has a branch to something inside outOf that is not // inside that item bool canMove(const std::vector& items, Expression* outOf) { auto allTargets = BranchUtils::getBranchTargets(outOf); for (auto* item : items) { auto exiting = BranchUtils::getExitingBranches(item); std::vector intersection; std::set_intersection(allTargets.begin(), allTargets.end(), exiting.begin(), exiting.end(), std::back_inserter(intersection)); if (intersection.size() > 0) { // anything exiting that is in all targets is something bad return false; } if (getModule()->features.hasExceptionHandling()) { EffectAnalyzer effects(getPassOptions(), *getModule(), item); // Pop instructions are pseudoinstructions used only after 'catch' to // simulate its behavior. We cannot move expressions containing pops if // they are not enclosed in a 'catch' body, because a pop instruction // should follow right after 'catch'. if (effects.danglingPop) { return false; } // When an expression can throw and it is within a try/try_table scope, // taking it out of the try/try_table scope changes the program's // behavior, because the expression that would otherwise have been // caught by the try/try_table now throws up to the next try/try_table // scope or even up to the caller. We restrict the move if 'outOf' // contains a 'try' or 'try_table' anywhere in it. This is a // conservative approximation because there can be cases that // 'try'/'try_table' is within the expression that may throw so it is // safe to take the expression out. // TODO: optimize this check to avoid two FindAlls. if (effects.throws() && (FindAll(outOf).has() || FindAll(outOf).has())) { return false; } } } return true; } // optimize tails that reach the outside of an expression. code that is // identical in all paths leading to the block exit can be merged. template void optimizeExpressionTails(std::vector& tails, T* curr) { auto oldType = curr->type; if (tails.size() < 2) { return; } // see if anything is untoward, and we should not do this for (auto& tail : tails) { if (tail.expr && modifieds.count(tail.expr) > 0) { return; } if (modifieds.count(tail.block) > 0) { return; } // if we were not modified, then we should be valid for processing assert(!tail.expr || !tail.block || (tail.expr == tail.block->list.back())); } auto getMergeable = [&](const Tail& tail, Index num) -> Expression* { if (!tail.isFallthrough()) { // If there is a branch value, it is the first mergeable item. auto* val = tail.expr->cast()->value; if (val && num == 0) { return val; } if (!val) { // Skip the branch instruction at the end; it is not part of the // merged tail. ++num; } } if (num >= tail.block->list.size()) { return nullptr; } return tail.block->list[tail.block->list.size() - num - 1]; }; // we are going to remove duplicate elements and add a block. // so for this to make sense, we need the size of the duplicate // elements to be worth that extra block (although, there is // some chance the block would get merged higher up, see later) std::vector mergeable; // the elements we can merge Index saved = 0; // how much we can save for (Index num = 0; true; ++num) { auto* item = getMergeable(tails[0], num); if (!item) { // The list is too short. break; } Index tail = 1; for (; tail < tails.size(); ++tail) { auto* other = getMergeable(tails[tail], num); if (!other || !ExpressionAnalyzer::equal(item, other)) { // Other tail too short or has a difference. break; } } if (tail != tails.size()) { // We saw a tail without a matching item. break; } // we may have found another one we can merge - can we move it? if (!canMove({item}, curr)) { break; } // we found another one we can merge mergeable.push_back(item); saved += Measurer::measure(item); } if (saved == 0) { return; } // we may be able to save enough. if (saved < WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH) { // it's not obvious we can save enough. see if we get rid // of a block, that would justify this bool willEmptyBlock = false; for (auto& tail : tails) { // it is enough to zero out the block, or leave just one // element, as then the block can be replaced with that if (mergeable.size() >= tail.block->list.size() - 1) { willEmptyBlock = true; break; } } if (!willEmptyBlock) { // last chance, if our parent is a block, then it should be // fine to create a new block here, it will be merged up // we are an if or a block, at the top assert(curr == controlFlowStack.back()); if (controlFlowStack.size() <= 1) { return; // no parent at all // TODO: if we are the toplevel in the function, then in the binary // format we might avoid emitting a block, so the same logic // applies here? } auto* parent = controlFlowStack[controlFlowStack.size() - 2]->dynCast(); if (!parent) { return; // parent is not a block } bool isChild = false; for (auto* child : parent->list) { if (child == curr) { isChild = true; break; } } if (!isChild) { return; // not a child, something in between } } } // this is worth doing, do it! for (auto& tail : tails) { // remove the items we are merging / moving // first, mark them as modified, so we don't try to handle them // again in this pass, which might be buggy markAsModified(tail.block); // we must preserve the br if there is one Break* branch = nullptr; if (!tail.isFallthrough()) { branch = tail.block->list.back()->cast(); if (branch->value) { branch->value = nullptr; } else { tail.block->list.pop_back(); } } for (Index i = 0; i < mergeable.size(); ++i) { tail.block->list.pop_back(); } if (tail.isFallthrough()) { // The block now ends in an expression that was previously in the middle // of the block, meaning it must have type none. tail.block->finalize(Type::none); } else { tail.block->list.push_back(branch); // The block still ends with the same branch it previously ended with, // so its type cannot have changed. tail.block->finalize(tail.block->type); } } // since we managed a merge, then it might open up more opportunities later anotherPass = true; // make a block with curr + the merged code Builder builder(*getModule()); auto* block = builder.makeBlock(); if constexpr (T::SpecificId == Expression::IfId) { // If we've moved all the contents out of both arms of the If, then we can // simplify the output by replacing it entirely with just a drop of the // condition. auto* iff = curr->template cast(); if (iff->ifTrue->template cast()->list.empty() && iff->ifFalse->template cast()->list.empty()) { block->list.push_back(builder.makeDrop(iff->condition)); } else { block->list.push_back(curr); } } else { block->list.push_back(curr); } while (!mergeable.empty()) { block->list.push_back(mergeable.back()); mergeable.pop_back(); } if constexpr (T::SpecificId == Expression::BlockId) { // If we didn't have a fallthrough tail because the end of the block was // not reachable, then we might have a concrete expression at the end of // the block even though the value produced by the block has been moved // out of it. If so, drop that expression. auto* currBlock = curr->template cast(); currBlock->list.back() = builder.dropIfConcretelyTyped(currBlock->list.back()); } // NB: we template-specialize so that this calls the proper finalizer for // the type curr->finalize(); // ensure the replacement has the same type, so the outside is not surprised block->finalize(oldType); replaceCurrent(block); needEHFixups = true; } // optimize tails that terminate control flow in this function, so we // are (1) merge just a few of them, we don't need all like with the // branches to a block, and (2) we do it on the function body. // num is the depth, i.e., how many tail items we can merge. 0 means // we are just starting; num > 0 means that tails is guaranteed to be // equal in the last num items, so we can merge there, but we look for // deeper merges first. // returns whether we optimized something. bool optimizeTerminatingTails(std::vector& tails, Index num = 0) { if (tails.size() < 2) { return false; } // remove things that are untoward and cannot be optimized tails.erase( std::remove_if(tails.begin(), tails.end(), [&](Tail& tail) { if (tail.expr && modifieds.count(tail.expr) > 0) { return true; } if (tail.block && modifieds.count(tail.block) > 0) { return true; } return false; }), tails.end()); // now let's try to find subsets that are mergeable. we don't look hard // for the most optimal; further passes may find more // effectiveSize: TODO: special-case fallthrough, matters for returns auto effectiveSize = [&](Tail& tail) -> Index { if (tail.block) { return tail.block->list.size(); } else { return 1; } }; // getItem: returns the relevant item from the tail. this includes the // final item // TODO: special-case fallthrough, matters for returns auto getItem = [&](Tail& tail, Index num) { if (tail.block) { return tail.block->list[effectiveSize(tail) - num - 1]; } else { return tail.expr; } }; // gets the tail elements of a certain depth auto getTailItems = [&](Index num, std::vector& tails) { std::vector items; for (Index i = 0; i < num; i++) { auto item = getItem(tails[0], i); items.push_back(item); } return items; }; // estimate if a merging is worth the cost auto worthIt = [&](Index num, std::vector& tails) { auto items = getTailItems(num, tails); // the elements we can merge Index saved = 0; // how much we can save for (auto* item : items) { saved += Measurer::measure(item) * (tails.size() - 1); } // compure the cost: in non-fallthroughs, we are replacing the final // element with a br; for a fallthrough, if there is one, we must // add a return element (for the function body, so it doesn't reach us) // TODO: handle fallthroughts for return Index cost = tails.size(); // we also need to add two blocks: for us to break to, and to contain // that block and the merged code. very possibly one of the blocks // can be removed, though cost += WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; // if we cannot merge to the end, then we definitely need 2 blocks, // and a branch // TODO: efficiency, entire body if (!canMove(items, getFunction()->body)) { cost += 1 + WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; // TODO: to do this, we need to maintain a map of element=>parent, // so that we can insert the new blocks in the right place // for now, just don't do this optimization return false; } // is it worth it? return saved > cost; }; // let's see if we can merge deeper than num, to num + 1 auto next = tails; // remove tails that are too short, or that we hit an item we can't handle next.erase(std::remove_if(next.begin(), next.end(), [&](Tail& tail) { if (effectiveSize(tail) < num + 1) { return true; } auto* newItem = getItem(tail, num); // ignore tails that break to outside blocks. we // want to move code to the very outermost // position, so such code cannot be moved // TODO: this should not be a problem in // *non*-terminating tails, but // double-verify that if (EffectAnalyzer( getPassOptions(), *getModule(), newItem) .hasExternalBreakTargets()) { return true; } return false; }), next.end()); // if we have enough to investigate, do so if (next.size() >= 2) { // now we want to find a mergeable item - any item that is equal among a // subset std::map hashes; // expression => hash value // hash value => expressions with that hash std::map> hashed; for (auto& tail : next) { auto* item = getItem(tail, num); auto hash = hashes[item] = ExpressionAnalyzer::hash(item); hashed[hash].push_back(item); } // look at each hash value exactly once. we do this in a deterministic // order by iterating over a vector retaining insertion order. std::set seen; for (auto& tail : next) { auto* item = getItem(tail, num); auto digest = hashes[item]; if (!seen.emplace(digest).second) { continue; } auto& items = hashed[digest]; if (items.size() == 1) { continue; } assert(items.size() > 0); // look for an item that has another match. while (items.size() >= 2) { auto first = items[0]; std::vector others; items.erase( std::remove_if(items.begin(), items.end(), [&](Expression* item) { if (item == first || // don't bother comparing the first ExpressionAnalyzer::equal(item, first)) { // equal, keep it return false; } else { // unequal, look at it later others.push_back(item); return true; } }), items.end()); if (items.size() >= 2) { // possible merge here, investigate it auto* correct = items[0]; auto explore = next; explore.erase(std::remove_if(explore.begin(), explore.end(), [&](Tail& tail) { auto* item = getItem(tail, num); return !ExpressionAnalyzer::equal( item, correct); }), explore.end()); // try to optimize this deeper tail. if we succeed, then stop here, // as the changes may influence us. we leave further opts to further // passes (as this is rare in practice, it's generally not a perf // issue, but TODO optimize) if (optimizeTerminatingTails(explore, num + 1)) { return true; } } items.swap(others); } } } // we explored deeper (higher num) options, but perhaps there // was nothing there while there is something we can do at this level // but if we are at num == 0, then we found nothing at all if (num == 0) { return false; } // if not worth it, stop if (!worthIt(num, tails)) { return false; } // this is worth doing, do it! auto mergeable = getTailItems(num, tails); // the elements we can merge // since we managed a merge, then it might open up more opportunities later anotherPass = true; Builder builder(*getModule()); // TODO: don't create one per merge, linear in function size LabelUtils::LabelManager labels(getFunction()); Name innerName = labels.getUnique("folding-inner"); for (auto& tail : tails) { // remove the items we are merging / moving, and add a break // also mark as modified, so we don't try to handle them // again in this pass, which might be buggy if (tail.block) { markAsModified(tail.block); for (Index i = 0; i < mergeable.size(); i++) { tail.block->list.pop_back(); } tail.block->list.push_back(builder.makeBreak(innerName)); tail.block->finalize(tail.block->type); } else { markAsModified(tail.expr); *tail.pointer = builder.makeBreak(innerName); } } // make a block with the old body + the merged code auto* old = getFunction()->body; auto* inner = builder.makeBlock(); inner->name = innerName; if (old->type == Type::unreachable) { // the old body is not flowed out of anyhow, so just put it there inner->list.push_back(old); } else { // otherwise, we must not flow out to the merged code if (old->type == Type::none) { inner->list.push_back(old); inner->list.push_back(builder.makeReturn()); } else { // looks like we must return this. but if it's a toplevel block // then it might be marked as having a type, but not actually // returning it (we marked it as such for wasm type-checking // rules, and now it won't be toplevel in the function, it can // change) auto* toplevel = old->dynCast(); if (toplevel) { toplevel->finalize(); } if (old->type != Type::unreachable) { inner->list.push_back(builder.makeReturn(old)); } else { inner->list.push_back(old); } } } inner->finalize(); auto* outer = builder.makeBlock(); outer->list.push_back(inner); while (!mergeable.empty()) { outer->list.push_back(mergeable.back()); mergeable.pop_back(); } // ensure the replacement has the same type, so the outside is not surprised outer->finalize(getFunction()->getResults()); getFunction()->body = outer; needEHFixups = true; return true; } void markAsModified(Expression* curr) { ExpressionMarker marker(modifieds, curr); } }; Pass* createCodeFoldingPass() { return new CodeFolding(); } } // namespace wasm