summaryrefslogtreecommitdiff
path: root/src/passes/RelooperJumpThreading.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/RelooperJumpThreading.cpp')
-rw-r--r--src/passes/RelooperJumpThreading.cpp119
1 files changed, 97 insertions, 22 deletions
diff --git a/src/passes/RelooperJumpThreading.cpp b/src/passes/RelooperJumpThreading.cpp
index a0c33811a..7f74220d5 100644
--- a/src/passes/RelooperJumpThreading.cpp
+++ b/src/passes/RelooperJumpThreading.cpp
@@ -47,6 +47,53 @@ struct NameEnsurer {
}
};
+static If* isLabelCheckingIf(Expression* curr, Index labelIndex) {
+ if (!curr) return nullptr;
+ auto* iff = curr->dynCast<If>();
+ if (!iff) return nullptr;
+ auto* condition = iff->condition->dynCast<Binary>();
+ if (!(condition && condition->op == EqInt32)) return nullptr;
+ auto* left = condition->left->dynCast<GetLocal>();
+ if (!(left && left->index == labelIndex)) return nullptr;
+ return iff;
+}
+
+static Index getCheckedLabelValue(If* iff) {
+ return iff->condition->cast<Binary>()->right->cast<Const>()->value.geti32();
+}
+
+static SetLocal* isLabelSettingSetLocal(Expression* curr, Index labelIndex) {
+ if (!curr) return nullptr;
+ auto* set = curr->dynCast<SetLocal>();
+ if (!set) return nullptr;
+ if (set->index != labelIndex) return nullptr;
+ return set;
+}
+
+static Index getSetLabelValue(SetLocal* set) {
+ return set->value->cast<Const>()->value.geti32();
+}
+
+struct LabelUseFinder : public PostWalker<LabelUseFinder, Visitor<LabelUseFinder>> {
+ Index labelIndex;
+ std::map<Index, Index>& checks; // label value => number of checks on it
+ std::map<Index, Index>& sets; // label value => number of sets to it
+
+ LabelUseFinder(Index labelIndex, std::map<Index, Index>& checks, std::map<Index, Index>& sets) : labelIndex(labelIndex), checks(checks), sets(sets) {}
+
+ void visitIf(If* curr) {
+ if (isLabelCheckingIf(curr, labelIndex)) {
+ checks[getCheckedLabelValue(curr)]++;
+ }
+ }
+
+ void visitSetLocal(SetLocal* curr) {
+ if (isLabelSettingSetLocal(curr, labelIndex)) {
+ sets[getSetLabelValue(curr)]++;
+ }
+ }
+};
+
struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>> {
bool isFunctionParallel() override { return true; }
@@ -56,6 +103,9 @@ struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJ
static NameEnsurer ensurer;
}
+ std::map<Index, Index> labelChecks;
+ std::map<Index, Index> labelSets;
+
Index labelIndex;
Index newNameCounter = 0;
@@ -64,27 +114,35 @@ struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJ
auto& list = curr->list;
if (list.size() == 0) return;
for (Index i = 0; i < list.size() - 1; i++) {
+ // once we see something that might be irreducible, we must skip that if and the rest of the dependents
+ bool irreducible = false;
Index origin = i;
for (Index j = i + 1; j < list.size(); j++) {
- if (auto* iff = isLabelCheckingIf(list[j])) {
- optimizeJumpsToLabelCheck(list[origin], iff);
- ExpressionManipulator::nop(iff);
+ if (auto* iff = isLabelCheckingIf(list[j], labelIndex)) {
+ irreducible |= hasIrreducibleControlFlow(iff, list[origin]);
+ if (!irreducible) {
+ optimizeJumpsToLabelCheck(list[origin], iff);
+ ExpressionManipulator::nop(iff);
+ }
i++;
continue;
}
// if the next element is a block, it may be the holding block of label-checking ifs
if (auto* holder = list[j]->dynCast<Block>()) {
if (holder->list.size() > 0) {
- if (If* iff = isLabelCheckingIf(holder->list[0])) {
- // this is indeed a holder. we can process the ifs, and must also move
- // the block to enclose the origin, so it is properly reachable
- assert(holder->list.size() == 1); // must be size 1, a relooper multiple will have its own label, and is an if-else sequence and nothing more
- optimizeJumpsToLabelCheck(list[origin], iff);
- holder->list[0] = list[origin];
- list[origin] = holder;
- // reuse the if as a nop
- list[j] = iff;
- ExpressionManipulator::nop(iff);
+ if (If* iff = isLabelCheckingIf(holder->list[0], labelIndex)) {
+ irreducible |= hasIrreducibleControlFlow(iff, list[origin]);
+ if (!irreducible) {
+ // this is indeed a holder. we can process the ifs, and must also move
+ // the block to enclose the origin, so it is properly reachable
+ assert(holder->list.size() == 1); // must be size 1, a relooper multiple will have its own label, and is an if-else sequence and nothing more
+ optimizeJumpsToLabelCheck(list[origin], iff);
+ holder->list[0] = list[origin];
+ list[origin] = holder;
+ // reuse the if as a nop
+ list[j] = iff;
+ ExpressionManipulator::nop(iff);
+ }
i++;
continue;
}
@@ -99,19 +157,36 @@ struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJ
// if there isn't a label variable, nothing for us to do
if (func->localIndices.count(LABEL)) {
labelIndex = func->getLocalIndex(LABEL);
+ LabelUseFinder finder(labelIndex, labelChecks, labelSets);
+ finder.walk(func->body);
WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>>::doWalkFunction(func);
}
}
private:
- If* isLabelCheckingIf(Expression* curr) {
- auto* iff = curr->dynCast<If>();
- if (!iff) return nullptr;
- auto* condition = iff->condition->dynCast<Binary>();
- if (!(condition && condition->op == EqInt32)) return nullptr;
- auto* left = condition->left->dynCast<GetLocal>();
- if (!(left && left->index == labelIndex)) return nullptr;
- return iff;
+
+ bool hasIrreducibleControlFlow(If* iff, Expression* origin) {
+ // Gather the checks in this if chain. If all the label values checked are only set in origin,
+ // then since origin is right before us, this is not irreducible - we can replace all sets
+ // in origin with jumps forward to us, and since there is nothing else, this is safe and complete.
+ // We must also have the property that there is just one check for the label value, as otherwise
+ // node splitting has complicated things.
+ std::map<Index, Index> labelChecksInOrigin;
+ std::map<Index, Index> labelSetsInOrigin;
+ LabelUseFinder finder(labelIndex, labelChecksInOrigin, labelSetsInOrigin);
+ finder.walk(origin);
+ while (iff) {
+ auto num = getCheckedLabelValue(iff);
+ assert(labelChecks[num] > 0);
+ if (labelChecks[num] > 1) return true; // checked more than once, somewhere in function
+ assert(labelChecksInOrigin[num] == 0);
+ if (labelSetsInOrigin[num] != labelSets[num]) {
+ assert(labelSetsInOrigin[num] < labelSets[num]);
+ return true; // label set somewhere outside of origin TODO: if set in the if body here, it might be safe in some cases
+ }
+ iff = isLabelCheckingIf(iff->ifFalse, labelIndex);
+ }
+ return false;
}
// optimizes jumps to a label check
@@ -123,7 +198,7 @@ private:
std::cerr << "too many names in RelooperJumpThreading :(\n";
return;
}
- Index num = iff->condition->cast<Binary>()->right->cast<Const>()->value.geti32();
+ Index num = getCheckedLabelValue(iff);
// create a new block for this jump target
Builder builder(*getModule());
// origin is where all jumps to this target must come from - the element right before this if