summaryrefslogtreecommitdiff
path: root/src/passes/OptimizeInstructions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r--src/passes/OptimizeInstructions.cpp120
1 files changed, 23 insertions, 97 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index e8e7a5f4a..79c01dee7 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -250,80 +250,6 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) {
}
}
-// Check if an expression is a sign-extend, and if so, returns the value
-// that is extended, otherwise nullptr
-static Expression* getSignExt(Expression* curr) {
- if (auto* outer = curr->dynCast<Binary>()) {
- if (outer->op == ShrSInt32) {
- if (auto* outerConst = outer->right->dynCast<Const>()) {
- if (auto* inner = outer->left->dynCast<Binary>()) {
- if (inner->op == ShlInt32) {
- if (auto* innerConst = inner->right->dynCast<Const>()) {
- if (outerConst->value == innerConst->value) {
- return inner->left;
- }
- }
- }
- }
- }
- }
- }
- return nullptr;
-}
-
-// gets the size of the sign-extended value
-static Index getSignExtBits(Expression* curr) {
- return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32();
-}
-
-// Check if an expression is almost a sign-extend: perhaps the inner shift
-// is too large. We can split the shifts in that case, which is sometimes
-// useful (e.g. if we can remove the signext)
-static Expression* getAlmostSignExt(Expression* curr) {
- if (auto* outer = curr->dynCast<Binary>()) {
- if (outer->op == ShrSInt32) {
- if (auto* outerConst = outer->right->dynCast<Const>()) {
- if (auto* inner = outer->left->dynCast<Binary>()) {
- if (inner->op == ShlInt32) {
- if (auto* innerConst = inner->right->dynCast<Const>()) {
- if (outerConst->value.leU(innerConst->value).geti32()) {
- return inner->left;
- }
- }
- }
- }
- }
- }
- }
- return nullptr;
-}
-
-// gets the size of the almost sign-extended value, as well as the
-// extra shifts, if any
-static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) {
- extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() -
- curr->cast<Binary>()->right->cast<Const>()->value.geti32();
- return getSignExtBits(curr);
-}
-
-// get a mask to keep only the low # of bits
-static int32_t lowBitMask(int32_t bits) {
- uint32_t ret = -1;
- if (bits >= 32) return ret;
- return ret >> (32 - bits);
-}
-
-// checks if the input is a mask of lower bits, i.e., all 1s up to some high bit, and all zeros
-// from there. returns the number of masked bits, or 0 if this is not such a mask
-static uint32_t getMaskedBits(int32_t mask) {
- if (mask == -1) return 32; // all the bits
- if (mask == 0) return 0; // trivially not a mask
- // otherwise, see if adding one turns this into a 1-bit thing, 00011111 + 1 => 00100000
- if (PopCount(mask + 1) != 1) return 0;
- // this is indeed a mask
- return 32 - CountLeadingZeroes(mask);
-}
-
// looks through fallthrough operations, like tee_local, block fallthrough, etc.
// too and block fallthroughs, etc.
Expression* getFallthrough(Expression* curr) {
@@ -386,8 +312,8 @@ struct LocalScanner : PostWalker<LocalScanner, Visitor<LocalScanner>> {
auto& info = localInfo[curr->index];
info.maxBits = std::max(info.maxBits, getMaxBits(value, this));
auto signExtBits = LocalInfo::kUnknown;
- if (getSignExt(value)) {
- signExtBits = getSignExtBits(value);
+ if (Properties::getSignExtValue(value)) {
+ signExtBits = Properties::getSignExtBits(value);
} else if (auto* load = value->dynCast<Load>()) {
if (load->signed_) {
signExtBits = load->bytes * 8;
@@ -475,9 +401,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
std::swap(binary->left, binary->right);
}
}
- if (auto* ext = getAlmostSignExt(binary)) {
+ if (auto* ext = Properties::getAlmostSignExt(binary)) {
Index extraShifts;
- auto bits = getAlmostSignExtBits(binary, extraShifts);
+ auto bits = Properties::getAlmostSignExtBits(binary, extraShifts);
if (auto* load = getFallthrough(ext)->dynCast<Load>()) {
// pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
if ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16)) {
@@ -495,28 +421,28 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
}
} else if (binary->op == EqInt32 || binary->op == NeInt32) {
if (auto* c = binary->right->dynCast<Const>()) {
- if (auto* ext = getSignExt(binary->left)) {
+ if (auto* ext = Properties::getSignExtValue(binary->left)) {
// we are comparing a sign extend to a constant, which means we can use a cheaper zext
- auto bits = getSignExtBits(binary->left);
+ auto bits = Properties::getSignExtBits(binary->left);
binary->left = makeZeroExt(ext, bits);
// the const we compare to only needs the relevant bits
- c->value = c->value.and_(Literal(lowBitMask(bits)));
+ c->value = c->value.and_(Literal(Bits::lowBitMask(bits)));
return binary;
}
if (binary->op == EqInt32 && c->value.geti32() == 0) {
// equal 0 => eqz
return Builder(*getModule()).makeUnary(EqZInt32, binary->left);
}
- } else if (auto* left = getSignExt(binary->left)) {
- if (auto* right = getSignExt(binary->right)) {
+ } else if (auto* left = Properties::getSignExtValue(binary->left)) {
+ if (auto* right = Properties::getSignExtValue(binary->right)) {
// we are comparing two sign-exts, so we may as well replace both with cheaper zexts
- auto bits = getSignExtBits(binary->left);
+ auto bits = Properties::getSignExtBits(binary->left);
binary->left = makeZeroExt(left, bits);
binary->right = makeZeroExt(right, bits);
return binary;
} else if (auto* load = binary->right->dynCast<Load>()) {
// we are comparing a load to a sign-ext, we may be able to switch to zext
- auto leftBits = getSignExtBits(binary->left);
+ auto leftBits = Properties::getSignExtBits(binary->left);
if (load->signed_ && leftBits == load->bytes * 8) {
load->signed_ = false;
binary->left = makeZeroExt(left, leftBits);
@@ -524,9 +450,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
}
}
} else if (auto* load = binary->left->dynCast<Load>()) {
- if (auto* right = getSignExt(binary->right)) {
+ if (auto* right = Properties::getSignExtValue(binary->right)) {
// we are comparing a load to a sign-ext, we may be able to switch to zext
- auto rightBits = getSignExtBits(binary->right);
+ auto rightBits = Properties::getSignExtBits(binary->right);
if (load->signed_ && rightBits == load->bytes * 8) {
load->signed_ = false;
binary->right = makeZeroExt(right, rightBits);
@@ -553,7 +479,7 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
load->signed_ = false;
return binary->left;
}
- } else if (auto maskedBits = getMaskedBits(mask)) {
+ } else if (auto maskedBits = Bits::getMaskedBits(mask)) {
if (getMaxBits(binary->left, this) <= maskedBits) {
// a mask of lower bits is not needed if we are already smaller
return binary->left;
@@ -618,9 +544,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
}
}
// eqz of a sign extension can be of zero-extension
- if (auto* ext = getSignExt(unary->value)) {
+ if (auto* ext = Properties::getSignExtValue(unary->value)) {
// we are comparing a sign extend to a constant, which means we can use a cheaper zext
- auto bits = getSignExtBits(unary->value);
+ auto bits = Properties::getSignExtBits(unary->value);
unary->value = makeZeroExt(ext, bits);
return unary;
}
@@ -674,10 +600,10 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
}
}
}
- } else if (auto* ext = getSignExt(binary)) {
+ } else if (auto* ext = Properties::getSignExtValue(binary)) {
// if sign extending the exact bit size we store, we can skip the extension
// if extending something bigger, then we just alter bits we don't save anyhow
- if (getSignExtBits(binary) >= store->bytes * 8) {
+ if (Properties::getSignExtBits(binary) >= store->bytes * 8) {
store->value = ext;
}
}
@@ -724,9 +650,9 @@ private:
}
}
}
- if (auto* ext = getSignExt(binary)) {
+ if (auto* ext = Properties::getSignExtValue(binary)) {
// use a cheaper zero-extent, we just care about the boolean value anyhow
- return makeZeroExt(ext, getSignExtBits(binary));
+ return makeZeroExt(ext, Properties::getSignExtBits(binary));
}
} else if (auto* block = boolean->dynCast<Block>()) {
if (block->type == i32 && block->list.size() > 0) {
@@ -913,7 +839,7 @@ private:
Expression* makeZeroExt(Expression* curr, int32_t bits) {
Builder builder(*getModule());
- return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits))));
+ return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits))));
}
// given an "almost" sign extend - either a proper one, or it
@@ -933,8 +859,8 @@ private:
// check if an expression is already sign-extended
bool isSignExted(Expression* curr, Index bits) {
- if (getSignExt(curr)) {
- return getSignExtBits(curr) == bits;
+ if (Properties::getSignExtValue(curr)) {
+ return Properties::getSignExtBits(curr) == bits;
}
if (auto* get = curr->dynCast<GetLocal>()) {
// check what we know about the local