diff options
author | Max Graey <maxgraey@gmail.com> | 2020-09-17 23:05:55 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-17 13:05:55 -0700 |
commit | 2d47c0b8ae7b72e710b982abce83429c50c6de30 (patch) | |
tree | 45eb34e10b508f6f5d1585b53345776cdde26a07 /src | |
parent | 6116553a91b5da4fd877480bb27fc88b264b737f (diff) | |
download | binaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.tar.gz binaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.tar.bz2 binaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.zip |
Implement more cases for getMaxBits (#2879)
- Complete 64-bit cases in range `AddInt64` ... `ShrSInt64`
- `ExtendSInt32` and `ExtendUInt32` for unary cases
- For binary cases
- `AddInt32` / `AddInt64`
- `MulInt32` / `MulInt64`
- `RemUInt32` / `RemUInt64`
- `RemSInt32` / `RemSInt64`
- `DivUInt32` / `DivUInt64`
- `DivSInt32` / `DivSInt64`
- and more
Also more fast paths for some getMaxBits calculations
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/bits.h | 187 | ||||
-rw-r--r-- | src/support/bits.cpp | 8 | ||||
-rw-r--r-- | src/support/bits.h | 6 |
3 files changed, 186 insertions, 15 deletions
diff --git a/src/ir/bits.h b/src/ir/bits.h index 20d97f13f..132f7ec4d 100644 --- a/src/ir/bits.h +++ b/src/ir/bits.h @@ -128,35 +128,85 @@ struct DummyLocalInfoProvider { template<typename LocalInfoProvider = DummyLocalInfoProvider> Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider = nullptr) { - if (auto* const_ = curr->dynCast<Const>()) { + if (auto* c = curr->dynCast<Const>()) { switch (curr->type.getBasic()) { case Type::i32: - return 32 - const_->value.countLeadingZeroes().geti32(); + return 32 - c->value.countLeadingZeroes().geti32(); case Type::i64: - return 64 - const_->value.countLeadingZeroes().geti64(); + return 64 - c->value.countLeadingZeroes().geti64(); default: WASM_UNREACHABLE("invalid type"); } } else if (auto* binary = curr->dynCast<Binary>()) { switch (binary->op) { // 32-bit - case AddInt32: - case SubInt32: - case MulInt32: - case DivSInt32: - case DivUInt32: - case RemSInt32: - case RemUInt32: case RotLInt32: case RotRInt32: + case SubInt32: + return 32; + case AddInt32: { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1); + } + case MulInt32: { + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(32), maxBitsLeft + maxBitsRight); + } + case DivSInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // If either side might be negative, then the result will be negative + if (maxBitsLeft == 32 || c->value.geti32() < 0) { + return 32; + } + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return 32; + } + case DivUInt32: { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + if (auto* c = binary->right->dynCast<Const>()) { + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return maxBitsLeft; + } + case RemSInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft is negative + if (maxBitsLeft == 32) { + return 32; + } + auto bitsRight = Index(CeilLog2(c->value.geti32())); + return std::min(maxBitsLeft, bitsRight); + } + return 32; + } + case RemUInt32: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto bitsRight = Index(CeilLog2(c->value.geti32())); + return std::min(maxBitsLeft, bitsRight); + } return 32; - case AndInt32: + } + case AndInt32: { return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); + } case OrInt32: - case XorInt32: - return std::max(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); + case XorInt32: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + // if maxBits is negative + if (maxBits == 32) { + return 32; + } + return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); + } case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { return std::min(Index(32), @@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr, case ShrSInt32: { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); + // if maxBits is negative if (maxBits == 32) { return 32; } @@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr, } return 32; } - // 64-bit TODO + case RotLInt64: + case RotRInt64: + case SubInt64: + return 64; + case AddInt64: { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1); + } + case MulInt64: { + auto maxBitsRight = getMaxBits(binary->right, localInfoProvider); + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(64), maxBitsLeft + maxBitsRight); + } + case DivSInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft or right const value is negative + if (maxBitsLeft == 64 || c->value.geti64() < 0) { + return 64; + } + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return 64; + } + case DivUInt64: { + int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + if (auto* c = binary->right->dynCast<Const>()) { + int32_t bitsRight = getMaxBits(c); + return std::max(0, maxBitsLeft - bitsRight + 1); + } + return maxBitsLeft; + } + case RemSInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + // if maxBitsLeft is negative + if (maxBitsLeft == 64) { + return 64; + } + auto bitsRight = Index(CeilLog2(c->value.geti64())); + return std::min(maxBitsLeft, bitsRight); + } + return 64; + } + case RemUInt64: { + if (auto* c = binary->right->dynCast<Const>()) { + auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider); + auto bitsRight = Index(CeilLog2(c->value.geti64())); + return std::min(maxBitsLeft, bitsRight); + } + return 64; + } + case AndInt64: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + return std::min(getMaxBits(binary->left, localInfoProvider), maxBits); + } + case OrInt64: + case XorInt64: { + auto maxBits = getMaxBits(binary->right, localInfoProvider); + // if maxBits is negative + if (maxBits == 64) { + return 64; + } + return std::max(getMaxBits(binary->left, localInfoProvider), maxBits); + } + case ShlInt64: { + if (auto* shifts = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + return std::min(Index(64), + Bits::getEffectiveShifts(shifts) + maxBits); + } + return 64; + } + case ShrUInt64: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 64; + } + case ShrSInt64: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left, localInfoProvider); + // if maxBits is negative + if (maxBits == 64) { + return 64; + } + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 64; + } // comparisons case EqInt32: case NeInt32: @@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr, case GtUInt32: case GeSInt32: case GeUInt32: + case EqInt64: case NeInt64: case LtSInt64: @@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr, case GtUInt64: case GeSInt64: case GeUInt64: + case EqFloat32: case NeFloat32: case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: + case EqFloat64: case NeFloat64: case LtFloat64: @@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr, case EqZInt64: return 1; case WrapInt64: + case ExtendUInt32: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); + case ExtendSInt32: { + auto maxBits = getMaxBits(unary->value, localInfoProvider); + return maxBits == 32 ? Index(64) : maxBits; + } default: { } } diff --git a/src/support/bits.cpp b/src/support/bits.cpp index 992b97955..e60f2365b 100644 --- a/src/support/bits.cpp +++ b/src/support/bits.cpp @@ -152,6 +152,14 @@ template<> int CountLeadingZeroes<uint64_t>(uint64_t v) { #endif } +template<> int CeilLog2<uint32_t>(uint32_t v) { + return 32 - CountLeadingZeroes(v - 1); +} + +template<> int CeilLog2<uint64_t>(uint64_t v) { + return 64 - CountLeadingZeroes(v - 1); +} + uint32_t Log2(uint32_t v) { switch (v) { default: diff --git a/src/support/bits.h b/src/support/bits.h index bd91fdec6..a927a2832 100644 --- a/src/support/bits.h +++ b/src/support/bits.h @@ -40,6 +40,7 @@ template<typename T> int PopCount(T); template<typename T> uint32_t BitReverse(T); template<typename T> int CountTrailingZeroes(T); template<typename T> int CountLeadingZeroes(T); +template<typename T> int CeilLog2(T); #ifndef wasm_support_bits_definitions // The template specializations are provided elsewhere. @@ -52,6 +53,8 @@ extern template int CountTrailingZeroes(uint32_t); extern template int CountTrailingZeroes(uint64_t); extern template int CountLeadingZeroes(uint32_t); extern template int CountLeadingZeroes(uint64_t); +extern template int CeilLog2(uint32_t); +extern template int CeilLog2(uint64_t); #endif // Convenience signed -> unsigned. It usually doesn't make much sense to use bit @@ -65,6 +68,9 @@ template<typename T> int CountTrailingZeroes(T v) { template<typename T> int CountLeadingZeroes(T v) { return CountLeadingZeroes(typename std::make_unsigned<T>::type(v)); } +template<typename T> int CeilLog2(T v) { + return CeilLog2(typename std::make_unsigned<T>::type(v)); +} template<typename T> bool IsPowerOf2(T v) { return v != 0 && (v & (v - 1)) == 0; } |