diff options
Diffstat (limited to 'src/ir/bits.h')
-rw-r--r-- | src/ir/bits.h | 187 |
1 files changed, 172 insertions, 15 deletions
diff --git a/src/ir/bits.h b/src/ir/bits.h index 20d97f13f..132f7ec4d 100644 --- a/src/ir/bits.h +++ b/src/ir/bits.h @@ -128,35 +128,85 @@ struct DummyLocalInfoProvider { template<typename LocalInfoProvider = DummyLocalInfoProvider> Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider = nullptr) { - if (auto* const_ = curr->dynCast<Const>()) { + if (auto* c = curr->dynCast<Const>()) { switch (curr->type.getBasic()) { case Type::i32: - return 32 - const_->value.countLeadingZeroes().geti32(); + return 32 - c->value.countLeadingZeroes().geti32(); case Type::i64: - return 64 - const_->value.countLeadingZeroes().geti64(); + return 64 - c->value.countLeadingZeroes().geti64(); default: WASM_UNREACHABLE("invalid type"); } } else if (auto* binary = curr->dynCast<Binary>()) { switch (binary->op) { // 32-bit - case AddInt32: - case SubInt32: - case MulInt32: - case DivSInt32: - case DivUInt32: - case RemSInt32: - case RemUInt32: case RotLInt32: case RotRInt32: + case SubInt32: + return 32; + case AddInt32: { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1); + } + case MulInt32: { + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(32), maxBitsLeft + maxBitsRight); + } + case DivSInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // If either side might be negative, then the result will be negative + if (maxBitsLeft == 32 || c->value.geti32() < 0) { + return 32; + } + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return 32; + } + case DivUInt32: { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + if (auto* c = binary->right->dynCast<Const>()) { + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return maxBitsLeft; + } + case RemSInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft is negative + if (maxBitsLeft == 32) { + return 32; + } + auto bitsRight = Index(CeilLog2(c->value.geti32())); + return std::min(maxBitsLeft, bitsRight); + } + return 32; + } + case RemUInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto bitsRight = Index(CeilLog2(c->value.geti32())); + return std::min(maxBitsLeft, bitsRight); + } return 32; - case AndInt32: + } + case AndInt32: { return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); + } case OrInt32: - case XorInt32: - return std::max(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); + case XorInt32: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + // if maxBits is negative + if (maxBits == 32) { + return 32; + } + return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); + } case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { return std::min(Index(32), @@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr, case ShrSInt32: { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); + // if maxBits is negative if (maxBits == 32) { return 32; } @@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr, } return 32; } - // 64-bit TODO + case RotLInt64: + case RotRInt64: + case SubInt64: + return 64; + case AddInt64: { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1); + } + case MulInt64: { + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(64), maxBitsLeft + maxBitsRight); + } + case DivSInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft or right const value is negative + if (maxBitsLeft == 64 || c->value.geti64() < 0) { + return 64; + } + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return 64; + } + case DivUInt64: { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + if (auto* c = binary->right->dynCast<Const>()) { + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return maxBitsLeft; + } + case RemSInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft is negative + if (maxBitsLeft == 64) { + return 64; + } + auto bitsRight = Index(CeilLog2(c->value.geti64())); + return std::min(maxBitsLeft, bitsRight); + } + return 64; + } + case RemUInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto bitsRight = Index(CeilLog2(c->value.geti64())); + return std::min(maxBitsLeft, bitsRight); + } + return 64; + } + case AndInt64: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + return std::min(getMaxBits(binary->left, localInfoProvider), maxBits); + } + case OrInt64: + case XorInt64: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + // if maxBits is negative + if (maxBits == 64) { + return 64; + } + return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); + } + case ShlInt64: { + if (auto* shifts = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(64), + Bits::getEffectiveShifts(shifts) + maxBits); + } + return 64; + } + case ShrUInt64: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 64; + } + case ShrSInt64: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + // if maxBits is negative + if (maxBits == 64) { + return 64; + } + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 64; + } // comparisons case EqInt32: case NeInt32: @@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr, case GtUInt32: case GeSInt32: case GeUInt32: + case EqInt64: case NeInt64: case LtSInt64: @@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr, case GtUInt64: case GeSInt64: case GeUInt64: + case EqFloat32: case NeFloat32: case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: + case EqFloat64: case NeFloat64: case LtFloat64: @@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr, case EqZInt64: return 1; case WrapInt64: + case ExtendUInt32: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); + case ExtendSInt32: { + auto maxBits = getMaxBits(unary->value, localInfoProvider); + return maxBits == 32 ? Index(64) : maxBits; + } default: { } } |