summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/passes/RemoveUnusedBrs.cpp176
-rw-r--r--src/wasm-builder.h1
2 files changed, 171 insertions, 6 deletions
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 33ffc42b6..9145e5e2c 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -20,6 +20,7 @@
#include <wasm.h>
#include <pass.h>
+#include <parsing.h>
#include <ir/utils.h>
#include <ir/branch-utils.h>
#include <ir/effects.h>
@@ -444,9 +445,11 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
// perform some final optimizations
struct FinalOptimizer : public PostWalker<FinalOptimizer> {
- bool selectify;
+ bool shrink;
PassOptions& passOptions;
+ bool needUniqify = false;
+
FinalOptimizer(PassOptions& passOptions) : passOptions(passOptions) {}
void visitBlock(Block* curr) {
@@ -479,9 +482,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
}
}
if (list.size() >= 2) {
- if (selectify) {
- // Join adjacent br_ifs to the same target, making one br_if with
- // a "selectified" condition that executes both.
+ // Join adjacent br_ifs to the same target, making one br_if with
+ // a "selectified" condition that executes both.
+ if (shrink) {
for (Index i = 0; i < list.size() - 1; i++) {
auto* br1 = list[i]->dynCast<Break>();
// avoid unreachable brs, as they are dead code anyhow, and after merging
@@ -500,6 +503,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
}
}
}
+ // combine adjacent br_ifs that test the same value into a br_table,
+ // when that makes sense
+ tablify(curr);
// Restructuring of ifs: if we have
// (block $x
// (br_if $x (cond))
@@ -535,7 +541,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
void visitIf(If* curr) {
// we may have simplified ifs enough to turn them into selects
// this is helpful for code size, but can be a tradeoff with performance as we run both code paths
- if (!selectify) return;
+ if (!shrink) return;
if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) {
// if with else, consider turning it into a select if there is no control flow
// TODO: estimate cost
@@ -556,11 +562,169 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
}
}
}
+
+ // (br_if)+ => br_table
+ // we look for the specific pattern of
+ // (br_if ..target1..
+ // (i32.eq
+ // (..input..)
+ // (i32.const ..value1..)
+ // )
+ // )
+ // (br_if ..target2..
+ // (i32.eq
+ // (..input..)
+ // (i32.const ..value2..)
+ // )
+ // )
+ // TODO: consider also looking at <= etc. and not just eq
+ void tablify(Block* block) {
+ auto &list = block->list;
+ if (list.size() <= 1) return;
+
+ // Heuristics. These are slightly inspired by the constants from the asm.js backend.
+
+ // How many br_ifs we need to see to consider doing this
+ const uint32_t MIN_NUM = 3;
+ // How much of a range of values is definitely too big
+ const uint32_t MAX_RANGE = 1024;
+ // Multiplied by the number of br_ifs, then compared to the range. When
+ // this is high, we allow larger ranges.
+ const uint32_t NUM_TO_RANGE_FACTOR = 3;
+
+ // check if the input is a proper br_if on an i32.eq of a condition value to a const,
+ // and the const is in the proper range, [0-int32_max), to avoid overflow concerns.
+ // returns the br_if if so, or nullptr otherwise
+ auto getProperBrIf = [](Expression* curr) -> Break*{
+ auto* br = curr->dynCast<Break>();
+ if (!br) return nullptr;
+ if (!br->condition || br->value) return nullptr;
+ if (br->type != none) return nullptr; // no value, so can be unreachable or none. ignore unreachable ones, dce will clean it up
+ auto* binary = br->condition->dynCast<Binary>();
+ if (!binary) return nullptr;
+ if (binary->op != EqInt32) return nullptr;
+ auto* c = binary->right->dynCast<Const>();
+ if (!c) return nullptr;
+ uint32_t value = c->value.geti32();
+ if (value >= std::numeric_limits<int32_t>::max()) return nullptr;
+ return br;
+ };
+
+ // check if the input is a proper br_if
+ // and returns the condition if so, or nullptr otherwise
+ auto getProperBrIfConditionValue = [&getProperBrIf](Expression* curr) -> Expression* {
+ auto* br = getProperBrIf(curr);
+ if (!br) return nullptr;
+ return br->condition->cast<Binary>()->left;
+ };
+
+ // returns the constant value, as a uint32_t
+ auto getProperBrIfConstant = [&getProperBrIf](Expression* curr) -> uint32_t {
+ return getProperBrIf(curr)->condition->cast<Binary>()->right->cast<Const>()->value.geti32();
+ };
+ Index start = 0;
+ while (start < list.size() - 1) {
+ auto* conditionValue = getProperBrIfConditionValue(list[start]);
+ if (!conditionValue) {
+ start++;
+ continue;
+ }
+ // if the condition has side effects, we can't replace many appearances of it
+ // with a single one
+ if (EffectAnalyzer(passOptions, conditionValue).hasSideEffects()) {
+ start++;
+ continue;
+ }
+ // look for a "run" of br_ifs with all the same conditionValue, and having
+ // unique constants (an overlapping constant could be handled, just the first
+ // branch is taken, but we can't remove the other br_if (it may be the only
+ // branch keeping a block reachable), which may make this bad for code size.
+ Index end = start + 1;
+ std::unordered_set<uint32_t> usedConstants;
+ usedConstants.insert(getProperBrIfConstant(list[start]));
+ while (end < list.size() &&
+ ExpressionAnalyzer::equal(getProperBrIfConditionValue(list[end]),
+ conditionValue)) {
+ if (!usedConstants.insert(getProperBrIfConstant(list[end])).second) {
+ // this constant already appeared
+ break;
+ }
+ end++;
+ }
+ auto num = end - start;
+ if (num >= 2 && num >= MIN_NUM) {
+ // we found a suitable range, [start, end), containing more than 1
+ // element. let's see if it's worth it
+ auto min = getProperBrIfConstant(list[start]);
+ auto max = min;
+ for (Index i = start + 1; i < end; i++) {
+ auto* curr = list[i];
+ min = std::min(min, getProperBrIfConstant(curr));
+ max = std::max(max, getProperBrIfConstant(curr));
+ }
+ uint32_t range = max - min;
+ // decision time
+ if (range <= MAX_RANGE &&
+ range <= num * NUM_TO_RANGE_FACTOR) {
+ // great! let's do this
+ std::unordered_set<Name> usedNames;
+ for (Index i = start; i < end; i++) {
+ usedNames.insert(getProperBrIf(list[i])->name);
+ }
+ // we need a name for the default too
+ Name defaultName;
+ Index i = 0;
+ while (1) {
+ defaultName = "tablify|" + std::to_string(i++);
+ if (usedNames.count(defaultName) == 0) break;
+ }
+ std::vector<Name> table;
+ for (Index i = start; i < end; i++) {
+ auto name = getProperBrIf(list[i])->name;
+ auto index = getProperBrIfConstant(list[i]);
+ index -= min;
+ while (table.size() <= index) {
+ table.push_back(defaultName);
+ }
+ assert(table[index] == defaultName); // we should have made sure there are no overlaps
+ table[index] = name;
+ }
+ Builder builder(*getModule());
+ // the table and condition are offset by the min
+ if (min != 0) {
+ conditionValue = builder.makeBinary(
+ SubInt32,
+ conditionValue,
+ builder.makeConst(Literal(int32_t(min)))
+ );
+ }
+ list[end - 1] = builder.makeBlock(
+ defaultName,
+ builder.makeSwitch(
+ table,
+ defaultName,
+ conditionValue
+ )
+ );
+ for (Index i = start; i < end - 1; i++) {
+ ExpressionManipulator::nop(list[i]);
+ }
+ // the defaultName may exist elsewhere in this function,
+ // uniquify it later
+ needUniqify = true;
+ }
+ }
+ start = end;
+ }
+ }
};
FinalOptimizer finalOptimizer(getPassOptions());
finalOptimizer.setModule(getModule());
- finalOptimizer.selectify = getPassRunner()->options.shrinkLevel > 0;
+ finalOptimizer.shrink = getPassRunner()->options.shrinkLevel > 0;
finalOptimizer.walkFunction(func);
+ if (finalOptimizer.needUniqify) {
+ wasm::UniqueNameMapper::uniquify(func->body);
+ }
}
};
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 8ab4cfec9..3c200a431 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -82,6 +82,7 @@ public:
Block* makeBlock(Name name, Expression* first = nullptr) {
auto* ret = makeBlock(first);
ret->name = name;
+ ret->finalize();
return ret;
}
Block* makeBlock(const std::vector<Expression*>& items) {