summaryrefslogtreecommitdiff
path: root/src/ir/bits.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/ir/bits.h')
-rw-r--r--src/ir/bits.h187
1 files changed, 172 insertions, 15 deletions
diff --git a/src/ir/bits.h b/src/ir/bits.h
index 20d97f13f..132f7ec4d 100644
--- a/src/ir/bits.h
+++ b/src/ir/bits.h
@@ -128,35 +128,85 @@ struct DummyLocalInfoProvider {
template<typename LocalInfoProvider = DummyLocalInfoProvider>
Index getMaxBits(Expression* curr,
LocalInfoProvider* localInfoProvider = nullptr) {
- if (auto* const_ = curr->dynCast<Const>()) {
+ if (auto* c = curr->dynCast<Const>()) {
switch (curr->type.getBasic()) {
case Type::i32:
- return 32 - const_->value.countLeadingZeroes().geti32();
+ return 32 - c->value.countLeadingZeroes().geti32();
case Type::i64:
- return 64 - const_->value.countLeadingZeroes().geti64();
+ return 64 - c->value.countLeadingZeroes().geti64();
default:
WASM_UNREACHABLE("invalid type");
}
} else if (auto* binary = curr->dynCast<Binary>()) {
switch (binary->op) {
// 32-bit
- case AddInt32:
- case SubInt32:
- case MulInt32:
- case DivSInt32:
- case DivUInt32:
- case RemSInt32:
- case RemUInt32:
case RotLInt32:
case RotRInt32:
+ case SubInt32:
+ return 32;
+ case AddInt32: {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ return std::min(Index(32), std::max(maxBitsLeft, maxBitsRight) + 1);
+ }
+ case MulInt32: {
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(32), maxBitsLeft + maxBitsRight);
+ }
+ case DivSInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // If either side might be negative, then the result will be negative
+ if (maxBitsLeft == 32 || c->value.geti32() < 0) {
+ return 32;
+ }
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return 32;
+ }
+ case DivUInt32: {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return maxBitsLeft;
+ }
+ case RemSInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft is negative
+ if (maxBitsLeft == 32) {
+ return 32;
+ }
+ auto bitsRight = Index(CeilLog2(c->value.geti32()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 32;
+ }
+ case RemUInt32: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto bitsRight = Index(CeilLog2(c->value.geti32()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
return 32;
- case AndInt32:
+ }
+ case AndInt32: {
return std::min(getMaxBits(binary->left, localInfoProvider),
getMaxBits(binary->right, localInfoProvider));
+ }
case OrInt32:
- case XorInt32:
- return std::max(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
+ case XorInt32: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 32) {
+ return 32;
+ }
+ return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
case ShlInt32: {
if (auto* shifts = binary->right->dynCast<Const>()) {
return std::min(Index(32),
@@ -178,6 +228,7 @@ Index getMaxBits(Expression* curr,
case ShrSInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ // if maxBits is negative
if (maxBits == 32) {
return 32;
}
@@ -188,7 +239,105 @@ Index getMaxBits(Expression* curr,
}
return 32;
}
- // 64-bit TODO
+ case RotLInt64:
+ case RotRInt64:
+ case SubInt64:
+ return 64;
+ case AddInt64: {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ return std::min(Index(64), std::max(maxBitsLeft, maxBitsRight) + 1);
+ }
+ case MulInt64: {
+ auto maxBitsRight = getMaxBits(binary->right, localInfoProvider);
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(64), maxBitsLeft + maxBitsRight);
+ }
+ case DivSInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft or right const value is negative
+ if (maxBitsLeft == 64 || c->value.geti64() < 0) {
+ return 64;
+ }
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return 64;
+ }
+ case DivUInt64: {
+ int32_t maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ if (auto* c = binary->right->dynCast<Const>()) {
+ int32_t bitsRight = getMaxBits(c);
+ return std::max(0, maxBitsLeft - bitsRight + 1);
+ }
+ return maxBitsLeft;
+ }
+ case RemSInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ // if maxBitsLeft is negative
+ if (maxBitsLeft == 64) {
+ return 64;
+ }
+ auto bitsRight = Index(CeilLog2(c->value.geti64()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 64;
+ }
+ case RemUInt64: {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ auto maxBitsLeft = getMaxBits(binary->left, localInfoProvider);
+ auto bitsRight = Index(CeilLog2(c->value.geti64()));
+ return std::min(maxBitsLeft, bitsRight);
+ }
+ return 64;
+ }
+ case AndInt64: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ return std::min(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
+ case OrInt64:
+ case XorInt64: {
+ auto maxBits = getMaxBits(binary->right, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 64) {
+ return 64;
+ }
+ return std::max(getMaxBits(binary->left, localInfoProvider), maxBits);
+ }
+ case ShlInt64: {
+ if (auto* shifts = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ return std::min(Index(64),
+ Bits::getEffectiveShifts(shifts) + maxBits);
+ }
+ return 64;
+ }
+ case ShrUInt64: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 64;
+ }
+ case ShrSInt64: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ // if maxBits is negative
+ if (maxBits == 64) {
+ return 64;
+ }
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 64;
+ }
// comparisons
case EqInt32:
case NeInt32:
@@ -200,6 +349,7 @@ Index getMaxBits(Expression* curr,
case GtUInt32:
case GeSInt32:
case GeUInt32:
+
case EqInt64:
case NeInt64:
case LtSInt64:
@@ -210,12 +360,14 @@ Index getMaxBits(Expression* curr,
case GtUInt64:
case GeSInt64:
case GeUInt64:
+
case EqFloat32:
case NeFloat32:
case LtFloat32:
case LeFloat32:
case GtFloat32:
case GeFloat32:
+
case EqFloat64:
case NeFloat64:
case LtFloat64:
@@ -240,7 +392,12 @@ Index getMaxBits(Expression* curr,
case EqZInt64:
return 1;
case WrapInt64:
+ case ExtendUInt32:
return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
+ case ExtendSInt32: {
+ auto maxBits = getMaxBits(unary->value, localInfoProvider);
+ return maxBits == 32 ? Index(64) : maxBits;
+ }
default: {
}
}