summaryrefslogtreecommitdiff
path: root/src/passes/OptimizeInstructions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r--src/passes/OptimizeInstructions.cpp153
1 files changed, 4 insertions, 149 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 860ea3ac7..7f66ca7ee 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -22,6 +22,7 @@
#include <type_traits>
#include <ir/abstract.h>
+#include <ir/bits.h>
#include <ir/cost.h>
#include <ir/effects.h>
#include <ir/literal-utils.h>
@@ -45,152 +46,6 @@ Name F32_EXPR = "f32.expr";
Name F64_EXPR = "f64.expr";
Name ANY_EXPR = "any.expr";
-// Utilities
-
-// returns the maximum amount of bits used in an integer expression
-// not extremely precise (doesn't look into add operands, etc.)
-// LocalInfoProvider is an optional class that can provide answers about
-// local.get.
-template<typename LocalInfoProvider>
-Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) {
- if (auto* const_ = curr->dynCast<Const>()) {
- switch (curr->type.getSingle()) {
- case Type::i32:
- return 32 - const_->value.countLeadingZeroes().geti32();
- case Type::i64:
- return 64 - const_->value.countLeadingZeroes().geti64();
- default:
- WASM_UNREACHABLE("invalid type");
- }
- } else if (auto* binary = curr->dynCast<Binary>()) {
- switch (binary->op) {
- // 32-bit
- case AddInt32:
- case SubInt32:
- case MulInt32:
- case DivSInt32:
- case DivUInt32:
- case RemSInt32:
- case RemUInt32:
- case RotLInt32:
- case RotRInt32:
- return 32;
- case AndInt32:
- return std::min(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
- case OrInt32:
- case XorInt32:
- return std::max(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
- case ShlInt32: {
- if (auto* shifts = binary->right->dynCast<Const>()) {
- return std::min(Index(32),
- getMaxBits(binary->left, localInfoProvider) +
- Bits::getEffectiveShifts(shifts));
- }
- return 32;
- }
- case ShrUInt32: {
- if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left, localInfoProvider);
- auto shifts =
- std::min(Index(Bits::getEffectiveShifts(shift)),
- maxBits); // can ignore more shifts than zero us out
- return std::max(Index(0), maxBits - shifts);
- }
- return 32;
- }
- case ShrSInt32: {
- if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left, localInfoProvider);
- if (maxBits == 32) {
- return 32;
- }
- auto shifts =
- std::min(Index(Bits::getEffectiveShifts(shift)),
- maxBits); // can ignore more shifts than zero us out
- return std::max(Index(0), maxBits - shifts);
- }
- return 32;
- }
- // 64-bit TODO
- // comparisons
- case EqInt32:
- case NeInt32:
- case LtSInt32:
- case LtUInt32:
- case LeSInt32:
- case LeUInt32:
- case GtSInt32:
- case GtUInt32:
- case GeSInt32:
- case GeUInt32:
- case EqInt64:
- case NeInt64:
- case LtSInt64:
- case LtUInt64:
- case LeSInt64:
- case LeUInt64:
- case GtSInt64:
- case GtUInt64:
- case GeSInt64:
- case GeUInt64:
- case EqFloat32:
- case NeFloat32:
- case LtFloat32:
- case LeFloat32:
- case GtFloat32:
- case GeFloat32:
- case EqFloat64:
- case NeFloat64:
- case LtFloat64:
- case LeFloat64:
- case GtFloat64:
- case GeFloat64:
- return 1;
- default: {}
- }
- } else if (auto* unary = curr->dynCast<Unary>()) {
- switch (unary->op) {
- case ClzInt32:
- case CtzInt32:
- case PopcntInt32:
- return 6;
- case ClzInt64:
- case CtzInt64:
- case PopcntInt64:
- return 7;
- case EqZInt32:
- case EqZInt64:
- return 1;
- case WrapInt64:
- return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
- default: {}
- }
- } else if (auto* set = curr->dynCast<LocalSet>()) {
- // a tee passes through the value
- return getMaxBits(set->value, localInfoProvider);
- } else if (auto* get = curr->dynCast<LocalGet>()) {
- return localInfoProvider->getMaxBitsForLocal(get);
- } else if (auto* load = curr->dynCast<Load>()) {
- // if signed, then the sign-extension might fill all the bits
- // if unsigned, then we have a limit
- if (LoadUtils::isSignRelevant(load) && !load->signed_) {
- return 8 * load->bytes;
- }
- }
- switch (curr->type.getSingle()) {
- case Type::i32:
- return 32;
- case Type::i64:
- return 64;
- case Type::unreachable:
- return 64; // not interesting, but don't crash
- default:
- WASM_UNREACHABLE("invalid type");
- }
-}
-
// Useful information about locals
struct LocalInfo {
static const Index kUnknown = Index(-1);
@@ -243,7 +98,7 @@ struct LocalScanner : PostWalker<LocalScanner> {
auto* value = Properties::getFallthrough(
curr->value, passOptions, getModule()->features);
auto& info = localInfo[curr->index];
- info.maxBits = std::max(info.maxBits, getMaxBits(value, this));
+ info.maxBits = std::max(info.maxBits, Bits::getMaxBits(value, this));
auto signExtBits = LocalInfo::kUnknown;
if (Properties::getSignExtValue(value)) {
signExtBits = Properties::getSignExtBits(value);
@@ -373,7 +228,7 @@ struct OptimizeInstructions
// if the sign-extend input cannot have a sign bit, we don't need it
// we also don't need it if it already has an identical-sized sign
// extend
- if (getMaxBits(ext, this) + extraShifts < bits ||
+ if (Bits::getMaxBits(ext, this) + extraShifts < bits ||
isSignExted(ext, bits)) {
return removeAlmostSignExt(binary);
}
@@ -538,7 +393,7 @@ struct OptimizeInstructions
return binary->left;
}
} else if (auto maskedBits = Bits::getMaskedBits(mask)) {
- if (getMaxBits(binary->left, this) <= maskedBits) {
+ if (Bits::getMaxBits(binary->left, this) <= maskedBits) {
// a mask of lower bits is not needed if we are already smaller
return binary->left;
}