summaryrefslogtreecommitdiff
path: root/src/passes/OptimizeInstructions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r--src/passes/OptimizeInstructions.cpp1057
1 files changed, 550 insertions, 507 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index f5718654e..c12a48b0d 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -201,9 +201,7 @@ template<class Opt> struct Match::Internal::MatchSelf<PureMatcherKind<Opt>> {
// Main pass class
struct OptimizeInstructions
- : public WalkerPass<
- PostWalker<OptimizeInstructions,
- UnifiedExpressionVisitor<OptimizeInstructions>>> {
+ : public WalkerPass<PostWalker<OptimizeInstructions>> {
bool isFunctionParallel() override { return true; }
Pass* create() override { return new OptimizeInstructions; }
@@ -226,21 +224,34 @@ struct OptimizeInstructions
}
}
- void visitExpression(Expression* curr) {
- // if this contains dead code, don't bother trying to optimize it, the type
- // might change (if might not be unreachable if just one arm is, for
- // example). this optimization pass focuses on actually executing code. the
- // only exceptions are control flow changes
- if (curr->is<Const>() || curr->is<Call>() || curr->is<Nop>() ||
- (curr->type == Type::unreachable && !curr->is<Break>() &&
- !curr->is<Switch>() && !curr->is<If>())) {
+ // Set to true when one of the visitors makes a change (either replacing the
+ // node or modifying it).
+ bool changed;
+
+ // Used to avoid recursion in replaceCurrent, see below.
+ bool inReplaceCurrent = false;
+
+ void replaceCurrent(Expression* rep) {
+ WalkerPass<PostWalker<OptimizeInstructions>>::replaceCurrent(rep);
+ // We may be able to apply multiple patterns as one may open opportunities
+ // for others. NB: patterns must not have cycles
+
+ // To avoid recursion, this uses the following pattern: the initial call to
+ // this method comes from one of the visit*() methods. We then loop in here,
+ // and if we are called again we set |changed| instead of recursing, so that
+ // we can loop on that value.
+ if (inReplaceCurrent) {
+ // We are in the loop below so just note a change and return to there.
+ changed = true;
return;
}
- // we may be able to apply multiple patterns, one may open opportunities
- // that look deeper NB: patterns must not have cycles
- while ((curr = handOptimize(curr))) {
- replaceCurrent(curr);
- }
+ // Loop on further changes.
+ inReplaceCurrent = true;
+ do {
+ changed = false;
+ visit(getCurrent());
+ } while (changed);
+ inReplaceCurrent = false;
}
EffectAnalyzer effects(Expression* expr) {
@@ -257,19 +268,16 @@ struct OptimizeInstructions
getPassOptions(), getModule()->features, a, b);
}
- // Optimizations that don't yet fit in the pattern DSL, but could be
- // eventually maybe
- Expression* handOptimize(Expression* curr) {
- if (curr->is<Const>()) {
- return nullptr;
+ void visitBinary(Binary* curr) {
+ // If this contains dead code, don't bother trying to optimize it, the type
+ // might change (if might not be unreachable if just one arm is, for
+ // example). This optimization pass focuses on actually executing code.
+ if (curr->type == Type::unreachable) {
+ return;
}
- FeatureSet features = getModule()->features;
-
- if (auto* binary = curr->dynCast<Binary>()) {
- if (shouldCanonicalize(binary)) {
- canonicalize(binary);
- }
+ if (shouldCanonicalize(curr)) {
+ canonicalize(curr);
}
{
@@ -304,7 +312,7 @@ struct OptimizeInstructions
canReorder(x, y)) {
sub->left = y;
sub->right = x;
- return sub;
+ return replaceCurrent(sub);
}
}
{
@@ -326,43 +334,7 @@ struct OptimizeInstructions
if (matches(curr,
binary(Add, any(&y), binary(&sub, Sub, ival(0), any())))) {
sub->left = y;
- return sub;
- }
- }
- {
- // eqz(x - y) => x == y
- Binary* inner;
- if (matches(curr, unary(EqZ, binary(&inner, Sub, any(), any())))) {
- inner->op = Abstract::getBinary(inner->left->type, Eq);
- inner->type = Type::i32;
- return inner;
- }
- }
- {
- // eqz(x + C) => x == -C
- Const* c;
- Binary* inner;
- if (matches(curr, unary(EqZ, binary(&inner, Add, any(), ival(&c))))) {
- c->value = c->value.neg();
- inner->op = Abstract::getBinary(c->type, Eq);
- inner->type = Type::i32;
- return inner;
- }
- }
- {
- // eqz((signed)x % C_pot) => eqz(x & (abs(C_pot) - 1))
- Const* c;
- Binary* inner;
- if (matches(curr, unary(EqZ, binary(&inner, RemS, any(), ival(&c)))) &&
- (c->value.isSignedMin() ||
- Bits::isPowerOf2(c->value.abs().getInteger()))) {
- inner->op = Abstract::getBinary(c->type, And);
- if (c->value.isSignedMin()) {
- c->value = Literal::makeSignedMax(c->type);
- } else {
- c->value = c->value.abs().sub(Literal::makeOne(c->type));
- }
- return curr;
+ return replaceCurrent(sub);
}
}
{
@@ -385,19 +357,7 @@ struct OptimizeInstructions
bin->left = x;
bin->right = y;
un->value = bin;
- return un;
- }
- }
- {
- // i32.eqz(i32.wrap_i64(x)) => i64.eqz(x)
- // where maxBits(x) <= 32
- Unary* inner;
- Expression* x;
- if (matches(curr, unary(EqZInt32, unary(&inner, WrapInt64, any(&x)))) &&
- Bits::getMaxBits(x, this) <= 32) {
- inner->op = EqZInt64;
- inner->value = x;
- return inner;
+ return replaceCurrent(un);
}
}
{
@@ -420,7 +380,7 @@ struct OptimizeInstructions
Literal::makeFromInt32(c->type.getByteSize() * 8 - 1, c->type));
// x <<>> 0 ==> x
if (c->value.isZero()) {
- return x;
+ return replaceCurrent(x);
}
}
if (matches(curr,
@@ -431,14 +391,14 @@ struct OptimizeInstructions
if ((c->type == Type::i32 && (c->value.geti32() & 31) == 31) ||
(c->type == Type::i64 && (c->value.geti64() & 63LL) == 63LL)) {
curr->cast<Binary>()->right = y;
- return curr;
+ return replaceCurrent(curr);
}
// i32(x) <<>> (y & C) ==> x, where (C & 31) == 0
// i64(x) <<>> (y & C) ==> x, where (C & 63) == 0
if (((c->type == Type::i32 && (c->value.geti32() & 31) == 0) ||
(c->type == Type::i64 && (c->value.geti64() & 63LL) == 0LL)) &&
!effects(y).hasSideEffects()) {
- return x;
+ return replaceCurrent(x);
}
}
}
@@ -450,496 +410,579 @@ struct OptimizeInstructions
c->value.isZero()) {
c->value = Literal::makeOne(Type::i32);
c->type = Type::i32;
- return c;
+ return replaceCurrent(c);
}
// unsigned(x) < 0 => i32(0)
if (matches(curr, binary(LtU, pure(&x), ival(&c))) &&
c->value.isZero()) {
c->value = Literal::makeZero(Type::i32);
c->type = Type::i32;
- return c;
- }
- }
- }
-
- if (auto* select = curr->dynCast<Select>()) {
- return optimizeSelect(select);
- }
-
- if (auto* binary = curr->dynCast<Binary>()) {
- if (auto* ext = Properties::getAlmostSignExt(binary)) {
- Index extraLeftShifts;
- auto bits = Properties::getAlmostSignExtBits(binary, extraLeftShifts);
- if (extraLeftShifts == 0) {
- if (auto* load =
- Properties::getFallthrough(ext, getPassOptions(), features)
- ->dynCast<Load>()) {
- // pattern match a load of 8 bits and a sign extend using a shl of
- // 24 then shr_s of 24 as well, etc.
- if (LoadUtils::canBeSigned(load) &&
- ((load->bytes == 1 && bits == 8) ||
- (load->bytes == 2 && bits == 16))) {
- // if the value falls through, we can't alter the load, as it
- // might be captured in a tee
- if (load->signed_ == true || load == ext) {
- load->signed_ = true;
- return ext;
- }
+ return replaceCurrent(c);
+ }
+ }
+ }
+ if (auto* ext = Properties::getAlmostSignExt(curr)) {
+ Index extraLeftShifts;
+ auto bits = Properties::getAlmostSignExtBits(curr, extraLeftShifts);
+ if (extraLeftShifts == 0) {
+ if (auto* load = Properties::getFallthrough(
+ ext, getPassOptions(), getModule()->features)
+ ->dynCast<Load>()) {
+ // pattern match a load of 8 bits and a sign extend using a shl of
+ // 24 then shr_s of 24 as well, etc.
+ if (LoadUtils::canBeSigned(load) &&
+ ((load->bytes == 1 && bits == 8) ||
+ (load->bytes == 2 && bits == 16))) {
+ // if the value falls through, we can't alter the load, as it
+ // might be captured in a tee
+ if (load->signed_ == true || load == ext) {
+ load->signed_ = true;
+ return replaceCurrent(ext);
}
}
}
- // We can in some cases remove part of a sign extend, that is,
- // (x << A) >> B => x << (A - B)
- // If the sign-extend input cannot have a sign bit, we don't need it.
- if (Bits::getMaxBits(ext, this) + extraLeftShifts < bits) {
- return removeAlmostSignExt(binary);
- }
- // We also don't need it if it already has an identical-sized sign
- // extend applied to it. That is, if it is already a sign-extended
- // value, then another sign extend will do nothing. We do need to be
- // careful of the extra shifts, though.
- if (isSignExted(ext, bits) && extraLeftShifts == 0) {
- return removeAlmostSignExt(binary);
- }
- } else if (binary->op == EqInt32 || binary->op == NeInt32) {
- if (auto* c = binary->right->dynCast<Const>()) {
- if (auto* ext = Properties::getSignExtValue(binary->left)) {
- // We are comparing a sign extend to a constant, which means we can
- // use a cheaper zero-extend in some cases. That is,
- // (x << S) >> S ==/!= C => x & T ==/!= C
- // where S and T are the matching values for sign/zero extend of the
- // same size. For example, for an effective 8-bit value:
- // (x << 24) >> 24 ==/!= C => x & 255 ==/!= C
- //
- // The key thing to track here are the upper bits plus the sign bit;
- // call those the "relevant bits". This is crucial because x is
- // sign-extended, that is, its effective sign bit is spread to all
- // the upper bits, which means that the relevant bits on the left
- // side are either all 0, or all 1.
- auto bits = Properties::getSignExtBits(binary->left);
- uint32_t right = c->value.geti32();
- uint32_t numRelevantBits = 32 - bits + 1;
- uint32_t setRelevantBits =
- Bits::popCount(right >> uint32_t(bits - 1));
- // If all the relevant bits on C are zero
- // then we can mask off the high bits instead of sign-extending x.
- // This is valid because if x is negative, then the comparison was
- // false before (negative vs positive), and will still be false
- // as the sign bit will remain to cause a difference. And if x is
- // positive then the upper bits would be zero anyhow.
- if (setRelevantBits == 0) {
- binary->left = makeZeroExt(ext, bits);
- return binary;
- } else if (setRelevantBits == numRelevantBits) {
- // If all those bits are one, then we can do something similar if
- // we also zero-extend on the right as well. This is valid
- // because, as in the previous case, the sign bit differentiates
- // the two sides when they are different, and if the sign bit is
- // identical, then the upper bits don't matter, so masking them
- // off both sides is fine.
- binary->left = makeZeroExt(ext, bits);
- c->value = c->value.and_(Literal(Bits::lowBitMask(bits)));
- return binary;
- } else {
- // Otherwise, C's relevant bits are mixed, and then the two sides
- // can never be equal, as the left side's bits cannot be mixed.
- Builder builder(*getModule());
- // The result is either always true, or always false.
- c->value = Literal::makeFromInt32(binary->op == NeInt32, c->type);
- return builder.makeSequence(builder.makeDrop(ext), c);
- }
+ }
+ // We can in some cases remove part of a sign extend, that is,
+ // (x << A) >> B => x << (A - B)
+ // If the sign-extend input cannot have a sign bit, we don't need it.
+ if (Bits::getMaxBits(ext, this) + extraLeftShifts < bits) {
+ return replaceCurrent(removeAlmostSignExt(curr));
+ }
+ // We also don't need it if it already has an identical-sized sign
+ // extend applied to it. That is, if it is already a sign-extended
+ // value, then another sign extend will do nothing. We do need to be
+ // careful of the extra shifts, though.
+ if (isSignExted(ext, bits) && extraLeftShifts == 0) {
+ return replaceCurrent(removeAlmostSignExt(curr));
+ }
+ } else if (curr->op == EqInt32 || curr->op == NeInt32) {
+ if (auto* c = curr->right->dynCast<Const>()) {
+ if (auto* ext = Properties::getSignExtValue(curr->left)) {
+ // We are comparing a sign extend to a constant, which means we can
+ // use a cheaper zero-extend in some cases. That is,
+ // (x << S) >> S ==/!= C => x & T ==/!= C
+ // where S and T are the matching values for sign/zero extend of the
+ // same size. For example, for an effective 8-bit value:
+ // (x << 24) >> 24 ==/!= C => x & 255 ==/!= C
+ //
+ // The key thing to track here are the upper bits plus the sign bit;
+ // call those the "relevant bits". This is crucial because x is
+ // sign-extended, that is, its effective sign bit is spread to all
+ // the upper bits, which means that the relevant bits on the left
+ // side are either all 0, or all 1.
+ auto bits = Properties::getSignExtBits(curr->left);
+ uint32_t right = c->value.geti32();
+ uint32_t numRelevantBits = 32 - bits + 1;
+ uint32_t setRelevantBits =
+ Bits::popCount(right >> uint32_t(bits - 1));
+ // If all the relevant bits on C are zero
+ // then we can mask off the high bits instead of sign-extending x.
+ // This is valid because if x is negative, then the comparison was
+ // false before (negative vs positive), and will still be false
+ // as the sign bit will remain to cause a difference. And if x is
+ // positive then the upper bits would be zero anyhow.
+ if (setRelevantBits == 0) {
+ curr->left = makeZeroExt(ext, bits);
+ return replaceCurrent(curr);
+ } else if (setRelevantBits == numRelevantBits) {
+ // If all those bits are one, then we can do something similar if
+ // we also zero-extend on the right as well. This is valid
+ // because, as in the previous case, the sign bit differentiates
+ // the two sides when they are different, and if the sign bit is
+ // identical, then the upper bits don't matter, so masking them
+ // off both sides is fine.
+ curr->left = makeZeroExt(ext, bits);
+ c->value = c->value.and_(Literal(Bits::lowBitMask(bits)));
+ return replaceCurrent(curr);
+ } else {
+ // Otherwise, C's relevant bits are mixed, and then the two sides
+ // can never be equal, as the left side's bits cannot be mixed.
+ Builder builder(*getModule());
+ // The result is either always true, or always false.
+ c->value = Literal::makeFromInt32(curr->op == NeInt32, c->type);
+ return replaceCurrent(
+ builder.makeSequence(builder.makeDrop(ext), c));
}
- } else if (auto* left = Properties::getSignExtValue(binary->left)) {
- if (auto* right = Properties::getSignExtValue(binary->right)) {
- auto bits = Properties::getSignExtBits(binary->left);
- if (Properties::getSignExtBits(binary->right) == bits) {
- // we are comparing two sign-exts with the same bits, so we may as
- // well replace both with cheaper zexts
- binary->left = makeZeroExt(left, bits);
- binary->right = makeZeroExt(right, bits);
- return binary;
- }
- } else if (auto* load = binary->right->dynCast<Load>()) {
- // we are comparing a load to a sign-ext, we may be able to switch
- // to zext
- auto leftBits = Properties::getSignExtBits(binary->left);
- if (load->signed_ && leftBits == load->bytes * 8) {
- load->signed_ = false;
- binary->left = makeZeroExt(left, leftBits);
- return binary;
- }
+ }
+ } else if (auto* left = Properties::getSignExtValue(curr->left)) {
+ if (auto* right = Properties::getSignExtValue(curr->right)) {
+ auto bits = Properties::getSignExtBits(curr->left);
+ if (Properties::getSignExtBits(curr->right) == bits) {
+ // we are comparing two sign-exts with the same bits, so we may as
+ // well replace both with cheaper zexts
+ curr->left = makeZeroExt(left, bits);
+ curr->right = makeZeroExt(right, bits);
+ return replaceCurrent(curr);
}
- } else if (auto* load = binary->left->dynCast<Load>()) {
- if (auto* right = Properties::getSignExtValue(binary->right)) {
- // we are comparing a load to a sign-ext, we may be able to switch
- // to zext
- auto rightBits = Properties::getSignExtBits(binary->right);
- if (load->signed_ && rightBits == load->bytes * 8) {
- load->signed_ = false;
- binary->right = makeZeroExt(right, rightBits);
- return binary;
- }
+ } else if (auto* load = curr->right->dynCast<Load>()) {
+ // we are comparing a load to a sign-ext, we may be able to switch
+ // to zext
+ auto leftBits = Properties::getSignExtBits(curr->left);
+ if (load->signed_ && leftBits == load->bytes * 8) {
+ load->signed_ = false;
+ curr->left = makeZeroExt(left, leftBits);
+ return replaceCurrent(curr);
}
}
- // note that both left and right may be consts, but then we let
- // precompute compute the constant result
- } else if (binary->op == AddInt32 || binary->op == AddInt64 ||
- binary->op == SubInt32 || binary->op == SubInt64) {
- if (auto* ret = optimizeAddedConstants(binary)) {
- return ret;
- }
- } else if (binary->op == MulFloat32 || binary->op == MulFloat64 ||
- binary->op == DivFloat32 || binary->op == DivFloat64) {
- if (binary->left->type == binary->right->type) {
- if (auto* leftUnary = binary->left->dynCast<Unary>()) {
- if (leftUnary->op ==
- Abstract::getUnary(binary->type, Abstract::Abs)) {
- if (auto* rightUnary = binary->right->dynCast<Unary>()) {
- if (leftUnary->op == rightUnary->op) { // both are abs ops
- // abs(x) * abs(y) ==> abs(x * y)
- // abs(x) / abs(y) ==> abs(x / y)
- binary->left = leftUnary->value;
- binary->right = rightUnary->value;
- leftUnary->value = binary;
- return leftUnary;
- }
- }
- }
+ } else if (auto* load = curr->left->dynCast<Load>()) {
+ if (auto* right = Properties::getSignExtValue(curr->right)) {
+ // we are comparing a load to a sign-ext, we may be able to switch
+ // to zext
+ auto rightBits = Properties::getSignExtBits(curr->right);
+ if (load->signed_ && rightBits == load->bytes * 8) {
+ load->signed_ = false;
+ curr->right = makeZeroExt(right, rightBits);
+ return replaceCurrent(curr);
}
}
}
- // a bunch of operations on a constant right side can be simplified
- if (auto* right = binary->right->dynCast<Const>()) {
- if (binary->op == AndInt32) {
- auto mask = right->value.geti32();
- // and with -1 does nothing (common in asm.js output)
- if (mask == -1) {
- return binary->left;
- }
- // small loads do not need to be masked, the load itself masks
- if (auto* load = binary->left->dynCast<Load>()) {
- if ((load->bytes == 1 && mask == 0xff) ||
- (load->bytes == 2 && mask == 0xffff)) {
- load->signed_ = false;
- return binary->left;
- }
- } else if (auto maskedBits = Bits::getMaskedBits(mask)) {
- if (Bits::getMaxBits(binary->left, this) <= maskedBits) {
- // a mask of lower bits is not needed if we are already smaller
- return binary->left;
- }
- }
- }
- // some math operations have trivial results
- if (auto* ret = optimizeWithConstantOnRight(binary)) {
- return ret;
- }
- // the square of some operations can be merged
- if (auto* left = binary->left->dynCast<Binary>()) {
- if (left->op == binary->op) {
- if (auto* leftRight = left->right->dynCast<Const>()) {
- if (left->op == AndInt32 || left->op == AndInt64) {
- leftRight->value = leftRight->value.and_(right->value);
- return left;
- } else if (left->op == OrInt32 || left->op == OrInt64) {
- leftRight->value = leftRight->value.or_(right->value);
- return left;
- } else if (left->op == XorInt32 || left->op == XorInt64) {
- leftRight->value = leftRight->value.xor_(right->value);
- return left;
- } else if (left->op == MulInt32 || left->op == MulInt64) {
- leftRight->value = leftRight->value.mul(right->value);
- return left;
-
- // TODO:
- // handle signed / unsigned divisions. They are more complex
- } else if (left->op == ShlInt32 || left->op == ShrUInt32 ||
- left->op == ShrSInt32 || left->op == ShlInt64 ||
- left->op == ShrUInt64 || left->op == ShrSInt64) {
- // shifts only use an effective amount from the constant, so
- // adding must be done carefully
- auto total = Bits::getEffectiveShifts(leftRight) +
- Bits::getEffectiveShifts(right);
- if (total == Bits::getEffectiveShifts(total, right->type)) {
- // no overflow, we can do this
- leftRight->value = Literal::makeFromInt32(total, right->type);
- return left;
- } // TODO: handle overflows
+ // note that both left and right may be consts, but then we let
+ // precompute compute the constant result
+ } else if (curr->op == AddInt32 || curr->op == AddInt64 ||
+ curr->op == SubInt32 || curr->op == SubInt64) {
+ if (auto* ret = optimizeAddedConstants(curr)) {
+ return replaceCurrent(ret);
+ }
+ } else if (curr->op == MulFloat32 || curr->op == MulFloat64 ||
+ curr->op == DivFloat32 || curr->op == DivFloat64) {
+ if (curr->left->type == curr->right->type) {
+ if (auto* leftUnary = curr->left->dynCast<Unary>()) {
+ if (leftUnary->op == Abstract::getUnary(curr->type, Abstract::Abs)) {
+ if (auto* rightUnary = curr->right->dynCast<Unary>()) {
+ if (leftUnary->op == rightUnary->op) { // both are abs ops
+ // abs(x) * abs(y) ==> abs(x * y)
+ // abs(x) / abs(y) ==> abs(x / y)
+ curr->left = leftUnary->value;
+ curr->right = rightUnary->value;
+ leftUnary->value = curr;
+ return replaceCurrent(leftUnary);
}
}
}
}
- if (right->type == Type::i32) {
- BinaryOp op;
- int32_t c = right->value.geti32();
- // First, try to lower signed operations to unsigned if that is
- // possible. Some unsigned operations like div_u or rem_u are usually
- // faster on VMs. Also this opens more possibilities for further
- // simplifications afterwards.
- if (c >= 0 &&
- (op = makeUnsignedBinaryOp(binary->op)) != InvalidBinary &&
- Bits::getMaxBits(binary->left, this) <= 31) {
- binary->op = op;
- }
- if (c < 0 && c > std::numeric_limits<int32_t>::min() &&
- binary->op == DivUInt32) {
- // u32(x) / C ==> u32(x) >= C iff C > 2^31
- // We avoid applying this for C == 2^31 due to conflict
- // with other rule which transform to more prefereble
- // right shift operation.
- binary->op = c == -1 ? EqInt32 : GeUInt32;
- return binary;
- }
- if (Bits::isPowerOf2((uint32_t)c)) {
- switch (binary->op) {
- case MulInt32:
- return optimizePowerOf2Mul(binary, (uint32_t)c);
- case RemUInt32:
- return optimizePowerOf2URem(binary, (uint32_t)c);
- case DivUInt32:
- return optimizePowerOf2UDiv(binary, (uint32_t)c);
- default:
- break;
- }
- }
+ }
+ }
+ // a bunch of operations on a constant right side can be simplified
+ if (auto* right = curr->right->dynCast<Const>()) {
+ if (curr->op == AndInt32) {
+ auto mask = right->value.geti32();
+ // and with -1 does nothing (common in asm.js output)
+ if (mask == -1) {
+ return replaceCurrent(curr->left);
}
- if (right->type == Type::i64) {
- BinaryOp op;
- int64_t c = right->value.geti64();
- // See description above for Type::i32
- if (c >= 0 &&
- (op = makeUnsignedBinaryOp(binary->op)) != InvalidBinary &&
- Bits::getMaxBits(binary->left, this) <= 63) {
- binary->op = op;
+ // small loads do not need to be masked, the load itself masks
+ if (auto* load = curr->left->dynCast<Load>()) {
+ if ((load->bytes == 1 && mask == 0xff) ||
+ (load->bytes == 2 && mask == 0xffff)) {
+ load->signed_ = false;
+ return replaceCurrent(curr->left);
}
- if (getPassOptions().shrinkLevel == 0 && c < 0 &&
- c > std::numeric_limits<int64_t>::min() &&
- binary->op == DivUInt64) {
- // u64(x) / C ==> u64(u64(x) >= C) iff C > 2^63
- // We avoid applying this for C == 2^31 due to conflict
- // with other rule which transform to more prefereble
- // right shift operation.
- // And apply this only for shrinkLevel == 0 due to it
- // increasing size by one byte.
- binary->op = c == -1LL ? EqInt64 : GeUInt64;
- binary->type = Type::i32;
- return Builder(*getModule()).makeUnary(ExtendUInt32, binary);
+ } else if (auto maskedBits = Bits::getMaskedBits(mask)) {
+ if (Bits::getMaxBits(curr->left, this) <= maskedBits) {
+ // a mask of lower bits is not needed if we are already smaller
+ return replaceCurrent(curr->left);
}
- if (Bits::isPowerOf2((uint64_t)c)) {
- switch (binary->op) {
- case MulInt64:
- return optimizePowerOf2Mul(binary, (uint64_t)c);
- case RemUInt64:
- return optimizePowerOf2URem(binary, (uint64_t)c);
- case DivUInt64:
- return optimizePowerOf2UDiv(binary, (uint64_t)c);
- default:
- break;
+ }
+ }
+ // some math operations have trivial results
+ if (auto* ret = optimizeWithConstantOnRight(curr)) {
+ return replaceCurrent(ret);
+ }
+ // the square of some operations can be merged
+ if (auto* left = curr->left->dynCast<Binary>()) {
+ if (left->op == curr->op) {
+ if (auto* leftRight = left->right->dynCast<Const>()) {
+ if (left->op == AndInt32 || left->op == AndInt64) {
+ leftRight->value = leftRight->value.and_(right->value);
+ return replaceCurrent(left);
+ } else if (left->op == OrInt32 || left->op == OrInt64) {
+ leftRight->value = leftRight->value.or_(right->value);
+ return replaceCurrent(left);
+ } else if (left->op == XorInt32 || left->op == XorInt64) {
+ leftRight->value = leftRight->value.xor_(right->value);
+ return replaceCurrent(left);
+ } else if (left->op == MulInt32 || left->op == MulInt64) {
+ leftRight->value = leftRight->value.mul(right->value);
+ return replaceCurrent(left);
+
+ // TODO:
+ // handle signed / unsigned divisions. They are more complex
+ } else if (left->op == ShlInt32 || left->op == ShrUInt32 ||
+ left->op == ShrSInt32 || left->op == ShlInt64 ||
+ left->op == ShrUInt64 || left->op == ShrSInt64) {
+ // shifts only use an effective amount from the constant, so
+ // adding must be done carefully
+ auto total = Bits::getEffectiveShifts(leftRight) +
+ Bits::getEffectiveShifts(right);
+ if (total == Bits::getEffectiveShifts(total, right->type)) {
+ // no overflow, we can do this
+ leftRight->value = Literal::makeFromInt32(total, right->type);
+ return replaceCurrent(left);
+ } // TODO: handle overflows
}
}
}
- if (binary->op == DivFloat32) {
- float c = right->value.getf32();
- if (Bits::isPowerOf2InvertibleFloat(c)) {
- return optimizePowerOf2FDiv(binary, c);
+ }
+ if (right->type == Type::i32) {
+ BinaryOp op;
+ int32_t c = right->value.geti32();
+ // First, try to lower signed operations to unsigned if that is
+ // possible. Some unsigned operations like div_u or rem_u are usually
+ // faster on VMs. Also this opens more possibilities for further
+ // simplifications afterwards.
+ if (c >= 0 && (op = makeUnsignedBinaryOp(curr->op)) != InvalidBinary &&
+ Bits::getMaxBits(curr->left, this) <= 31) {
+ curr->op = op;
+ }
+ if (c < 0 && c > std::numeric_limits<int32_t>::min() &&
+ curr->op == DivUInt32) {
+ // u32(x) / C ==> u32(x) >= C iff C > 2^31
+ // We avoid applying this for C == 2^31 due to conflict
+ // with other rule which transform to more prefereble
+ // right shift operation.
+ curr->op = c == -1 ? EqInt32 : GeUInt32;
+ return replaceCurrent(curr);
+ }
+ if (Bits::isPowerOf2((uint32_t)c)) {
+ switch (curr->op) {
+ case MulInt32:
+ return replaceCurrent(optimizePowerOf2Mul(curr, (uint32_t)c));
+ case RemUInt32:
+ return replaceCurrent(optimizePowerOf2URem(curr, (uint32_t)c));
+ case DivUInt32:
+ return replaceCurrent(optimizePowerOf2UDiv(curr, (uint32_t)c));
+ default:
+ break;
}
}
- if (binary->op == DivFloat64) {
- double c = right->value.getf64();
- if (Bits::isPowerOf2InvertibleFloat(c)) {
- return optimizePowerOf2FDiv(binary, c);
+ }
+ if (right->type == Type::i64) {
+ BinaryOp op;
+ int64_t c = right->value.geti64();
+ // See description above for Type::i32
+ if (c >= 0 && (op = makeUnsignedBinaryOp(curr->op)) != InvalidBinary &&
+ Bits::getMaxBits(curr->left, this) <= 63) {
+ curr->op = op;
+ }
+ if (getPassOptions().shrinkLevel == 0 && c < 0 &&
+ c > std::numeric_limits<int64_t>::min() && curr->op == DivUInt64) {
+ // u64(x) / C ==> u64(u64(x) >= C) iff C > 2^63
+ // We avoid applying this for C == 2^31 due to conflict
+ // with other rule which transform to more prefereble
+ // right shift operation.
+ // And apply this only for shrinkLevel == 0 due to it
+ // increasing size by one byte.
+ curr->op = c == -1LL ? EqInt64 : GeUInt64;
+ curr->type = Type::i32;
+ return replaceCurrent(
+ Builder(*getModule()).makeUnary(ExtendUInt32, curr));
+ }
+ if (Bits::isPowerOf2((uint64_t)c)) {
+ switch (curr->op) {
+ case MulInt64:
+ return replaceCurrent(optimizePowerOf2Mul(curr, (uint64_t)c));
+ case RemUInt64:
+ return replaceCurrent(optimizePowerOf2URem(curr, (uint64_t)c));
+ case DivUInt64:
+ return replaceCurrent(optimizePowerOf2UDiv(curr, (uint64_t)c));
+ default:
+ break;
}
}
}
- // a bunch of operations on a constant left side can be simplified
- if (binary->left->is<Const>()) {
- if (auto* ret = optimizeWithConstantOnLeft(binary)) {
- return ret;
+ if (curr->op == DivFloat32) {
+ float c = right->value.getf32();
+ if (Bits::isPowerOf2InvertibleFloat(c)) {
+ return replaceCurrent(optimizePowerOf2FDiv(curr, c));
}
}
- // bitwise operations
- // for and and or, we can potentially conditionalize
- if (binary->op == AndInt32 || binary->op == OrInt32) {
- if (auto* ret = conditionalizeExpensiveOnBitwise(binary)) {
- return ret;
+ if (curr->op == DivFloat64) {
+ double c = right->value.getf64();
+ if (Bits::isPowerOf2InvertibleFloat(c)) {
+ return replaceCurrent(optimizePowerOf2FDiv(curr, c));
}
}
- // for or, we can potentially combine
- if (binary->op == OrInt32) {
- if (auto* ret = combineOr(binary)) {
- return ret;
- }
+ }
+ // a bunch of operations on a constant left side can be simplified
+ if (curr->left->is<Const>()) {
+ if (auto* ret = optimizeWithConstantOnLeft(curr)) {
+ return replaceCurrent(ret);
}
- // relation/comparisons allow for math optimizations
- if (binary->isRelational()) {
- if (auto* ret = optimizeRelational(binary)) {
- return ret;
- }
+ }
+ // bitwise operations
+ // for and and or, we can potentially conditionalize
+ if (curr->op == AndInt32 || curr->op == OrInt32) {
+ if (auto* ret = conditionalizeExpensiveOnBitwise(curr)) {
+ return replaceCurrent(ret);
}
- // finally, try more expensive operations on the binary in
- // the case that they have no side effects
- if (!effects(binary->left).hasSideEffects()) {
- if (ExpressionAnalyzer::equal(binary->left, binary->right)) {
- if (auto* ret = optimizeBinaryWithEqualEffectlessChildren(binary)) {
- return ret;
- }
+ }
+ // for or, we can potentially combine
+ if (curr->op == OrInt32) {
+ if (auto* ret = combineOr(curr)) {
+ return replaceCurrent(ret);
+ }
+ }
+ // relation/comparisons allow for math optimizations
+ if (curr->isRelational()) {
+ if (auto* ret = optimizeRelational(curr)) {
+ return replaceCurrent(ret);
+ }
+ }
+ // finally, try more expensive operations on the curr in
+ // the case that they have no side effects
+ if (!effects(curr->left).hasSideEffects()) {
+ if (ExpressionAnalyzer::equal(curr->left, curr->right)) {
+ if (auto* ret = optimizeBinaryWithEqualEffectlessChildren(curr)) {
+ return replaceCurrent(ret);
}
}
+ }
+
+ if (auto* ret = deduplicateBinary(curr)) {
+ return replaceCurrent(ret);
+ }
+ }
- if (auto* ret = deduplicateBinary(binary)) {
- return ret;
+ void visitUnary(Unary* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+
+ {
+ using namespace Match;
+ using namespace Abstract;
+ Builder builder(*getModule());
+ {
+ // eqz(x - y) => x == y
+ Binary* inner;
+ if (matches(curr, unary(EqZ, binary(&inner, Sub, any(), any())))) {
+ inner->op = Abstract::getBinary(inner->left->type, Eq);
+ inner->type = Type::i32;
+ return replaceCurrent(inner);
+ }
}
- } else if (auto* unary = curr->dynCast<Unary>()) {
- if (unary->op == EqZInt32) {
- if (auto* inner = unary->value->dynCast<Binary>()) {
- // Try to invert a relational operation using De Morgan's law
- auto op = invertBinaryOp(inner->op);
- if (op != InvalidBinary) {
- inner->op = op;
- return inner;
- }
+ {
+ // eqz(x + C) => x == -C
+ Const* c;
+ Binary* inner;
+ if (matches(curr, unary(EqZ, binary(&inner, Add, any(), ival(&c))))) {
+ c->value = c->value.neg();
+ inner->op = Abstract::getBinary(c->type, Eq);
+ inner->type = Type::i32;
+ return replaceCurrent(inner);
}
- // eqz of a sign extension can be of zero-extension
- if (auto* ext = Properties::getSignExtValue(unary->value)) {
- // we are comparing a sign extend to a constant, which means we can
- // use a cheaper zext
- auto bits = Properties::getSignExtBits(unary->value);
- unary->value = makeZeroExt(ext, bits);
- return unary;
- }
- } else if (unary->op == AbsFloat32 || unary->op == AbsFloat64) {
- // abs(-x) ==> abs(x)
- if (auto* unaryInner = unary->value->dynCast<Unary>()) {
- if (unaryInner->op ==
- Abstract::getUnary(unaryInner->type, Abstract::Neg)) {
- unary->value = unaryInner->value;
- return unary;
+ }
+ {
+ // eqz((signed)x % C_pot) => eqz(x & (abs(C_pot) - 1))
+ Const* c;
+ Binary* inner;
+ if (matches(curr, unary(EqZ, binary(&inner, RemS, any(), ival(&c)))) &&
+ (c->value.isSignedMin() ||
+ Bits::isPowerOf2(c->value.abs().getInteger()))) {
+ inner->op = Abstract::getBinary(c->type, And);
+ if (c->value.isSignedMin()) {
+ c->value = Literal::makeSignedMax(c->type);
+ } else {
+ c->value = c->value.abs().sub(Literal::makeOne(c->type));
}
+ return replaceCurrent(curr);
}
- // abs(x * x) ==> x * x
- // abs(x / x) ==> x / x
- if (auto* binary = unary->value->dynCast<Binary>()) {
- if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) ||
- binary->op ==
- Abstract::getBinary(binary->type, Abstract::DivS)) &&
- ExpressionAnalyzer::equal(binary->left, binary->right)) {
- return binary;
- }
- // abs(0 - x) ==> abs(x),
- // only for fast math
- if (fastMath &&
- binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) {
- if (auto* c = binary->left->dynCast<Const>()) {
- if (c->value.isZero()) {
- unary->value = binary->right;
- return unary;
- }
- }
- }
+ }
+ {
+ // i32.eqz(i32.wrap_i64(x)) => i64.eqz(x)
+ // where maxBits(x) <= 32
+ Unary* inner;
+ Expression* x;
+ if (matches(curr, unary(EqZInt32, unary(&inner, WrapInt64, any(&x)))) &&
+ Bits::getMaxBits(x, this) <= 32) {
+ inner->op = EqZInt64;
+ inner->value = x;
+ return replaceCurrent(inner);
}
}
+ }
- if (auto* ret = deduplicateUnary(unary)) {
- return ret;
+ if (curr->op == EqZInt32) {
+ if (auto* inner = curr->value->dynCast<Binary>()) {
+ // Try to invert a relational operation using De Morgan's law
+ auto op = invertBinaryOp(inner->op);
+ if (op != InvalidBinary) {
+ inner->op = op;
+ return replaceCurrent(inner);
+ }
}
- } else if (auto* set = curr->dynCast<GlobalSet>()) {
- // optimize out a set of a get
- auto* get = set->value->dynCast<GlobalGet>();
- if (get && get->name == set->name) {
- ExpressionManipulator::nop(curr);
+ // eqz of a sign extension can be of zero-extension
+ if (auto* ext = Properties::getSignExtValue(curr->value)) {
+ // we are comparing a sign extend to a constant, which means we can
+ // use a cheaper zext
+ auto bits = Properties::getSignExtBits(curr->value);
+ curr->value = makeZeroExt(ext, bits);
+ return replaceCurrent(curr);
}
- } else if (auto* iff = curr->dynCast<If>()) {
- iff->condition = optimizeBoolean(iff->condition);
- if (iff->ifFalse) {
- if (auto* unary = iff->condition->dynCast<Unary>()) {
- if (unary->op == EqZInt32) {
- // flip if-else arms to get rid of an eqz
- iff->condition = unary->value;
- std::swap(iff->ifTrue, iff->ifFalse);
- }
+ } else if (curr->op == AbsFloat32 || curr->op == AbsFloat64) {
+ // abs(-x) ==> abs(x)
+ if (auto* unaryInner = curr->value->dynCast<Unary>()) {
+ if (unaryInner->op ==
+ Abstract::getUnary(unaryInner->type, Abstract::Neg)) {
+ curr->value = unaryInner->value;
+ return replaceCurrent(curr);
}
- if (iff->condition->type != Type::unreachable &&
- ExpressionAnalyzer::equal(iff->ifTrue, iff->ifFalse)) {
- // The sides are identical, so fold. If we can replace the If with one
- // arm and there are no side effects in the condition, replace it. But
- // make sure not to change a concrete expression to an unreachable
- // expression because we want to avoid having to refinalize.
- bool needCondition = effects(iff->condition).hasSideEffects();
- bool wouldBecomeUnreachable =
- iff->type.isConcrete() && iff->ifTrue->type == Type::unreachable;
- Builder builder(*getModule());
- if (!wouldBecomeUnreachable && !needCondition) {
- return iff->ifTrue;
- } else if (!wouldBecomeUnreachable) {
- return builder.makeSequence(builder.makeDrop(iff->condition),
- iff->ifTrue);
- } else {
- // Emit a block with the original concrete type.
- auto* ret = builder.makeBlock();
- if (needCondition) {
- ret->list.push_back(builder.makeDrop(iff->condition));
+ }
+ // abs(x * x) ==> x * x
+ // abs(x / x) ==> x / x
+ if (auto* binary = curr->value->dynCast<Binary>()) {
+ if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) ||
+ binary->op == Abstract::getBinary(binary->type, Abstract::DivS)) &&
+ ExpressionAnalyzer::equal(binary->left, binary->right)) {
+ return replaceCurrent(binary);
+ }
+ // abs(0 - x) ==> abs(x),
+ // only for fast math
+ if (fastMath &&
+ binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) {
+ if (auto* c = binary->left->dynCast<Const>()) {
+ if (c->value.isZero()) {
+ curr->value = binary->right;
+ return replaceCurrent(curr);
}
- ret->list.push_back(iff->ifTrue);
- ret->finalize(iff->type);
- return ret;
}
}
}
- } else if (auto* br = curr->dynCast<Break>()) {
- if (br->condition) {
- br->condition = optimizeBoolean(br->condition);
- }
- } else if (auto* load = curr->dynCast<Load>()) {
- optimizeMemoryAccess(load->ptr, load->offset);
- } else if (auto* store = curr->dynCast<Store>()) {
- optimizeMemoryAccess(store->ptr, store->offset);
- if (store->valueType.isInteger()) {
- // truncates constant values during stores
- // (i32|i64).store(8|16|32)(p, C) ==>
- // (i32|i64).store(8|16|32)(p, C & mask)
- if (auto* c = store->value->dynCast<Const>()) {
- if (store->valueType == Type::i64 && store->bytes == 4) {
- c->value = c->value.and_(Literal(uint64_t(0xffffffff)));
- } else {
- c->value = c->value.and_(Literal::makeFromInt32(
- Bits::lowBitMask(store->bytes * 8), store->valueType));
+ }
+
+ if (auto* ret = deduplicateUnary(curr)) {
+ return replaceCurrent(ret);
+ }
+ }
+
+ void visitSelect(Select* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+ if (auto* ret = optimizeSelect(curr)) {
+ return replaceCurrent(ret);
+ }
+ }
+
+ void visitGlobalSet(GlobalSet* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+ // optimize out a set of a get
+ auto* get = curr->value->dynCast<GlobalGet>();
+ if (get && get->name == curr->name) {
+ ExpressionManipulator::nop(curr);
+ return replaceCurrent(curr);
+ }
+ }
+
+ void visitIf(If* curr) {
+ curr->condition = optimizeBoolean(curr->condition);
+ if (curr->ifFalse) {
+ if (auto* unary = curr->condition->dynCast<Unary>()) {
+ if (unary->op == EqZInt32) {
+ // flip if-else arms to get rid of an eqz
+ curr->condition = unary->value;
+ std::swap(curr->ifTrue, curr->ifFalse);
+ }
+ }
+ if (curr->condition->type != Type::unreachable &&
+ ExpressionAnalyzer::equal(curr->ifTrue, curr->ifFalse)) {
+ // The sides are identical, so fold. If we can replace the If with one
+ // arm and there are no side effects in the condition, replace it. But
+ // make sure not to change a concrete expression to an unreachable
+ // expression because we want to avoid having to refinalize.
+ bool needCondition = effects(curr->condition).hasSideEffects();
+ bool wouldBecomeUnreachable =
+ curr->type.isConcrete() && curr->ifTrue->type == Type::unreachable;
+ Builder builder(*getModule());
+ if (!wouldBecomeUnreachable && !needCondition) {
+ return replaceCurrent(curr->ifTrue);
+ } else if (!wouldBecomeUnreachable) {
+ return replaceCurrent(builder.makeSequence(
+ builder.makeDrop(curr->condition), curr->ifTrue));
+ } else {
+ // Emit a block with the original concrete type.
+ auto* ret = builder.makeBlock();
+ if (needCondition) {
+ ret->list.push_back(builder.makeDrop(curr->condition));
}
+ ret->list.push_back(curr->ifTrue);
+ ret->finalize(curr->type);
+ return replaceCurrent(ret);
}
}
- // stores of fewer bits truncates anyhow
- if (auto* binary = store->value->dynCast<Binary>()) {
- if (binary->op == AndInt32) {
- if (auto* right = binary->right->dynCast<Const>()) {
- if (right->type == Type::i32) {
- auto mask = right->value.geti32();
- if ((store->bytes == 1 && mask == 0xff) ||
- (store->bytes == 2 && mask == 0xffff)) {
- store->value = binary->left;
- }
+ }
+ }
+
+ void visitBreak(Break* curr) {
+ if (curr->condition) {
+ curr->condition = optimizeBoolean(curr->condition);
+ }
+ }
+
+ void visitLoad(Load* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+ optimizeMemoryAccess(curr->ptr, curr->offset);
+ }
+
+ void visitStore(Store* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+ optimizeMemoryAccess(curr->ptr, curr->offset);
+ if (curr->valueType.isInteger()) {
+ // truncates constant values during stores
+ // (i32|i64).store(8|16|32)(p, C) ==>
+ // (i32|i64).store(8|16|32)(p, C & mask)
+ if (auto* c = curr->value->dynCast<Const>()) {
+ if (curr->valueType == Type::i64 && curr->bytes == 4) {
+ c->value = c->value.and_(Literal(uint64_t(0xffffffff)));
+ } else {
+ c->value = c->value.and_(Literal::makeFromInt32(
+ Bits::lowBitMask(curr->bytes * 8), curr->valueType));
+ }
+ }
+ }
+ // stores of fewer bits truncates anyhow
+ if (auto* binary = curr->value->dynCast<Binary>()) {
+ if (binary->op == AndInt32) {
+ if (auto* right = binary->right->dynCast<Const>()) {
+ if (right->type == Type::i32) {
+ auto mask = right->value.geti32();
+ if ((curr->bytes == 1 && mask == 0xff) ||
+ (curr->bytes == 2 && mask == 0xffff)) {
+ curr->value = binary->left;
}
}
- } else if (auto* ext = Properties::getSignExtValue(binary)) {
- // if sign extending the exact bit size we store, we can skip the
- // extension if extending something bigger, then we just alter bits we
- // don't save anyhow
- if (Properties::getSignExtBits(binary) >= Index(store->bytes) * 8) {
- store->value = ext;
- }
}
- } else if (auto* unary = store->value->dynCast<Unary>()) {
- if (unary->op == WrapInt64) {
- // instead of wrapping to 32, just store some of the bits in the i64
- store->valueType = Type::i64;
- store->value = unary->value;
+ } else if (auto* ext = Properties::getSignExtValue(binary)) {
+ // if sign extending the exact bit size we store, we can skip the
+ // extension if extending something bigger, then we just alter bits we
+ // don't save anyhow
+ if (Properties::getSignExtBits(binary) >= Index(curr->bytes) * 8) {
+ curr->value = ext;
}
}
- } else if (auto* memCopy = curr->dynCast<MemoryCopy>()) {
- assert(features.hasBulkMemory());
- if (auto* ret = optimizeMemoryCopy(memCopy)) {
- return ret;
+ } else if (auto* unary = curr->value->dynCast<Unary>()) {
+ if (unary->op == WrapInt64) {
+ // instead of wrapping to 32, just store some of the bits in the i64
+ curr->valueType = Type::i64;
+ curr->value = unary->value;
}
}
- return nullptr;
+ }
+
+ void visitMemoryCopy(MemoryCopy* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+ assert(getModule()->features.hasBulkMemory());
+ if (auto* ret = optimizeMemoryCopy(curr)) {
+ return replaceCurrent(ret);
+ }
}
Index getMaxBitsForLocal(LocalGet* get) {