summaryrefslogtreecommitdiff
path: root/src/passes/CodeFolding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/CodeFolding.cpp')
-rw-r--r--src/passes/CodeFolding.cpp234
1 files changed, 111 insertions, 123 deletions
diff --git a/src/passes/CodeFolding.cpp b/src/passes/CodeFolding.cpp
index 0cddec4ca..42331b747 100644
--- a/src/passes/CodeFolding.cpp
+++ b/src/passes/CodeFolding.cpp
@@ -105,19 +105,11 @@ struct CodeFolding
Tail(Block* block) : expr(nullptr), block(block), pointer(nullptr) {}
// For a break
Tail(Expression* expr, Block* block)
- : expr(expr), block(block), pointer(nullptr) {
- validate();
- }
+ : expr(expr), block(block), pointer(nullptr) {}
Tail(Expression* expr, Expression** pointer)
: expr(expr), block(nullptr), pointer(pointer) {}
bool isFallthrough() const { return expr == nullptr; }
-
- void validate() const {
- if (expr && block) {
- assert(block->list.back() == expr);
- }
- }
};
// state
@@ -152,15 +144,13 @@ struct CodeFolding
}
void visitBreak(Break* curr) {
- if (curr->condition || curr->value) {
+ if (curr->condition) {
unoptimizables.insert(curr->name);
} else {
- // we can only optimize if we are at the end of the parent block,
- // and if the parent block does not return a value (we can't move
- // elements out of it if there is a value being returned)
+ // we can only optimize if we are at the end of the parent block.
+ // TODO: Relax this.
Block* parent = controlFlowStack.back()->dynCast<Block>();
- if (parent && curr == parent->list.back() &&
- !parent->list.back()->type.isConcrete()) {
+ if (parent && curr == parent->list.back()) {
breakTails[curr->name].push_back(Tail(curr, parent));
} else {
unoptimizables.insert(curr->name);
@@ -222,24 +212,19 @@ struct CodeFolding
if (unoptimizables.count(curr->name) > 0) {
return;
}
- // we can't optimize a fallthrough value
- if (curr->list.back()->type.isConcrete()) {
- return;
- }
auto iter = breakTails.find(curr->name);
if (iter == breakTails.end()) {
return;
}
- // looks promising
+ // Looks promising.
auto& tails = iter->second;
- // see if there is a fallthrough
- bool hasFallthrough = true;
- for (auto* child : curr->list) {
- if (child->type == Type::unreachable) {
- hasFallthrough = false;
- }
- }
- if (hasFallthrough) {
+ // If the end of the block cannot be reached, then we don't need to include
+ // it in the set of folded tails.
+ bool includeFallthrough =
+ !std::any_of(curr->list.begin(), curr->list.end(), [&](auto* child) {
+ return child->type == Type::unreachable;
+ });
+ if (includeFallthrough) {
tails.push_back({Tail(curr)});
}
optimizeExpressionTails(tails, curr);
@@ -249,48 +234,34 @@ struct CodeFolding
if (!curr->ifFalse) {
return;
}
- // if both sides are identical, this is easy to fold
- if (ExpressionAnalyzer::equal(curr->ifTrue, curr->ifFalse)) {
+ // If both are blocks, look for a tail we can merge.
+ auto* left = curr->ifTrue->dynCast<Block>();
+ auto* right = curr->ifFalse->dynCast<Block>();
+ // If one is a block and the other isn't, and the non-block is a tail of the
+ // other, we can fold that - for our convenience, we just add a block and
+ // run the rest of the optimization mormally.
+ auto maybeAddBlock = [this](Block* block, Expression*& other) -> Block* {
+ // If other is a suffix of the block, wrap it in a block.
+ if (block->list.empty() ||
+ !ExpressionAnalyzer::equal(other, block->list.back())) {
+ return nullptr;
+ }
+ // Do it, assign to the out param `other`, and return the block.
Builder builder(*getModule());
- // remove if (4 bytes), remove one arm, add drop (1), add block (3),
- // so this must be a net savings
- markAsModified(curr);
- auto* ret =
- builder.makeSequence(builder.makeDrop(curr->condition), curr->ifTrue);
- // we must ensure we present the same type as the if had
- ret->finalize(curr->type);
- replaceCurrent(ret);
- needEHFixups = true;
- } else {
- // if both are blocks, look for a tail we can merge
- auto* left = curr->ifTrue->dynCast<Block>();
- auto* right = curr->ifFalse->dynCast<Block>();
- // If one is a block and the other isn't, and the non-block is a tail
- // of the other, we can fold that - for our convenience, we just add
- // a block and run the rest of the optimization mormally.
- auto maybeAddBlock = [this](Block* block, Expression*& other) -> Block* {
- // if other is a suffix of the block, wrap it in a block
- if (block->list.empty() ||
- !ExpressionAnalyzer::equal(other, block->list.back())) {
- return nullptr;
- }
- // do it, assign to the out param `other`, and return the block
- Builder builder(*getModule());
- auto* ret = builder.makeBlock(other);
- other = ret;
- return ret;
- };
- if (left && !right) {
- right = maybeAddBlock(left, curr->ifFalse);
- } else if (!left && right) {
- left = maybeAddBlock(right, curr->ifTrue);
- }
- // we need nameless blocks, as if there is a name, someone might branch
- // to the end, skipping the code we want to merge
- if (left && right && !left->name.is() && !right->name.is()) {
- std::vector<Tail> tails = {Tail(left), Tail(right)};
- optimizeExpressionTails(tails, curr);
- }
+ auto* ret = builder.makeBlock(other);
+ other = ret;
+ return ret;
+ };
+ if (left && !right) {
+ right = maybeAddBlock(left, curr->ifFalse);
+ } else if (!left && right) {
+ left = maybeAddBlock(right, curr->ifTrue);
+ }
+ // We need nameless blocks, as if there is a name, someone might branch to
+ // the end, skipping the code we want to merge.
+ if (left && right && !left->name.is() && !right->name.is()) {
+ std::vector<Tail> tails = {Tail(left), Tail(right)};
+ optimizeExpressionTails(tails, curr);
}
}
@@ -315,10 +286,6 @@ struct CodeFolding
if (needEHFixups) {
EHUtils::handleBlockNestedPops(func, *getModule());
}
- // if we did any work, types may need to be propagated
- if (anotherPass) {
- ReFinalize().walkFunctionInModule(func, getModule());
- }
}
}
@@ -372,6 +339,7 @@ private:
// identical in all paths leading to the block exit can be merged.
template<typename T>
void optimizeExpressionTails(std::vector<Tail>& tails, T* curr) {
+ auto oldType = curr->type;
if (tails.size() < 2) {
return;
}
@@ -384,50 +352,49 @@ private:
return;
}
// if we were not modified, then we should be valid for processing
- tail.validate();
+ assert(!tail.expr || !tail.block ||
+ (tail.expr == tail.block->list.back()));
}
- // we can ignore the final br in a tail
- auto effectiveSize = [&](const Tail& tail) {
- auto ret = tail.block->list.size();
+ auto getMergeable = [&](const Tail& tail, Index num) -> Expression* {
if (!tail.isFallthrough()) {
- ret--;
+ // If there is a branch value, it is the first mergeable item.
+ auto* val = tail.expr->cast<Break>()->value;
+ if (val && num == 0) {
+ return val;
+ }
+ if (!val) {
+ // Skip the branch instruction at the end; it is not part of the
+ // merged tail.
+ ++num;
+ }
}
- return ret;
- };
- // the mergeable items do not include the final br in a tail
- auto getMergeable = [&](const Tail& tail, Index num) {
- return tail.block->list[effectiveSize(tail) - num - 1];
+ if (num >= tail.block->list.size()) {
+ return nullptr;
+ }
+ return tail.block->list[tail.block->list.size() - num - 1];
};
// we are going to remove duplicate elements and add a block.
// so for this to make sense, we need the size of the duplicate
// elements to be worth that extra block (although, there is
// some chance the block would get merged higher up, see later)
std::vector<Expression*> mergeable; // the elements we can merge
- Index num = 0; // how many elements back from the tail to look at
Index saved = 0; // how much we can save
- while (1) {
- // check if this num is still relevant
- bool stop = false;
- for (auto& tail : tails) {
- assert(tail.block);
- if (num >= effectiveSize(tail)) {
- // one of the lists is too short
- stop = true;
- break;
- }
- }
- if (stop) {
+ for (Index num = 0; true; ++num) {
+ auto* item = getMergeable(tails[0], num);
+ if (!item) {
+ // The list is too short.
break;
}
- auto* item = getMergeable(tails[0], num);
- for (auto& tail : tails) {
- if (!ExpressionAnalyzer::equal(item, getMergeable(tail, num))) {
- // one of the lists has a different item
- stop = true;
+ Index tail = 1;
+ for (; tail < tails.size(); ++tail) {
+ auto* other = getMergeable(tails[tail], num);
+ if (!other || !ExpressionAnalyzer::equal(item, other)) {
+ // Other tail too short or has a difference.
break;
}
}
- if (stop) {
+ if (tail != tails.size()) {
+ // We saw a tail without a matching item.
break;
}
// we may have found another one we can merge - can we move it?
@@ -436,7 +403,6 @@ private:
}
// we found another one we can merge
mergeable.push_back(item);
- num++;
saved += Measurer::measure(item);
}
if (saved == 0) {
@@ -450,7 +416,7 @@ private:
for (auto& tail : tails) {
// it is enough to zero out the block, or leave just one
// element, as then the block can be replaced with that
- if (num >= tail.block->list.size() - 1) {
+ if (mergeable.size() >= tail.block->list.size() - 1) {
willEmptyBlock = true;
break;
}
@@ -483,6 +449,7 @@ private:
}
}
}
+
// this is worth doing, do it!
for (auto& tail : tails) {
// remove the items we are merging / moving
@@ -490,37 +457,61 @@ private:
// again in this pass, which might be buggy
markAsModified(tail.block);
// we must preserve the br if there is one
- Expression* last = nullptr;
+ Break* branch = nullptr;
if (!tail.isFallthrough()) {
- last = tail.block->list.back();
- tail.block->list.pop_back();
+ branch = tail.block->list.back()->cast<Break>();
+ if (branch->value) {
+ branch->value = nullptr;
+ } else {
+ tail.block->list.pop_back();
+ }
}
- for (Index i = 0; i < mergeable.size(); i++) {
+ for (Index i = 0; i < mergeable.size(); ++i) {
tail.block->list.pop_back();
}
- if (!tail.isFallthrough()) {
- tail.block->list.push_back(last);
+ if (tail.isFallthrough()) {
+ // The block now ends in an expression that was previously in the middle
+ // of the block, meaning it must have type none.
+ tail.block->finalize(Type::none);
+ } else {
+ tail.block->list.push_back(branch);
+ // The block still ends with the same branch it previously ended with,
+ // so its type cannot have changed.
+ tail.block->finalize(tail.block->type);
}
- // the block type may change if we removed unreachable stuff,
- // but in general it should remain the same, as if it had a
- // forced type it should remain, *and*, we don't have a
- // fallthrough value (we would never get here), so a concrete
- // type was not from that. I.e., any type on the block is
- // either forced and/or from breaks with a value, so the
- // type cannot be changed by moving code out.
- tail.block->finalize(tail.block->type);
}
// since we managed a merge, then it might open up more opportunities later
anotherPass = true;
// make a block with curr + the merged code
Builder builder(*getModule());
auto* block = builder.makeBlock();
- block->list.push_back(curr);
+ if constexpr (T::SpecificId == Expression::IfId) {
+ // If we've moved all the contents out of both arms of the If, then we can
+ // simplify the output by replacing it entirely with just a drop of the
+ // condition.
+ auto* iff = curr->template cast<If>();
+ if (iff->ifTrue->template cast<Block>()->list.empty() &&
+ iff->ifFalse->template cast<Block>()->list.empty()) {
+ block->list.push_back(builder.makeDrop(iff->condition));
+ } else {
+ block->list.push_back(curr);
+ }
+ } else {
+ block->list.push_back(curr);
+ }
while (!mergeable.empty()) {
block->list.push_back(mergeable.back());
mergeable.pop_back();
}
- auto oldType = curr->type;
+ if constexpr (T::SpecificId == Expression::BlockId) {
+ // If we didn't have a fallthrough tail because the end of the block was
+ // not reachable, then we might have a concrete expression at the end of
+ // the block even though the value produced by the block has been moved
+ // out of it. If so, drop that expression.
+ auto* currBlock = curr->template cast<Block>();
+ currBlock->list.back() =
+ builder.dropIfConcretelyTyped(currBlock->list.back());
+ }
// NB: we template-specialize so that this calls the proper finalizer for
// the type
curr->finalize();
@@ -553,9 +544,6 @@ private:
if (tail.block && modifieds.count(tail.block) > 0) {
return true;
}
- // if we were not modified, then we should be valid for
- // processing
- tail.validate();
return false;
}),
tails.end());