summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMax Graey <maxgraey@gmail.com>2020-09-17 23:05:55 +0300
committerGitHub <noreply@github.com>2020-09-17 13:05:55 -0700
commit2d47c0b8ae7b72e710b982abce83429c50c6de30 (patch)
tree45eb34e10b508f6f5d1585b53345776cdde26a07 /src
parent6116553a91b5da4fd877480bb27fc88b264b737f (diff)
downloadbinaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.tar.gz
binaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.tar.bz2
binaryen-2d47c0b8ae7b72e710b982abce83429c50c6de30.zip
Implement more cases for getMaxBits (#2879)
- Complete 64-bit cases in range `AddInt64` ... `ShrSInt64` - `ExtendSInt32` and `ExtendUInt32` for unary cases - For binary cases - `AddInt32` / `AddInt64` - `MulInt32` / `MulInt64` - `RemUInt32` / `RemUInt64` - `RemSInt32` / `RemSInt64` - `DivUInt32` / `DivUInt64` - `DivSInt32` / `DivSInt64` - and more Also more fast paths for some getMaxBits calculations
Diffstat (limited to 'src')
-rw-r--r--src/ir/bits.h187
-rw-r--r--src/support/bits.cpp8
-rw-r--r--src/support/bits.h6
3 files changed, 186 insertions, 15 deletions
diff --git a/src/ir/bits.h b/src/ir/bits.h
index 20d97f13f..132f7ec4d 100644
--- a/src/ir/bits.h
+++ b/src/ir/bits.h
@@ -128,35 +128,85 @@ struct DummyLocalInfoProvider {
template<typename LocalInfoProvider = DummyLocalInfoProvider>
Index getMaxBits(Expression* curr,
LocalInfoProvider* localInfoProvider = nullptr) {
- if (auto* const_ = curr->dynCast<Const>()) {
+ if (auto* c = curr->dynCast<Const>()) {
switch (curr->type.getBasic()) {
case Type::i32:
- return 32 - const_->value.countLeadingZeroes().geti32();
+ return 32 - c->value.countLeadingZeroes().geti32();
case Type::i64:
- return 64 - const_->value.countLeadingZeroes().geti64();
+ return 64 - c->value.countLeadingZeroes().geti64();
default:
WASM_UNREACHABLE("invalid type");
}
} else if (auto* binary = curr->dynCast<Binary>()) {
switch (binary->op) {
// 32-bit
- case AddInt32:
- case SubInt32:
- case MulInt32:
- case DivSInt32:
- case DivUInt32:
- case RemSInt32:
- case RemUInt32:
case RotLInt32:
case RotRInt32:
+ case SubInt32:
+ return 32;
+ case AddInt32: {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1);
+ }
+ case MulInt32: {
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(32), maxBitsLeft + maxBitsRight);
+ }
+ case DivSInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // If either side might be negative, then the result will be negative
+ if (maxBitsLeft == 32 || c->value.geti32() < 0) {
+ return 32;
+ }
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return 32;
+ }
+ case DivUInt32: {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return maxBitsLeft;
+ }
+ case RemSInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft is negative
+ if (maxBitsLeft == 32) {
+ return 32;
+ }
+ auto bitsRight = Index(CeilLog2(c->value.geti32()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 32;
+ }
+ case RemUInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto bitsRight = Index(CeilLog2(c->value.geti32()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
return 32;
- case AndInt32:
+ }
+ case AndInt32: {
return std::min(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
+ }
case OrInt32:
- case XorInt32:
- return std::max(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
+ case XorInt32: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 32) {
+ return 32;
+ }
+ return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
case ShlInt32: {
if (auto* shifts = binary->right->dynCast<Const>()) {
return std::min(Index(32),
@@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr,
case ShrSInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ // if maxBits is negative
if (maxBits == 32) {
return 32;
}
@@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr,
}
return 32;
}
- // 64-bit TODO
+ case RotLInt64:
+ case RotRInt64:
+ case SubInt64:
+ return 64;
+ case AddInt64: {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1);
+ }
+ case MulInt64: {
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(64), maxBitsLeft + maxBitsRight);
+ }
+ case DivSInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft or right const value is negative
+ if (maxBitsLeft == 64 || c->value.geti64() < 0) {
+ return 64;
+ }
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return 64;
+ }
+ case DivUInt64: {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return maxBitsLeft;
+ }
+ case RemSInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft is negative
+ if (maxBitsLeft == 64) {
+ return 64;
+ }
+ auto bitsRight = Index(CeilLog2(c->value.geti64()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 64;
+ }
+ case RemUInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto bitsRight = Index(CeilLog2(c->value.geti64()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 64;
+ }
+ case AndInt64: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ return std::min(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
+ case OrInt64:
+ case XorInt64: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 64) {
+ return 64;
+ }
+ return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
+ case ShlInt64: {
+ if (auto* shifts = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(64),
+ Bits::getEffectiveShifts(shifts) + maxBits);
+ }
+ return 64;
+ }
+ case ShrUInt64: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 64;
+ }
+ case ShrSInt64: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 64) {
+ return 64;
+ }
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 64;
+ }
// comparisons
case EqInt32:
case NeInt32:
@@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr,
case GtUInt32:
case GeSInt32:
case GeUInt32:
+
case EqInt64:
case NeInt64:
case LtSInt64:
@@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr,
case GtUInt64:
case GeSInt64:
case GeUInt64:
+
case EqFloat32:
case NeFloat32:
case LtFloat32:
case LeFloat32:
case GtFloat32:
case GeFloat32:
+
case EqFloat64:
case NeFloat64:
case LtFloat64:
@@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr,
case EqZInt64:
return 1;
case WrapInt64:
+ case ExtendUInt32:
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
+ case ExtendSInt32: {
+ auto maxBits = getMaxBits(unary->value, localInfoProvider);
+ return maxBits == 32 ? Index(64) : maxBits;
+ }
default: {
}
}
diff --git a/src/support/bits.cpp b/src/support/bits.cpp
index 992b97955..e60f2365b 100644
--- a/src/support/bits.cpp
+++ b/src/support/bits.cpp
@@ -152,6 +152,14 @@ template<> int CountLeadingZeroes<uint64_t>(uint64_t v) {
#endif
}
+template<> int CeilLog2<uint32_t>(uint32_t v) {
+ return 32 - CountLeadingZeroes(v - 1);
+}
+
+template<> int CeilLog2<uint64_t>(uint64_t v) {
+ return 64 - CountLeadingZeroes(v - 1);
+}
+
uint32_t Log2(uint32_t v) {
switch (v) {
default:
diff --git a/src/support/bits.h b/src/support/bits.h
index bd91fdec6..a927a2832 100644
--- a/src/support/bits.h
+++ b/src/support/bits.h
@@ -40,6 +40,7 @@ template<typename T> int PopCount(T);
template<typename T> uint32_t BitReverse(T);
template<typename T> int CountTrailingZeroes(T);
template<typename T> int CountLeadingZeroes(T);
+template<typename T> int CeilLog2(T);
#ifndef wasm_support_bits_definitions
// The template specializations are provided elsewhere.
@@ -52,6 +53,8 @@ extern template int CountTrailingZeroes(uint32_t);
extern template int CountTrailingZeroes(uint64_t);
extern template int CountLeadingZeroes(uint32_t);
extern template int CountLeadingZeroes(uint64_t);
+extern template int CeilLog2(uint32_t);
+extern template int CeilLog2(uint64_t);
#endif
// Convenience signed -> unsigned. It usually doesn't make much sense to use bit
@@ -65,6 +68,9 @@ template<typename T> int CountTrailingZeroes(T v) {
template<typename T> int CountLeadingZeroes(T v) {
return CountLeadingZeroes(typename std::make_unsigned<T>::type(v));
}
+template<typename T> int CeilLog2(T v) {
+ return CeilLog2(typename std::make_unsigned<T>::type(v));
+}
template<typename T> bool IsPowerOf2(T v) {
return v != 0 && (v & (v - 1)) == 0;
}