diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 1057 |
1 files changed, 550 insertions, 507 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index f5718654e..c12a48b0d 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -201,9 +201,7 @@ template<class Opt> struct Match::Internal::MatchSelf<PureMatcherKind<Opt>> { // Main pass class struct OptimizeInstructions - : public WalkerPass< - PostWalker<OptimizeInstructions, - UnifiedExpressionVisitor<OptimizeInstructions>>> { + : public WalkerPass<PostWalker<OptimizeInstructions>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new OptimizeInstructions; } @@ -226,21 +224,34 @@ struct OptimizeInstructions } } - void visitExpression(Expression* curr) { - // if this contains dead code, don't bother trying to optimize it, the type - // might change (if might not be unreachable if just one arm is, for - // example). this optimization pass focuses on actually executing code. the - // only exceptions are control flow changes - if (curr->is<Const>() || curr->is<Call>() || curr->is<Nop>() || - (curr->type == Type::unreachable && !curr->is<Break>() && - !curr->is<Switch>() && !curr->is<If>())) { + // Set to true when one of the visitors makes a change (either replacing the + // node or modifying it). + bool changed; + + // Used to avoid recursion in replaceCurrent, see below. + bool inReplaceCurrent = false; + + void replaceCurrent(Expression* rep) { + WalkerPass<PostWalker<OptimizeInstructions>>::replaceCurrent(rep); + // We may be able to apply multiple patterns as one may open opportunities + // for others. NB: patterns must not have cycles + + // To avoid recursion, this uses the following pattern: the initial call to + // this method comes from one of the visit*() methods. We then loop in here, + // and if we are called again we set |changed| instead of recursing, so that + // we can loop on that value. + if (inReplaceCurrent) { + // We are in the loop below so just note a change and return to there. + changed = true; return; } - // we may be able to apply multiple patterns, one may open opportunities - // that look deeper NB: patterns must not have cycles - while ((curr = handOptimize(curr))) { - replaceCurrent(curr); - } + // Loop on further changes. + inReplaceCurrent = true; + do { + changed = false; + visit(getCurrent()); + } while (changed); + inReplaceCurrent = false; } EffectAnalyzer effects(Expression* expr) { @@ -257,19 +268,16 @@ struct OptimizeInstructions getPassOptions(), getModule()->features, a, b); } - // Optimizations that don't yet fit in the pattern DSL, but could be - // eventually maybe - Expression* handOptimize(Expression* curr) { - if (curr->is<Const>()) { - return nullptr; + void visitBinary(Binary* curr) { + // If this contains dead code, don't bother trying to optimize it, the type + // might change (if might not be unreachable if just one arm is, for + // example). This optimization pass focuses on actually executing code. + if (curr->type == Type::unreachable) { + return; } - FeatureSet features = getModule()->features; - - if (auto* binary = curr->dynCast<Binary>()) { - if (shouldCanonicalize(binary)) { - canonicalize(binary); - } + if (shouldCanonicalize(curr)) { + canonicalize(curr); } { @@ -304,7 +312,7 @@ struct OptimizeInstructions canReorder(x, y)) { sub->left = y; sub->right = x; - return sub; + return replaceCurrent(sub); } } { @@ -326,43 +334,7 @@ struct OptimizeInstructions if (matches(curr, binary(Add, any(&y), binary(&sub, Sub, ival(0), any())))) { sub->left = y; - return sub; - } - } - { - // eqz(x - y) => x == y - Binary* inner; - if (matches(curr, unary(EqZ, binary(&inner, Sub, any(), any())))) { - inner->op = Abstract::getBinary(inner->left->type, Eq); - inner->type = Type::i32; - return inner; - } - } - { - // eqz(x + C) => x == -C - Const* c; - Binary* inner; - if (matches(curr, unary(EqZ, binary(&inner, Add, any(), ival(&c))))) { - c->value = c->value.neg(); - inner->op = Abstract::getBinary(c->type, Eq); - inner->type = Type::i32; - return inner; - } - } - { - // eqz((signed)x % C_pot) => eqz(x & (abs(C_pot) - 1)) - Const* c; - Binary* inner; - if (matches(curr, unary(EqZ, binary(&inner, RemS, any(), ival(&c)))) && - (c->value.isSignedMin() || - Bits::isPowerOf2(c->value.abs().getInteger()))) { - inner->op = Abstract::getBinary(c->type, And); - if (c->value.isSignedMin()) { - c->value = Literal::makeSignedMax(c->type); - } else { - c->value = c->value.abs().sub(Literal::makeOne(c->type)); - } - return curr; + return replaceCurrent(sub); } } { @@ -385,19 +357,7 @@ struct OptimizeInstructions bin->left = x; bin->right = y; un->value = bin; - return un; - } - } - { - // i32.eqz(i32.wrap_i64(x)) => i64.eqz(x) - // where maxBits(x) <= 32 - Unary* inner; - Expression* x; - if (matches(curr, unary(EqZInt32, unary(&inner, WrapInt64, any(&x)))) && - Bits::getMaxBits(x, this) <= 32) { - inner->op = EqZInt64; - inner->value = x; - return inner; + return replaceCurrent(un); } } { @@ -420,7 +380,7 @@ struct OptimizeInstructions Literal::makeFromInt32(c->type.getByteSize() * 8 - 1, c->type)); // x <<>> 0 ==> x if (c->value.isZero()) { - return x; + return replaceCurrent(x); } } if (matches(curr, @@ -431,14 +391,14 @@ struct OptimizeInstructions if ((c->type == Type::i32 && (c->value.geti32() & 31) == 31) || (c->type == Type::i64 && (c->value.geti64() & 63LL) == 63LL)) { curr->cast<Binary>()->right = y; - return curr; + return replaceCurrent(curr); } // i32(x) <<>> (y & C) ==> x, where (C & 31) == 0 // i64(x) <<>> (y & C) ==> x, where (C & 63) == 0 if (((c->type == Type::i32 && (c->value.geti32() & 31) == 0) || (c->type == Type::i64 && (c->value.geti64() & 63LL) == 0LL)) && !effects(y).hasSideEffects()) { - return x; + return replaceCurrent(x); } } } @@ -450,496 +410,579 @@ struct OptimizeInstructions c->value.isZero()) { c->value = Literal::makeOne(Type::i32); c->type = Type::i32; - return c; + return replaceCurrent(c); } // unsigned(x) < 0 => i32(0) if (matches(curr, binary(LtU, pure(&x), ival(&c))) && c->value.isZero()) { c->value = Literal::makeZero(Type::i32); c->type = Type::i32; - return c; - } - } - } - - if (auto* select = curr->dynCast<Select>()) { - return optimizeSelect(select); - } - - if (auto* binary = curr->dynCast<Binary>()) { - if (auto* ext = Properties::getAlmostSignExt(binary)) { - Index extraLeftShifts; - auto bits = Properties::getAlmostSignExtBits(binary, extraLeftShifts); - if (extraLeftShifts == 0) { - if (auto* load = - Properties::getFallthrough(ext, getPassOptions(), features) - ->dynCast<Load>()) { - // pattern match a load of 8 bits and a sign extend using a shl of - // 24 then shr_s of 24 as well, etc. - if (LoadUtils::canBeSigned(load) && - ((load->bytes == 1 && bits == 8) || - (load->bytes == 2 && bits == 16))) { - // if the value falls through, we can't alter the load, as it - // might be captured in a tee - if (load->signed_ == true || load == ext) { - load->signed_ = true; - return ext; - } + return replaceCurrent(c); + } + } + } + if (auto* ext = Properties::getAlmostSignExt(curr)) { + Index extraLeftShifts; + auto bits = Properties::getAlmostSignExtBits(curr, extraLeftShifts); + if (extraLeftShifts == 0) { + if (auto* load = Properties::getFallthrough( + ext, getPassOptions(), getModule()->features) + ->dynCast<Load>()) { + // pattern match a load of 8 bits and a sign extend using a shl of + // 24 then shr_s of 24 as well, etc. + if (LoadUtils::canBeSigned(load) && + ((load->bytes == 1 && bits == 8) || + (load->bytes == 2 && bits == 16))) { + // if the value falls through, we can't alter the load, as it + // might be captured in a tee + if (load->signed_ == true || load == ext) { + load->signed_ = true; + return replaceCurrent(ext); } } } - // We can in some cases remove part of a sign extend, that is, - // (x << A) >> B => x << (A - B) - // If the sign-extend input cannot have a sign bit, we don't need it. - if (Bits::getMaxBits(ext, this) + extraLeftShifts < bits) { - return removeAlmostSignExt(binary); - } - // We also don't need it if it already has an identical-sized sign - // extend applied to it. That is, if it is already a sign-extended - // value, then another sign extend will do nothing. We do need to be - // careful of the extra shifts, though. - if (isSignExted(ext, bits) && extraLeftShifts == 0) { - return removeAlmostSignExt(binary); - } - } else if (binary->op == EqInt32 || binary->op == NeInt32) { - if (auto* c = binary->right->dynCast<Const>()) { - if (auto* ext = Properties::getSignExtValue(binary->left)) { - // We are comparing a sign extend to a constant, which means we can - // use a cheaper zero-extend in some cases. That is, - // (x << S) >> S ==/!= C => x & T ==/!= C - // where S and T are the matching values for sign/zero extend of the - // same size. For example, for an effective 8-bit value: - // (x << 24) >> 24 ==/!= C => x & 255 ==/!= C - // - // The key thing to track here are the upper bits plus the sign bit; - // call those the "relevant bits". This is crucial because x is - // sign-extended, that is, its effective sign bit is spread to all - // the upper bits, which means that the relevant bits on the left - // side are either all 0, or all 1. - auto bits = Properties::getSignExtBits(binary->left); - uint32_t right = c->value.geti32(); - uint32_t numRelevantBits = 32 - bits + 1; - uint32_t setRelevantBits = - Bits::popCount(right >> uint32_t(bits - 1)); - // If all the relevant bits on C are zero - // then we can mask off the high bits instead of sign-extending x. - // This is valid because if x is negative, then the comparison was - // false before (negative vs positive), and will still be false - // as the sign bit will remain to cause a difference. And if x is - // positive then the upper bits would be zero anyhow. - if (setRelevantBits == 0) { - binary->left = makeZeroExt(ext, bits); - return binary; - } else if (setRelevantBits == numRelevantBits) { - // If all those bits are one, then we can do something similar if - // we also zero-extend on the right as well. This is valid - // because, as in the previous case, the sign bit differentiates - // the two sides when they are different, and if the sign bit is - // identical, then the upper bits don't matter, so masking them - // off both sides is fine. - binary->left = makeZeroExt(ext, bits); - c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); - return binary; - } else { - // Otherwise, C's relevant bits are mixed, and then the two sides - // can never be equal, as the left side's bits cannot be mixed. - Builder builder(*getModule()); - // The result is either always true, or always false. - c->value = Literal::makeFromInt32(binary->op == NeInt32, c->type); - return builder.makeSequence(builder.makeDrop(ext), c); - } + } + // We can in some cases remove part of a sign extend, that is, + // (x << A) >> B => x << (A - B) + // If the sign-extend input cannot have a sign bit, we don't need it. + if (Bits::getMaxBits(ext, this) + extraLeftShifts < bits) { + return replaceCurrent(removeAlmostSignExt(curr)); + } + // We also don't need it if it already has an identical-sized sign + // extend applied to it. That is, if it is already a sign-extended + // value, then another sign extend will do nothing. We do need to be + // careful of the extra shifts, though. + if (isSignExted(ext, bits) && extraLeftShifts == 0) { + return replaceCurrent(removeAlmostSignExt(curr)); + } + } else if (curr->op == EqInt32 || curr->op == NeInt32) { + if (auto* c = curr->right->dynCast<Const>()) { + if (auto* ext = Properties::getSignExtValue(curr->left)) { + // We are comparing a sign extend to a constant, which means we can + // use a cheaper zero-extend in some cases. That is, + // (x << S) >> S ==/!= C => x & T ==/!= C + // where S and T are the matching values for sign/zero extend of the + // same size. For example, for an effective 8-bit value: + // (x << 24) >> 24 ==/!= C => x & 255 ==/!= C + // + // The key thing to track here are the upper bits plus the sign bit; + // call those the "relevant bits". This is crucial because x is + // sign-extended, that is, its effective sign bit is spread to all + // the upper bits, which means that the relevant bits on the left + // side are either all 0, or all 1. + auto bits = Properties::getSignExtBits(curr->left); + uint32_t right = c->value.geti32(); + uint32_t numRelevantBits = 32 - bits + 1; + uint32_t setRelevantBits = + Bits::popCount(right >> uint32_t(bits - 1)); + // If all the relevant bits on C are zero + // then we can mask off the high bits instead of sign-extending x. + // This is valid because if x is negative, then the comparison was + // false before (negative vs positive), and will still be false + // as the sign bit will remain to cause a difference. And if x is + // positive then the upper bits would be zero anyhow. + if (setRelevantBits == 0) { + curr->left = makeZeroExt(ext, bits); + return replaceCurrent(curr); + } else if (setRelevantBits == numRelevantBits) { + // If all those bits are one, then we can do something similar if + // we also zero-extend on the right as well. This is valid + // because, as in the previous case, the sign bit differentiates + // the two sides when they are different, and if the sign bit is + // identical, then the upper bits don't matter, so masking them + // off both sides is fine. + curr->left = makeZeroExt(ext, bits); + c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); + return replaceCurrent(curr); + } else { + // Otherwise, C's relevant bits are mixed, and then the two sides + // can never be equal, as the left side's bits cannot be mixed. + Builder builder(*getModule()); + // The result is either always true, or always false. + c->value = Literal::makeFromInt32(curr->op == NeInt32, c->type); + return replaceCurrent( + builder.makeSequence(builder.makeDrop(ext), c)); } - } else if (auto* left = Properties::getSignExtValue(binary->left)) { - if (auto* right = Properties::getSignExtValue(binary->right)) { - auto bits = Properties::getSignExtBits(binary->left); - if (Properties::getSignExtBits(binary->right) == bits) { - // we are comparing two sign-exts with the same bits, so we may as - // well replace both with cheaper zexts - binary->left = makeZeroExt(left, bits); - binary->right = makeZeroExt(right, bits); - return binary; - } - } else if (auto* load = binary->right->dynCast<Load>()) { - // we are comparing a load to a sign-ext, we may be able to switch - // to zext - auto leftBits = Properties::getSignExtBits(binary->left); - if (load->signed_ && leftBits == load->bytes * 8) { - load->signed_ = false; - binary->left = makeZeroExt(left, leftBits); - return binary; - } + } + } else if (auto* left = Properties::getSignExtValue(curr->left)) { + if (auto* right = Properties::getSignExtValue(curr->right)) { + auto bits = Properties::getSignExtBits(curr->left); + if (Properties::getSignExtBits(curr->right) == bits) { + // we are comparing two sign-exts with the same bits, so we may as + // well replace both with cheaper zexts + curr->left = makeZeroExt(left, bits); + curr->right = makeZeroExt(right, bits); + return replaceCurrent(curr); } - } else if (auto* load = binary->left->dynCast<Load>()) { - if (auto* right = Properties::getSignExtValue(binary->right)) { - // we are comparing a load to a sign-ext, we may be able to switch - // to zext - auto rightBits = Properties::getSignExtBits(binary->right); - if (load->signed_ && rightBits == load->bytes * 8) { - load->signed_ = false; - binary->right = makeZeroExt(right, rightBits); - return binary; - } + } else if (auto* load = curr->right->dynCast<Load>()) { + // we are comparing a load to a sign-ext, we may be able to switch + // to zext + auto leftBits = Properties::getSignExtBits(curr->left); + if (load->signed_ && leftBits == load->bytes * 8) { + load->signed_ = false; + curr->left = makeZeroExt(left, leftBits); + return replaceCurrent(curr); } } - // note that both left and right may be consts, but then we let - // precompute compute the constant result - } else if (binary->op == AddInt32 || binary->op == AddInt64 || - binary->op == SubInt32 || binary->op == SubInt64) { - if (auto* ret = optimizeAddedConstants(binary)) { - return ret; - } - } else if (binary->op == MulFloat32 || binary->op == MulFloat64 || - binary->op == DivFloat32 || binary->op == DivFloat64) { - if (binary->left->type == binary->right->type) { - if (auto* leftUnary = binary->left->dynCast<Unary>()) { - if (leftUnary->op == - Abstract::getUnary(binary->type, Abstract::Abs)) { - if (auto* rightUnary = binary->right->dynCast<Unary>()) { - if (leftUnary->op == rightUnary->op) { // both are abs ops - // abs(x) * abs(y) ==> abs(x * y) - // abs(x) / abs(y) ==> abs(x / y) - binary->left = leftUnary->value; - binary->right = rightUnary->value; - leftUnary->value = binary; - return leftUnary; - } - } - } + } else if (auto* load = curr->left->dynCast<Load>()) { + if (auto* right = Properties::getSignExtValue(curr->right)) { + // we are comparing a load to a sign-ext, we may be able to switch + // to zext + auto rightBits = Properties::getSignExtBits(curr->right); + if (load->signed_ && rightBits == load->bytes * 8) { + load->signed_ = false; + curr->right = makeZeroExt(right, rightBits); + return replaceCurrent(curr); } } } - // a bunch of operations on a constant right side can be simplified - if (auto* right = binary->right->dynCast<Const>()) { - if (binary->op == AndInt32) { - auto mask = right->value.geti32(); - // and with -1 does nothing (common in asm.js output) - if (mask == -1) { - return binary->left; - } - // small loads do not need to be masked, the load itself masks - if (auto* load = binary->left->dynCast<Load>()) { - if ((load->bytes == 1 && mask == 0xff) || - (load->bytes == 2 && mask == 0xffff)) { - load->signed_ = false; - return binary->left; - } - } else if (auto maskedBits = Bits::getMaskedBits(mask)) { - if (Bits::getMaxBits(binary->left, this) <= maskedBits) { - // a mask of lower bits is not needed if we are already smaller - return binary->left; - } - } - } - // some math operations have trivial results - if (auto* ret = optimizeWithConstantOnRight(binary)) { - return ret; - } - // the square of some operations can be merged - if (auto* left = binary->left->dynCast<Binary>()) { - if (left->op == binary->op) { - if (auto* leftRight = left->right->dynCast<Const>()) { - if (left->op == AndInt32 || left->op == AndInt64) { - leftRight->value = leftRight->value.and_(right->value); - return left; - } else if (left->op == OrInt32 || left->op == OrInt64) { - leftRight->value = leftRight->value.or_(right->value); - return left; - } else if (left->op == XorInt32 || left->op == XorInt64) { - leftRight->value = leftRight->value.xor_(right->value); - return left; - } else if (left->op == MulInt32 || left->op == MulInt64) { - leftRight->value = leftRight->value.mul(right->value); - return left; - - // TODO: - // handle signed / unsigned divisions. They are more complex - } else if (left->op == ShlInt32 || left->op == ShrUInt32 || - left->op == ShrSInt32 || left->op == ShlInt64 || - left->op == ShrUInt64 || left->op == ShrSInt64) { - // shifts only use an effective amount from the constant, so - // adding must be done carefully - auto total = Bits::getEffectiveShifts(leftRight) + - Bits::getEffectiveShifts(right); - if (total == Bits::getEffectiveShifts(total, right->type)) { - // no overflow, we can do this - leftRight->value = Literal::makeFromInt32(total, right->type); - return left; - } // TODO: handle overflows + // note that both left and right may be consts, but then we let + // precompute compute the constant result + } else if (curr->op == AddInt32 || curr->op == AddInt64 || + curr->op == SubInt32 || curr->op == SubInt64) { + if (auto* ret = optimizeAddedConstants(curr)) { + return replaceCurrent(ret); + } + } else if (curr->op == MulFloat32 || curr->op == MulFloat64 || + curr->op == DivFloat32 || curr->op == DivFloat64) { + if (curr->left->type == curr->right->type) { + if (auto* leftUnary = curr->left->dynCast<Unary>()) { + if (leftUnary->op == Abstract::getUnary(curr->type, Abstract::Abs)) { + if (auto* rightUnary = curr->right->dynCast<Unary>()) { + if (leftUnary->op == rightUnary->op) { // both are abs ops + // abs(x) * abs(y) ==> abs(x * y) + // abs(x) / abs(y) ==> abs(x / y) + curr->left = leftUnary->value; + curr->right = rightUnary->value; + leftUnary->value = curr; + return replaceCurrent(leftUnary); } } } } - if (right->type == Type::i32) { - BinaryOp op; - int32_t c = right->value.geti32(); - // First, try to lower signed operations to unsigned if that is - // possible. Some unsigned operations like div_u or rem_u are usually - // faster on VMs. Also this opens more possibilities for further - // simplifications afterwards. - if (c >= 0 && - (op = makeUnsignedBinaryOp(binary->op)) != InvalidBinary && - Bits::getMaxBits(binary->left, this) <= 31) { - binary->op = op; - } - if (c < 0 && c > std::numeric_limits<int32_t>::min() && - binary->op == DivUInt32) { - // u32(x) / C ==> u32(x) >= C iff C > 2^31 - // We avoid applying this for C == 2^31 due to conflict - // with other rule which transform to more prefereble - // right shift operation. - binary->op = c == -1 ? EqInt32 : GeUInt32; - return binary; - } - if (Bits::isPowerOf2((uint32_t)c)) { - switch (binary->op) { - case MulInt32: - return optimizePowerOf2Mul(binary, (uint32_t)c); - case RemUInt32: - return optimizePowerOf2URem(binary, (uint32_t)c); - case DivUInt32: - return optimizePowerOf2UDiv(binary, (uint32_t)c); - default: - break; - } - } + } + } + // a bunch of operations on a constant right side can be simplified + if (auto* right = curr->right->dynCast<Const>()) { + if (curr->op == AndInt32) { + auto mask = right->value.geti32(); + // and with -1 does nothing (common in asm.js output) + if (mask == -1) { + return replaceCurrent(curr->left); } - if (right->type == Type::i64) { - BinaryOp op; - int64_t c = right->value.geti64(); - // See description above for Type::i32 - if (c >= 0 && - (op = makeUnsignedBinaryOp(binary->op)) != InvalidBinary && - Bits::getMaxBits(binary->left, this) <= 63) { - binary->op = op; + // small loads do not need to be masked, the load itself masks + if (auto* load = curr->left->dynCast<Load>()) { + if ((load->bytes == 1 && mask == 0xff) || + (load->bytes == 2 && mask == 0xffff)) { + load->signed_ = false; + return replaceCurrent(curr->left); } - if (getPassOptions().shrinkLevel == 0 && c < 0 && - c > std::numeric_limits<int64_t>::min() && - binary->op == DivUInt64) { - // u64(x) / C ==> u64(u64(x) >= C) iff C > 2^63 - // We avoid applying this for C == 2^31 due to conflict - // with other rule which transform to more prefereble - // right shift operation. - // And apply this only for shrinkLevel == 0 due to it - // increasing size by one byte. - binary->op = c == -1LL ? EqInt64 : GeUInt64; - binary->type = Type::i32; - return Builder(*getModule()).makeUnary(ExtendUInt32, binary); + } else if (auto maskedBits = Bits::getMaskedBits(mask)) { + if (Bits::getMaxBits(curr->left, this) <= maskedBits) { + // a mask of lower bits is not needed if we are already smaller + return replaceCurrent(curr->left); } - if (Bits::isPowerOf2((uint64_t)c)) { - switch (binary->op) { - case MulInt64: - return optimizePowerOf2Mul(binary, (uint64_t)c); - case RemUInt64: - return optimizePowerOf2URem(binary, (uint64_t)c); - case DivUInt64: - return optimizePowerOf2UDiv(binary, (uint64_t)c); - default: - break; + } + } + // some math operations have trivial results + if (auto* ret = optimizeWithConstantOnRight(curr)) { + return replaceCurrent(ret); + } + // the square of some operations can be merged + if (auto* left = curr->left->dynCast<Binary>()) { + if (left->op == curr->op) { + if (auto* leftRight = left->right->dynCast<Const>()) { + if (left->op == AndInt32 || left->op == AndInt64) { + leftRight->value = leftRight->value.and_(right->value); + return replaceCurrent(left); + } else if (left->op == OrInt32 || left->op == OrInt64) { + leftRight->value = leftRight->value.or_(right->value); + return replaceCurrent(left); + } else if (left->op == XorInt32 || left->op == XorInt64) { + leftRight->value = leftRight->value.xor_(right->value); + return replaceCurrent(left); + } else if (left->op == MulInt32 || left->op == MulInt64) { + leftRight->value = leftRight->value.mul(right->value); + return replaceCurrent(left); + + // TODO: + // handle signed / unsigned divisions. They are more complex + } else if (left->op == ShlInt32 || left->op == ShrUInt32 || + left->op == ShrSInt32 || left->op == ShlInt64 || + left->op == ShrUInt64 || left->op == ShrSInt64) { + // shifts only use an effective amount from the constant, so + // adding must be done carefully + auto total = Bits::getEffectiveShifts(leftRight) + + Bits::getEffectiveShifts(right); + if (total == Bits::getEffectiveShifts(total, right->type)) { + // no overflow, we can do this + leftRight->value = Literal::makeFromInt32(total, right->type); + return replaceCurrent(left); + } // TODO: handle overflows } } } - if (binary->op == DivFloat32) { - float c = right->value.getf32(); - if (Bits::isPowerOf2InvertibleFloat(c)) { - return optimizePowerOf2FDiv(binary, c); + } + if (right->type == Type::i32) { + BinaryOp op; + int32_t c = right->value.geti32(); + // First, try to lower signed operations to unsigned if that is + // possible. Some unsigned operations like div_u or rem_u are usually + // faster on VMs. Also this opens more possibilities for further + // simplifications afterwards. + if (c >= 0 && (op = makeUnsignedBinaryOp(curr->op)) != InvalidBinary && + Bits::getMaxBits(curr->left, this) <= 31) { + curr->op = op; + } + if (c < 0 && c > std::numeric_limits<int32_t>::min() && + curr->op == DivUInt32) { + // u32(x) / C ==> u32(x) >= C iff C > 2^31 + // We avoid applying this for C == 2^31 due to conflict + // with other rule which transform to more prefereble + // right shift operation. + curr->op = c == -1 ? EqInt32 : GeUInt32; + return replaceCurrent(curr); + } + if (Bits::isPowerOf2((uint32_t)c)) { + switch (curr->op) { + case MulInt32: + return replaceCurrent(optimizePowerOf2Mul(curr, (uint32_t)c)); + case RemUInt32: + return replaceCurrent(optimizePowerOf2URem(curr, (uint32_t)c)); + case DivUInt32: + return replaceCurrent(optimizePowerOf2UDiv(curr, (uint32_t)c)); + default: + break; } } - if (binary->op == DivFloat64) { - double c = right->value.getf64(); - if (Bits::isPowerOf2InvertibleFloat(c)) { - return optimizePowerOf2FDiv(binary, c); + } + if (right->type == Type::i64) { + BinaryOp op; + int64_t c = right->value.geti64(); + // See description above for Type::i32 + if (c >= 0 && (op = makeUnsignedBinaryOp(curr->op)) != InvalidBinary && + Bits::getMaxBits(curr->left, this) <= 63) { + curr->op = op; + } + if (getPassOptions().shrinkLevel == 0 && c < 0 && + c > std::numeric_limits<int64_t>::min() && curr->op == DivUInt64) { + // u64(x) / C ==> u64(u64(x) >= C) iff C > 2^63 + // We avoid applying this for C == 2^31 due to conflict + // with other rule which transform to more prefereble + // right shift operation. + // And apply this only for shrinkLevel == 0 due to it + // increasing size by one byte. + curr->op = c == -1LL ? EqInt64 : GeUInt64; + curr->type = Type::i32; + return replaceCurrent( + Builder(*getModule()).makeUnary(ExtendUInt32, curr)); + } + if (Bits::isPowerOf2((uint64_t)c)) { + switch (curr->op) { + case MulInt64: + return replaceCurrent(optimizePowerOf2Mul(curr, (uint64_t)c)); + case RemUInt64: + return replaceCurrent(optimizePowerOf2URem(curr, (uint64_t)c)); + case DivUInt64: + return replaceCurrent(optimizePowerOf2UDiv(curr, (uint64_t)c)); + default: + break; } } } - // a bunch of operations on a constant left side can be simplified - if (binary->left->is<Const>()) { - if (auto* ret = optimizeWithConstantOnLeft(binary)) { - return ret; + if (curr->op == DivFloat32) { + float c = right->value.getf32(); + if (Bits::isPowerOf2InvertibleFloat(c)) { + return replaceCurrent(optimizePowerOf2FDiv(curr, c)); } } - // bitwise operations - // for and and or, we can potentially conditionalize - if (binary->op == AndInt32 || binary->op == OrInt32) { - if (auto* ret = conditionalizeExpensiveOnBitwise(binary)) { - return ret; + if (curr->op == DivFloat64) { + double c = right->value.getf64(); + if (Bits::isPowerOf2InvertibleFloat(c)) { + return replaceCurrent(optimizePowerOf2FDiv(curr, c)); } } - // for or, we can potentially combine - if (binary->op == OrInt32) { - if (auto* ret = combineOr(binary)) { - return ret; - } + } + // a bunch of operations on a constant left side can be simplified + if (curr->left->is<Const>()) { + if (auto* ret = optimizeWithConstantOnLeft(curr)) { + return replaceCurrent(ret); } - // relation/comparisons allow for math optimizations - if (binary->isRelational()) { - if (auto* ret = optimizeRelational(binary)) { - return ret; - } + } + // bitwise operations + // for and and or, we can potentially conditionalize + if (curr->op == AndInt32 || curr->op == OrInt32) { + if (auto* ret = conditionalizeExpensiveOnBitwise(curr)) { + return replaceCurrent(ret); } - // finally, try more expensive operations on the binary in - // the case that they have no side effects - if (!effects(binary->left).hasSideEffects()) { - if (ExpressionAnalyzer::equal(binary->left, binary->right)) { - if (auto* ret = optimizeBinaryWithEqualEffectlessChildren(binary)) { - return ret; - } + } + // for or, we can potentially combine + if (curr->op == OrInt32) { + if (auto* ret = combineOr(curr)) { + return replaceCurrent(ret); + } + } + // relation/comparisons allow for math optimizations + if (curr->isRelational()) { + if (auto* ret = optimizeRelational(curr)) { + return replaceCurrent(ret); + } + } + // finally, try more expensive operations on the curr in + // the case that they have no side effects + if (!effects(curr->left).hasSideEffects()) { + if (ExpressionAnalyzer::equal(curr->left, curr->right)) { + if (auto* ret = optimizeBinaryWithEqualEffectlessChildren(curr)) { + return replaceCurrent(ret); } } + } + + if (auto* ret = deduplicateBinary(curr)) { + return replaceCurrent(ret); + } + } - if (auto* ret = deduplicateBinary(binary)) { - return ret; + void visitUnary(Unary* curr) { + if (curr->type == Type::unreachable) { + return; + } + + { + using namespace Match; + using namespace Abstract; + Builder builder(*getModule()); + { + // eqz(x - y) => x == y + Binary* inner; + if (matches(curr, unary(EqZ, binary(&inner, Sub, any(), any())))) { + inner->op = Abstract::getBinary(inner->left->type, Eq); + inner->type = Type::i32; + return replaceCurrent(inner); + } } - } else if (auto* unary = curr->dynCast<Unary>()) { - if (unary->op == EqZInt32) { - if (auto* inner = unary->value->dynCast<Binary>()) { - // Try to invert a relational operation using De Morgan's law - auto op = invertBinaryOp(inner->op); - if (op != InvalidBinary) { - inner->op = op; - return inner; - } + { + // eqz(x + C) => x == -C + Const* c; + Binary* inner; + if (matches(curr, unary(EqZ, binary(&inner, Add, any(), ival(&c))))) { + c->value = c->value.neg(); + inner->op = Abstract::getBinary(c->type, Eq); + inner->type = Type::i32; + return replaceCurrent(inner); } - // eqz of a sign extension can be of zero-extension - if (auto* ext = Properties::getSignExtValue(unary->value)) { - // we are comparing a sign extend to a constant, which means we can - // use a cheaper zext - auto bits = Properties::getSignExtBits(unary->value); - unary->value = makeZeroExt(ext, bits); - return unary; - } - } else if (unary->op == AbsFloat32 || unary->op == AbsFloat64) { - // abs(-x) ==> abs(x) - if (auto* unaryInner = unary->value->dynCast<Unary>()) { - if (unaryInner->op == - Abstract::getUnary(unaryInner->type, Abstract::Neg)) { - unary->value = unaryInner->value; - return unary; + } + { + // eqz((signed)x % C_pot) => eqz(x & (abs(C_pot) - 1)) + Const* c; + Binary* inner; + if (matches(curr, unary(EqZ, binary(&inner, RemS, any(), ival(&c)))) && + (c->value.isSignedMin() || + Bits::isPowerOf2(c->value.abs().getInteger()))) { + inner->op = Abstract::getBinary(c->type, And); + if (c->value.isSignedMin()) { + c->value = Literal::makeSignedMax(c->type); + } else { + c->value = c->value.abs().sub(Literal::makeOne(c->type)); } + return replaceCurrent(curr); } - // abs(x * x) ==> x * x - // abs(x / x) ==> x / x - if (auto* binary = unary->value->dynCast<Binary>()) { - if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) || - binary->op == - Abstract::getBinary(binary->type, Abstract::DivS)) && - ExpressionAnalyzer::equal(binary->left, binary->right)) { - return binary; - } - // abs(0 - x) ==> abs(x), - // only for fast math - if (fastMath && - binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) { - if (auto* c = binary->left->dynCast<Const>()) { - if (c->value.isZero()) { - unary->value = binary->right; - return unary; - } - } - } + } + { + // i32.eqz(i32.wrap_i64(x)) => i64.eqz(x) + // where maxBits(x) <= 32 + Unary* inner; + Expression* x; + if (matches(curr, unary(EqZInt32, unary(&inner, WrapInt64, any(&x)))) && + Bits::getMaxBits(x, this) <= 32) { + inner->op = EqZInt64; + inner->value = x; + return replaceCurrent(inner); } } + } - if (auto* ret = deduplicateUnary(unary)) { - return ret; + if (curr->op == EqZInt32) { + if (auto* inner = curr->value->dynCast<Binary>()) { + // Try to invert a relational operation using De Morgan's law + auto op = invertBinaryOp(inner->op); + if (op != InvalidBinary) { + inner->op = op; + return replaceCurrent(inner); + } } - } else if (auto* set = curr->dynCast<GlobalSet>()) { - // optimize out a set of a get - auto* get = set->value->dynCast<GlobalGet>(); - if (get && get->name == set->name) { - ExpressionManipulator::nop(curr); + // eqz of a sign extension can be of zero-extension + if (auto* ext = Properties::getSignExtValue(curr->value)) { + // we are comparing a sign extend to a constant, which means we can + // use a cheaper zext + auto bits = Properties::getSignExtBits(curr->value); + curr->value = makeZeroExt(ext, bits); + return replaceCurrent(curr); } - } else if (auto* iff = curr->dynCast<If>()) { - iff->condition = optimizeBoolean(iff->condition); - if (iff->ifFalse) { - if (auto* unary = iff->condition->dynCast<Unary>()) { - if (unary->op == EqZInt32) { - // flip if-else arms to get rid of an eqz - iff->condition = unary->value; - std::swap(iff->ifTrue, iff->ifFalse); - } + } else if (curr->op == AbsFloat32 || curr->op == AbsFloat64) { + // abs(-x) ==> abs(x) + if (auto* unaryInner = curr->value->dynCast<Unary>()) { + if (unaryInner->op == + Abstract::getUnary(unaryInner->type, Abstract::Neg)) { + curr->value = unaryInner->value; + return replaceCurrent(curr); } - if (iff->condition->type != Type::unreachable && - ExpressionAnalyzer::equal(iff->ifTrue, iff->ifFalse)) { - // The sides are identical, so fold. If we can replace the If with one - // arm and there are no side effects in the condition, replace it. But - // make sure not to change a concrete expression to an unreachable - // expression because we want to avoid having to refinalize. - bool needCondition = effects(iff->condition).hasSideEffects(); - bool wouldBecomeUnreachable = - iff->type.isConcrete() && iff->ifTrue->type == Type::unreachable; - Builder builder(*getModule()); - if (!wouldBecomeUnreachable && !needCondition) { - return iff->ifTrue; - } else if (!wouldBecomeUnreachable) { - return builder.makeSequence(builder.makeDrop(iff->condition), - iff->ifTrue); - } else { - // Emit a block with the original concrete type. - auto* ret = builder.makeBlock(); - if (needCondition) { - ret->list.push_back(builder.makeDrop(iff->condition)); + } + // abs(x * x) ==> x * x + // abs(x / x) ==> x / x + if (auto* binary = curr->value->dynCast<Binary>()) { + if ((binary->op == Abstract::getBinary(binary->type, Abstract::Mul) || + binary->op == Abstract::getBinary(binary->type, Abstract::DivS)) && + ExpressionAnalyzer::equal(binary->left, binary->right)) { + return replaceCurrent(binary); + } + // abs(0 - x) ==> abs(x), + // only for fast math + if (fastMath && + binary->op == Abstract::getBinary(binary->type, Abstract::Sub)) { + if (auto* c = binary->left->dynCast<Const>()) { + if (c->value.isZero()) { + curr->value = binary->right; + return replaceCurrent(curr); } - ret->list.push_back(iff->ifTrue); - ret->finalize(iff->type); - return ret; } } } - } else if (auto* br = curr->dynCast<Break>()) { - if (br->condition) { - br->condition = optimizeBoolean(br->condition); - } - } else if (auto* load = curr->dynCast<Load>()) { - optimizeMemoryAccess(load->ptr, load->offset); - } else if (auto* store = curr->dynCast<Store>()) { - optimizeMemoryAccess(store->ptr, store->offset); - if (store->valueType.isInteger()) { - // truncates constant values during stores - // (i32|i64).store(8|16|32)(p, C) ==> - // (i32|i64).store(8|16|32)(p, C & mask) - if (auto* c = store->value->dynCast<Const>()) { - if (store->valueType == Type::i64 && store->bytes == 4) { - c->value = c->value.and_(Literal(uint64_t(0xffffffff))); - } else { - c->value = c->value.and_(Literal::makeFromInt32( - Bits::lowBitMask(store->bytes * 8), store->valueType)); + } + + if (auto* ret = deduplicateUnary(curr)) { + return replaceCurrent(ret); + } + } + + void visitSelect(Select* curr) { + if (curr->type == Type::unreachable) { + return; + } + if (auto* ret = optimizeSelect(curr)) { + return replaceCurrent(ret); + } + } + + void visitGlobalSet(GlobalSet* curr) { + if (curr->type == Type::unreachable) { + return; + } + // optimize out a set of a get + auto* get = curr->value->dynCast<GlobalGet>(); + if (get && get->name == curr->name) { + ExpressionManipulator::nop(curr); + return replaceCurrent(curr); + } + } + + void visitIf(If* curr) { + curr->condition = optimizeBoolean(curr->condition); + if (curr->ifFalse) { + if (auto* unary = curr->condition->dynCast<Unary>()) { + if (unary->op == EqZInt32) { + // flip if-else arms to get rid of an eqz + curr->condition = unary->value; + std::swap(curr->ifTrue, curr->ifFalse); + } + } + if (curr->condition->type != Type::unreachable && + ExpressionAnalyzer::equal(curr->ifTrue, curr->ifFalse)) { + // The sides are identical, so fold. If we can replace the If with one + // arm and there are no side effects in the condition, replace it. But + // make sure not to change a concrete expression to an unreachable + // expression because we want to avoid having to refinalize. + bool needCondition = effects(curr->condition).hasSideEffects(); + bool wouldBecomeUnreachable = + curr->type.isConcrete() && curr->ifTrue->type == Type::unreachable; + Builder builder(*getModule()); + if (!wouldBecomeUnreachable && !needCondition) { + return replaceCurrent(curr->ifTrue); + } else if (!wouldBecomeUnreachable) { + return replaceCurrent(builder.makeSequence( + builder.makeDrop(curr->condition), curr->ifTrue)); + } else { + // Emit a block with the original concrete type. + auto* ret = builder.makeBlock(); + if (needCondition) { + ret->list.push_back(builder.makeDrop(curr->condition)); } + ret->list.push_back(curr->ifTrue); + ret->finalize(curr->type); + return replaceCurrent(ret); } } - // stores of fewer bits truncates anyhow - if (auto* binary = store->value->dynCast<Binary>()) { - if (binary->op == AndInt32) { - if (auto* right = binary->right->dynCast<Const>()) { - if (right->type == Type::i32) { - auto mask = right->value.geti32(); - if ((store->bytes == 1 && mask == 0xff) || - (store->bytes == 2 && mask == 0xffff)) { - store->value = binary->left; - } + } + } + + void visitBreak(Break* curr) { + if (curr->condition) { + curr->condition = optimizeBoolean(curr->condition); + } + } + + void visitLoad(Load* curr) { + if (curr->type == Type::unreachable) { + return; + } + optimizeMemoryAccess(curr->ptr, curr->offset); + } + + void visitStore(Store* curr) { + if (curr->type == Type::unreachable) { + return; + } + optimizeMemoryAccess(curr->ptr, curr->offset); + if (curr->valueType.isInteger()) { + // truncates constant values during stores + // (i32|i64).store(8|16|32)(p, C) ==> + // (i32|i64).store(8|16|32)(p, C & mask) + if (auto* c = curr->value->dynCast<Const>()) { + if (curr->valueType == Type::i64 && curr->bytes == 4) { + c->value = c->value.and_(Literal(uint64_t(0xffffffff))); + } else { + c->value = c->value.and_(Literal::makeFromInt32( + Bits::lowBitMask(curr->bytes * 8), curr->valueType)); + } + } + } + // stores of fewer bits truncates anyhow + if (auto* binary = curr->value->dynCast<Binary>()) { + if (binary->op == AndInt32) { + if (auto* right = binary->right->dynCast<Const>()) { + if (right->type == Type::i32) { + auto mask = right->value.geti32(); + if ((curr->bytes == 1 && mask == 0xff) || + (curr->bytes == 2 && mask == 0xffff)) { + curr->value = binary->left; } } - } else if (auto* ext = Properties::getSignExtValue(binary)) { - // if sign extending the exact bit size we store, we can skip the - // extension if extending something bigger, then we just alter bits we - // don't save anyhow - if (Properties::getSignExtBits(binary) >= Index(store->bytes) * 8) { - store->value = ext; - } } - } else if (auto* unary = store->value->dynCast<Unary>()) { - if (unary->op == WrapInt64) { - // instead of wrapping to 32, just store some of the bits in the i64 - store->valueType = Type::i64; - store->value = unary->value; + } else if (auto* ext = Properties::getSignExtValue(binary)) { + // if sign extending the exact bit size we store, we can skip the + // extension if extending something bigger, then we just alter bits we + // don't save anyhow + if (Properties::getSignExtBits(binary) >= Index(curr->bytes) * 8) { + curr->value = ext; } } - } else if (auto* memCopy = curr->dynCast<MemoryCopy>()) { - assert(features.hasBulkMemory()); - if (auto* ret = optimizeMemoryCopy(memCopy)) { - return ret; + } else if (auto* unary = curr->value->dynCast<Unary>()) { + if (unary->op == WrapInt64) { + // instead of wrapping to 32, just store some of the bits in the i64 + curr->valueType = Type::i64; + curr->value = unary->value; } } - return nullptr; + } + + void visitMemoryCopy(MemoryCopy* curr) { + if (curr->type == Type::unreachable) { + return; + } + assert(getModule()->features.hasBulkMemory()); + if (auto* ret = optimizeMemoryCopy(curr)) { + return replaceCurrent(ret); + } } Index getMaxBitsForLocal(LocalGet* get) { |