diff options
author | Alon Zakai <alonzakai@gmail.com> | 2017-06-28 22:05:05 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-28 22:05:05 -0700 |
commit | e488da5adbef2613c08fe205db5b79b1765a4af3 (patch) | |
tree | e3cab840dcbf7d8d4ccf1f47a742fbfc41e1d5ef /src | |
parent | e2c08d42ab0ffc05d980ae2d34fee0e77b201134 (diff) | |
download | binaryen-e488da5adbef2613c08fe205db5b79b1765a4af3.tar.gz binaryen-e488da5adbef2613c08fe205db5b79b1765a4af3.tar.bz2 binaryen-e488da5adbef2613c08fe205db5b79b1765a4af3.zip |
Code folding (#1076)
Adds a pass that folds code, i.e. merges it when possible. See details in comment in the pass implementation cpp.
This is enabled by default in -Os and -Oz. Seems risky to enable anywhere else, as it does add branches - likely predictable ones so maybe no slowdown, but still some risk.
Code size numbers:
wasm-backend: 196331
+ binaryen -Os (before): 182598
+ binaryen -Os (with folding): 181943
asm2wasm -Os (before): 172463
asm2wasm -Os (with folding): 168774
So this reduces wasm-backend output by an additional 0.5% than it could before. Mainly this is because the wasm backend already has code folding, whereas on asm2wasm output, where we didn't have folding before, this saves over 2%. The 0.5% improvement on the wasm backend's output might be because this can fold more types of code than LLVM can (it can fold nested control flow, in particular).
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/branch-utils.h | 55 | ||||
-rw-r--r-- | src/ast/label-utils.h | 62 | ||||
-rw-r--r-- | src/mixed_arena.h | 4 | ||||
-rw-r--r-- | src/pass.h | 23 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/passes/CodeFolding.cpp | 603 | ||||
-rw-r--r-- | src/passes/NameManager.cpp | 80 | ||||
-rw-r--r-- | src/passes/pass.cpp | 7 | ||||
-rw-r--r-- | src/passes/passes.h | 2 | ||||
-rw-r--r-- | src/wasm-traversal.h | 2 |
10 files changed, 732 insertions, 108 deletions
diff --git a/src/ast/branch-utils.h b/src/ast/branch-utils.h index a54b8151f..bdf52d36a 100644 --- a/src/ast/branch-utils.h +++ b/src/ast/branch-utils.h @@ -18,6 +18,7 @@ #define wasm_ast_branch_h #include "wasm.h" +#include "wasm-traversal.h" namespace wasm { @@ -36,6 +37,60 @@ inline bool isBranchTaken(Switch* sw) { sw->condition->type != unreachable; } +// returns the set of targets to which we branch that are +// outside of a node +inline std::set<Name> getExitingBranches(Expression* ast) { + struct Scanner : public PostWalker<Scanner> { + std::set<Name> targets; + + void visitBreak(Break* curr) { + targets.insert(curr->name); + } + void visitSwitch(Switch* curr) { + for (auto target : targets) { + targets.insert(target); + } + targets.insert(curr->default_); + } + void visitBlock(Block* curr) { + if (curr->name.is()) { + targets.erase(curr->name); + } + } + void visitLoop(Loop* curr) { + if (curr->name.is()) { + targets.erase(curr->name); + } + } + }; + Scanner scanner; + scanner.walk(ast); + // anything not erased is a branch out + return scanner.targets; +} + +// returns the list of all branch targets in a node + +inline std::set<Name> getBranchTargets(Expression* ast) { + struct Scanner : public PostWalker<Scanner> { + std::set<Name> targets; + + void visitBlock(Block* curr) { + if (curr->name.is()) { + targets.insert(curr->name); + } + } + void visitLoop(Loop* curr) { + if (curr->name.is()) { + targets.insert(curr->name); + } + } + }; + Scanner scanner; + scanner.walk(ast); + return scanner.targets; +} + } // namespace BranchUtils } // namespace wasm diff --git a/src/ast/label-utils.h b/src/ast/label-utils.h new file mode 100644 index 000000000..6ec9ecf5d --- /dev/null +++ b/src/ast/label-utils.h @@ -0,0 +1,62 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_label_h +#define wasm_ast_label_h + +#include "wasm.h" +#include "wasm-traversal.h" + +namespace wasm { + +namespace LabelUtils { + +// Handles branch/loop labels in a function; makes it easy to add new +// ones without duplicates +class LabelManager : public PostWalker<LabelManager> { +public: + LabelManager(Function* func) { + walkFunction(func); + } + + Name getUnique(std::string prefix) { + while (1) { + auto curr = Name(prefix + std::to_string(counter++)); + if (labels.find(curr) == labels.end()) { + labels.insert(curr); + return curr; + } + } + } + + void visitBlock(Block* curr) { + labels.insert(curr->name); + } + void visitLoop(Loop* curr) { + labels.insert(curr->name); + } + +private: + std::set<Name> labels; + size_t counter = 0; +}; + +} // namespace LabelUtils + +} // namespace wasm + +#endif // wasm_ast_label_h + diff --git a/src/mixed_arena.h b/src/mixed_arena.h index af585092e..47e718454 100644 --- a/src/mixed_arena.h +++ b/src/mixed_arena.h @@ -172,6 +172,10 @@ public: return usedElements; } + bool empty() const { + return size() == 0; + } + void resize(size_t size) { if (size > allocatedElements) { reallocate(size); diff --git a/src/pass.h b/src/pass.h index 198d5dcb5..21836ebf9 100644 --- a/src/pass.h +++ b/src/pass.h @@ -237,29 +237,6 @@ public: // but registering them here in addition allows them to communicate // e.g. through PassRunner::getLast -// Handles names in a module, in particular adding names without duplicates -class NameManager : public WalkerPass<PostWalker<NameManager>> { - public: - Name getUnique(std::string prefix); - // TODO: getUniqueInFunction - - // visitors - void visitBlock(Block* curr); - void visitLoop(Loop* curr); - void visitBreak(Break* curr); - void visitSwitch(Switch* curr); - void visitCall(Call* curr); - void visitCallImport(CallImport* curr); - void visitFunctionType(FunctionType* curr); - void visitFunction(Function* curr); - void visitImport(Import* curr); - void visitExport(Export* curr); - -private: - std::set<Name> names; - size_t counter = 0; -}; - // Prints out a module class Printer : public Pass { protected: diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index d05390d45..7c0166786 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -2,6 +2,7 @@ SET(passes_SOURCES pass.cpp CoalesceLocals.cpp CodePushing.cpp + CodeFolding.cpp DeadCodeElimination.cpp DuplicateFunctionElimination.cpp ExtractFunction.cpp @@ -15,7 +16,6 @@ SET(passes_SOURCES MemoryPacking.cpp MergeBlocks.cpp Metrics.cpp - NameManager.cpp NameList.cpp OptimizeInstructions.cpp PickLoadSigns.cpp diff --git a/src/passes/CodeFolding.cpp b/src/passes/CodeFolding.cpp new file mode 100644 index 000000000..ae2f81283 --- /dev/null +++ b/src/passes/CodeFolding.cpp @@ -0,0 +1,603 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Folds duplicate code together, saving space. +// +// We fold tails of code where they merge and moving the code +// to the merge point is helpful. There are two cases here: (1) expressions, +// in which we merge to right after the expression itself, in these cases: +// * blocks, we merge the fallthrough + the breaks +// * if-else, we merge the arms +// and (2) the function body as a whole, in which we can merge returns or +// unreachables, putting the merged code at the end of the function body. +// +// For example, with an if-else, we might merge this: +// (if (condition) +// (block +// A +// C +// ) +// (block +// B +// C +// ) +// ) +// to +// (if (condition) +// (block +// A +// ) +// (block +// B +// ) +// ) +// C +// +// Note that the merged code, C in the example above, can be anything, +// including code with control flow. If C is identical in all the locations, +// then it must be safe to merge (if it contains a branch to something +// higher up, then since our branch target names are unique, it must be +// to the same thing, and after merging it can still reach it). +// + +#include <iterator> + +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ast_utils.h" +#include "ast/branch-utils.h" +#include "ast/label-utils.h" + +namespace wasm { + +static const Index WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH = 3; + +struct ExpressionMarker : public PostWalker<ExpressionMarker, UnifiedExpressionVisitor<ExpressionMarker>> { + std::set<Expression*>& marked; + + ExpressionMarker(std::set<Expression*>& marked, Expression* expr) : marked(marked) { + walk(expr); + } + + void visitExpression(Expression* expr) { + marked.insert(expr); + } +}; + +struct CodeFolding : public WalkerPass<ControlFlowWalker<CodeFolding>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new CodeFolding; } + + // information about a "tail" - code that reaches a point that we can + // merge (e.g., a branch and some code leading up to it) + struct Tail { + Expression* expr; // nullptr if this is a fallthrough + Block* block; // the enclosing block of code we hope to merge at its tail + Expression** pointer; // for an expr with no parent block, the location it is at, so we can replace it + + // For a fallthrough + Tail(Block* block) : expr(nullptr), block(block), pointer(nullptr) {} + // For a break + Tail(Expression* expr, Block* block) : expr(expr), block(block), pointer(nullptr) { + validate(); + } + Tail(Expression* expr, Expression** pointer) : expr(expr), block(nullptr), pointer(pointer) {} + + bool isFallthrough() const { return expr == nullptr; } + + void validate() const { + if (expr && block) { + assert(block->list.back() == expr); + } + } + }; + + // state + + bool anotherPass; + + // pass state + + std::map<Name, std::vector<Tail>> breakTails; // break target name => tails that reach it + std::vector<Tail> unreachableTails; // tails leading to (unreachable) + std::vector<Tail> returnTails; // tails leading to (return) + std::set<Name> unoptimizables; // break target names that we can't handle + std::set<Expression*> modifieds; // modified code should not be processed again, wait for next pass + + // walking + + void visitBreak(Break* curr) { + if (curr->condition || curr->value) { + unoptimizables.insert(curr->name); + } else { + // we can only optimize if we are at the end of the parent block + Block* parent = controlFlowStack.back()->dynCast<Block>(); + if (parent && curr == parent->list.back()) { + breakTails[curr->name].push_back(Tail(curr, parent)); + } else { + unoptimizables.insert(curr->name); + } + } + } + + void visitSwitch(Switch* curr) { + for (auto target : curr->targets) { + unoptimizables.insert(target); + } + unoptimizables.insert(curr->default_); + } + + void visitUnreachable(Unreachable* curr) { + // we can only optimize if we are at the end of the parent block + if (!controlFlowStack.empty()) { + Block* parent = controlFlowStack.back()->dynCast<Block>(); + if (parent && curr == parent->list.back()) { + unreachableTails.push_back(Tail(curr, parent)); + } + } + } + + void visitReturn(Return* curr) { + if (!controlFlowStack.empty()) { + // we can easily optimize if we are at the end of the parent block + Block* parent = controlFlowStack.back()->dynCast<Block>(); + if (parent && curr == parent->list.back()) { + returnTails.push_back(Tail(curr, parent)); + return; + } + } + // otherwise, if we have a large value, it might be worth optimizing us as well + returnTails.push_back(Tail(curr, getCurrentPointer())); + } + + void visitBlock(Block* curr) { + if (!curr->name.is()) return; + if (unoptimizables.count(curr->name) > 0) return; + auto iter = breakTails.find(curr->name); + if (iter == breakTails.end()) return; + // looks promising + auto& tails = iter->second; + // see if there is a fallthrough + bool hasFallthrough = true; + for (auto* child : curr->list) { + if (child->type == unreachable) { + hasFallthrough = false; + } + } + if (hasFallthrough) { + tails.push_back({ Tail(curr) }); + } + optimizeExpressionTails(tails, curr); + } + + void visitIf(If* curr) { + if (!curr->ifFalse) return; + // if both sides are identical, this is easy to fold + // (except if the condition is unreachable and we return a value, then we can't just replace + // outselves with a drop + if (ExpressionAnalyzer::equal(curr->ifTrue, curr->ifFalse)) { + Builder builder(*getModule()); + // remove if (4 bytes), remove one arm, add drop (1), add block (3), + // so this must be a net savings + markAsModified(curr); + auto* ret = builder.makeSequence( + builder.makeDrop(curr->condition), + curr->ifTrue + ); + // we must ensure we present the same type as the if had + ret->finalize(curr->type); + replaceCurrent(ret); + } else { + // if both are blocks, look for a tail we can merge + auto* left = curr->ifTrue->dynCast<Block>(); + auto* right = curr->ifFalse->dynCast<Block>(); + // we need nameless blocks, as if there is a name, someone might branch + // to the end, skipping the code we want to merge + if (left && right && + !left->name.is() && !right->name.is()) { + std::vector<Tail> tails = { Tail(left), Tail(right) }; + optimizeExpressionTails(tails, curr); + } + } + } + + void doWalkFunction(Function* func) { + anotherPass = true; + while (anotherPass) { + anotherPass = false; + WalkerPass<ControlFlowWalker<CodeFolding>>::doWalkFunction(func); + optimizeTerminatingTails(unreachableTails); + // optimize returns at the end, so we can benefit from a fallthrough if there is a value TODO: separate passes for them? + optimizeTerminatingTails(returnTails); + // TODO add fallthrough for returns + // TODO optimzier returns not in blocks, a big return value can be worth it + // clean up + breakTails.clear(); + unreachableTails.clear(); + returnTails.clear(); + unoptimizables.clear(); + modifieds.clear(); + } + } + +private: + // check if we can move a list of items out of another item. we can't do so + // if one of the items has a branch to something inside outOf that is not + // inside that item + bool canMove(const std::vector<Expression*>& items, Expression* outOf) { + auto allTargets = BranchUtils::getBranchTargets(outOf); + for (auto* item : items) { + auto exiting = BranchUtils::getExitingBranches(item); + std::vector<Name> intersection; + std::set_intersection(allTargets.begin(), allTargets.end(), exiting.begin(), exiting.end(), + std::back_inserter(intersection)); + if (intersection.size() > 0) { + // anything exiting that is in all targets is something bad + return false; + } + } + return true; + } + + // optimize tails that reach the outside of an expression. code that is identical in all + // paths leading to the block exit can be merged. + template<typename T> + void optimizeExpressionTails(std::vector<Tail>& tails, T* curr) { + if (tails.size() < 2) return; + // see if anything is untoward, and we should not do this + for (auto& tail : tails) { + if (tail.expr && modifieds.count(tail.expr) > 0) return; + if (modifieds.count(tail.block) > 0) return; + // if we were not modified, then we should be valid for processing + tail.validate(); + } + // we can ignore the final br in a tail + auto effectiveSize = [&](const Tail& tail) { + auto ret = tail.block->list.size(); + if (!tail.isFallthrough()) { + ret--; + } + return ret; + }; + // the mergeable items do not include the final br in a tail + auto getMergeable = [&](const Tail& tail, Index num) { + return tail.block->list[effectiveSize(tail) - num - 1]; + }; + // we are going to remove duplicate elements and add a block. + // so for this to make sense, we need the size of the duplicate + // elements to be worth that extra block (although, there is + // some chance the block would get merged higher up, see later) + std::vector<Expression*> mergeable; // the elements we can merge + Index num = 0; // how many elements back from the tail to look at + Index saved = 0; // how much we can save + while (1) { + // check if this num is still relevant + bool stop = false; + for (auto& tail : tails) { + assert(tail.block); + if (num >= effectiveSize(tail)) { + // one of the lists is too short + stop = true; + break; + } + } + if (stop) break; + auto* item = getMergeable(tails[0], num); + for (auto& tail : tails) { + if (!ExpressionAnalyzer::equal(item, getMergeable(tail, num))) { + // one of the lists has a different item + stop = true; + break; + } + } + if (stop) break; + // we may have found another one we can merge - can we move it? + if (!canMove({ item }, curr)) break; + // we found another one we can merge + mergeable.push_back(item); + num++; + saved += Measurer::measure(item); + } + if (saved == 0) return; + // we may be able to save enough. + if (saved < WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH) { + // it's not obvious we can save enough. see if we get rid + // of a block, that would justify this + bool willEmptyBlock = false; + for (auto& tail : tails) { + // it is enough to zero out the block, or leave just one + // element, as then the block can be replaced with that + if (num >= tail.block->list.size() - 1) { + willEmptyBlock = true; + break; + } + } + if (!willEmptyBlock) { + // last chance, if our parent is a block, then it should be + // fine to create a new block here, it will be merged up + assert(curr == controlFlowStack.back()); // we are an if or a block, at the top + if (controlFlowStack.size() <= 1) { + return; // no parent at all + // TODO: if we are the toplevel in the function, then in the binary format + // we might avoid emitting a block, so the same logic applies here? + } + auto* parent = controlFlowStack[controlFlowStack.size() - 2]->dynCast<Block>(); + if (!parent) { + return; // parent is not a block + } + bool isChild = false; + for (auto* child : parent->list) { + if (child == curr) { + isChild = true; + break; + } + } + if (!isChild) { + return; // not a child, something in between + } + } + } + // this is worth doing, do it! + for (auto& tail : tails) { + // remove the items we are merging / moving + // first, mark them as modified, so we don't try to handle them + // again in this pass, which might be buggy + markAsModified(tail.block); + // we must preserve the br if there is one + Expression* last = nullptr; + if (!tail.isFallthrough()) { + last = tail.block->list.back(); + tail.block->list.pop_back(); + } + for (Index i = 0; i < mergeable.size(); i++) { + tail.block->list.pop_back(); + } + if (!tail.isFallthrough()) { + tail.block->list.push_back(last); + } + // the blocks lose their endings, so any values are gone, and the blocks + // are now either none or unreachable + tail.block->finalize(); + } + // since we managed a merge, then it might open up more opportunities later + anotherPass = true; + // make a block with curr + the merged code + Builder builder(*getModule()); + auto* block = builder.makeBlock(); + block->list.push_back(curr); + while (!mergeable.empty()) { + block->list.push_back(mergeable.back()); + mergeable.pop_back(); + } + auto oldType = curr->type; + // NB: we template-specialize so that this calls the proper finalizer for + // the type + curr->finalize(); + // ensure the replacement has the same type, so the outside is not surprised + block->finalize(oldType); + replaceCurrent(block); + } + + // optimize tails that terminate control flow in this function, so we + // are (1) merge just a few of them, we don't need all like with the + // branches to a block, and (2) we do it on the function body. + // num is the depth, i.e., how many tail items we can merge. 0 means + // we are just starting; num > 0 means that tails is guaranteed to be + // equal in the last num items, so we can merge there, but we look for + // deeper merges first. + // returns whether we optimized something. + bool optimizeTerminatingTails(std::vector<Tail>& tails, Index num = 0) { + if (tails.size() < 2) return false; + // remove things that are untoward and cannot be optimized + tails.erase(std::remove_if(tails.begin(), tails.end(), [&](Tail& tail) { + if (tail.expr && modifieds.count(tail.expr) > 0) return true; + if (tail.block && modifieds.count(tail.block) > 0) return true; + // if we were not modified, then we should be valid for processing + tail.validate(); + return false; + }), tails.end()); + // now let's try to find subsets that are mergeable. we don't look hard + // for the most optimal; further passes may find more + // effectiveSize: TODO: special-case fallthrough, matters for returns + auto effectiveSize = [&](Tail& tail) -> Index { + if (tail.block) { + return tail.block->list.size(); + } else { + return 1; + } + }; + // getItem: returns the relevant item from the tail. this includes the + // final item + // TODO: special-case fallthrough, matters for returns + auto getItem = [&](Tail& tail, Index num) { + if (tail.block) { + return tail.block->list[effectiveSize(tail) - num - 1]; + } else { + return tail.expr; + } + }; + // gets the tail elements of a certain depth + auto getTailItems = [&](Index num, std::vector<Tail>& tails) { + std::vector<Expression*> items; + for (Index i = 0; i < num; i++) { + auto item = getItem(tails[0], i); + items.push_back(item); + } + return items; + }; + // estimate if a merging is worth the cost + auto worthIt = [&](Index num, std::vector<Tail>& tails) { + auto items = getTailItems(num, tails); // the elements we can merge + Index saved = 0; // how much we can save + for (auto* item : items) { + saved += Measurer::measure(item) * (tails.size() - 1); + } + // compure the cost: in non-fallthroughs, we are replacing the final + // element with a br; for a fallthrough, if there is one, we must + // add a return element (for the function body, so it doesn't reach us) + // TODO: handle fallthroughts for return + Index cost = tails.size(); + // we also need to add two blocks: for us to break to, and to contain + // that block and the merged code. very possibly one of the blocks + // can be removed, though + cost += WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; + // if we cannot merge to the end, then we definitely need 2 blocks, + // and a branch + if (!canMove(items, getFunction()->body)) { // TODO: efficiency, entire body + cost += 1 + WORTH_ADDING_BLOCK_TO_REMOVE_THIS_MUCH; + // TODO: to do this, we need to maintain a map of element=>parent, + // so that we can insert the new blocks in the right place + // for now, just don't do this optimization + return false; + } + // is it worth it? + return saved > cost; + }; + // let's see if we can merge deeper than num, to num + 1 + auto next = tails; + // remove tails that are too short + next.erase(std::remove_if(next.begin(), next.end(), [&](Tail& tail) { + return effectiveSize(tail) < num + 1; + }), next.end()); + // if we have enough to investigate, do so + if (next.size() >= 2) { + // now we want to find a mergeable item - any item that is equal among a subset + std::map<uint32_t, std::vector<Expression*>> hashed; // hash value => expressions with that hash + for (auto& tail : next) { + auto* item = getItem(tail, num); + hashed[ExpressionAnalyzer::hash(item)].push_back(item); + } + for (auto& iter : hashed) { + auto& items = iter.second; + if (items.size() == 1) continue; + assert(items.size() > 0); + // look for an item that has another match. + while (items.size() >= 2) { + auto first = items[0]; + std::vector<Expression*> others; + items.erase(std::remove_if(items.begin(), items.end(), [&](Expression* item) { + if (item == first || // don't bother comparing the first + ExpressionAnalyzer::equal(item, first)) { + // equal, keep it + return false; + } else { + // unequal, look at it later + others.push_back(item); + return true; + } + }), items.end()); + if (items.size() >= 2) { + // possible merge here, investigate it + auto* correct = items[0]; + auto explore = next; + explore.erase(std::remove_if(explore.begin(), explore.end(), [&](Tail& tail) { + auto* item = getItem(tail, num); + return !ExpressionAnalyzer::equal(item, correct); + }), explore.end()); + // try to optimize this deeper tail. if we succeed, then stop here, as the + // changes may influence us. we leave further opts to further passes (as this + // is rare in practice, it's generally not a perf issue, but TODO optimize) + if (optimizeTerminatingTails(explore, num + 1)) { + return true; + } + } + items.swap(others); + } + } + } + // we explored deeper (higher num) options, but perhaps there + // was nothing there while there is something we can do at this level + // but if we are at num == 0, then we found nothing at all + if (num == 0) return false; + // if not worth it, stop + if (!worthIt(num, tails)) return false; + // this is worth doing, do it! + auto mergeable = getTailItems(num, tails); // the elements we can merge + // since we managed a merge, then it might open up more opportunities later + anotherPass = true; + Builder builder(*getModule()); + LabelUtils::LabelManager labels(getFunction()); // TODO: don't create one per merge, linear in function size + Name innerName = labels.getUnique("folding-inner"); + for (auto& tail : tails) { + // remove the items we are merging / moving, and add a break + // also mark as modified, so we don't try to handle them + // again in this pass, which might be buggy + if (tail.block) { + markAsModified(tail.block); + for (Index i = 0; i < mergeable.size(); i++) { + tail.block->list.pop_back(); + } + tail.block->list.push_back(builder.makeBreak(innerName)); + tail.block->finalize(tail.block->type); + } else { + markAsModified(tail.expr); + *tail.pointer = builder.makeBreak(innerName); + } + } + // make a block with the old body + the merged code + auto* old = getFunction()->body; + auto* inner = builder.makeBlock(); + inner->name = innerName; + if (old->type == unreachable) { + // the old body is not flowed out of anyhow, so just put it there + inner->list.push_back(old); + } else { + // otherwise, we must not flow out to the merged code + if (old->type == none) { + inner->list.push_back(old); + inner->list.push_back(builder.makeReturn()); + } else { + // looks like we must return this. but if it's a toplevel block + // then it might be marked as having a type, but not actually + // returning it (we marked it as such for wasm type-checking + // rules, and now it won't be toplevel in the function, it can + // change) + auto* toplevel = old->dynCast<Block>(); + if (toplevel) toplevel->finalize(); + if (old->type != unreachable) { + inner->list.push_back(builder.makeReturn(old)); + } else { + inner->list.push_back(old); + } + } + } + inner->finalize(); + auto* outer = builder.makeBlock(); + outer->list.push_back(inner); + while (!mergeable.empty()) { + outer->list.push_back(mergeable.back()); + mergeable.pop_back(); + } + // ensure the replacement has the same type, so the outside is not surprised + outer->finalize(getFunction()->result); + getFunction()->body = outer; + return true; + } + + void markAsModified(Expression* curr) { + ExpressionMarker marker(modifieds, curr); + } +}; + +Pass *createCodeFoldingPass() { + return new CodeFolding(); +} + +} // namespace wasm + diff --git a/src/passes/NameManager.cpp b/src/passes/NameManager.cpp deleted file mode 100644 index 035586a77..000000000 --- a/src/passes/NameManager.cpp +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2015 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// NameManager -// - -#include <wasm.h> -#include <pass.h> - -namespace wasm { - -Name NameManager::getUnique(std::string prefix) { - while (1) { - Name curr = cashew::IString((prefix + std::to_string(counter++)).c_str(), false); - if (names.find(curr) == names.end()) { - names.insert(curr); - return curr; - } - } -} - -void NameManager::visitBlock(Block* curr) { - names.insert(curr->name); -} -void NameManager::visitLoop(Loop* curr) { - names.insert(curr->name); -} -void NameManager::visitBreak(Break* curr) { - names.insert(curr->name); -} -void NameManager::visitSwitch(Switch* curr) { - names.insert(curr->default_); - for (auto& target : curr->targets) { - names.insert(target); - } -} -void NameManager::visitCall(Call* curr) { - names.insert(curr->target); -} -void NameManager::visitCallImport(CallImport* curr) { - names.insert(curr->target); -} -void NameManager::visitFunctionType(FunctionType* curr) { - names.insert(curr->name); -} -void NameManager::visitFunction(Function* curr) { - names.insert(curr->name); - for (Index i = 0; i < curr->getNumLocals(); i++) { - Name name = curr->getLocalNameOrDefault(i); - if (name.is()) { - names.insert(name); - } - } -} -void NameManager::visitImport(Import* curr) { - names.insert(curr->name); -} -void NameManager::visitExport(Export* curr) { - names.insert(curr->name); -} - -Pass *createNameManagerPass() { - return new NameManager(); -} - -} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 903675b8c..fc0a583fb 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -67,6 +67,7 @@ void PassRegistry::registerPasses() { registerPass("coalesce-locals", "reduce # of locals by coalescing", createCoalesceLocalsPass); registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass); registerPass("code-pushing", "push code forward, potentially making it not always execute", createCodePushingPass); + registerPass("code-folding", "fold code, merging duplicates", createCodeFoldingPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass); @@ -81,7 +82,6 @@ void PassRegistry::registerPasses() { registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass); registerPass("metrics", "reports metrics", createMetricsPass); registerPass("nm", "name list", createNameListPass); - registerPass("name-manager", "utility pass to manage names in modules", createNameManagerPass); registerPass("optimize-instructions", "optimizes instruction combinations", createOptimizeInstructionsPass); registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass); registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass); @@ -139,6 +139,9 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { add("simplify-locals"); add("vacuum"); // previous pass creates garbage add("reorder-locals"); + if (options.shrinkLevel >= 1) { + add("code-folding"); + } add("merge-blocks"); // makes remove-unused-brs more effective add("remove-unused-brs"); // coalesce-locals opens opportunities for optimizations add("merge-blocks"); // clean up remove-unused-brs new blocks @@ -148,7 +151,7 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { add("local-cse"); // TODO: run this early, before first coalesce-locals. right now doing so uncovers some deficiencies we need to fix first add("coalesce-locals"); // just for localCSE } - add("vacuum"); // should not be needed, last few passes do not create garbage, but just to be safe + add("vacuum"); // just to be safe } void PassRunner::addDefaultGlobalOptimizationPasses() { diff --git a/src/passes/passes.h b/src/passes/passes.h index 43bdd3efe..7a40b7da4 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -24,6 +24,7 @@ class Pass; // All passes: Pass *createCoalesceLocalsPass(); Pass *createCoalesceLocalsWithLearningPass(); +Pass *createCodeFoldingPass(); Pass *createCodePushingPass(); Pass *createDeadCodeEliminationPass(); Pass *createDuplicateFunctionEliminationPass(); @@ -41,7 +42,6 @@ Pass *createMergeBlocksPass(); Pass *createMinifiedPrinterPass(); Pass *createMetricsPass(); Pass *createNameListPass(); -Pass *createNameManagerPass(); Pass *createOptimizeInstructionsPass(); Pass *createPickLoadSignsPass(); Pass *createPostEmscriptenPass(); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index aa8e008ad..3b1de2e32 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -513,7 +513,7 @@ struct ControlFlowWalker : public PostWalker<SubType, VisitorType> { } static void doPostVisitControlFlow(SubType* self, Expression** currp) { - assert(self->controlFlowStack.back() == *currp); + // note that we might be popping something else, as we may have been replaced self->controlFlowStack.pop_back(); } |