diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 120 |
1 files changed, 23 insertions, 97 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index e8e7a5f4a..79c01dee7 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -250,80 +250,6 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { } } -// Check if an expression is a sign-extend, and if so, returns the value -// that is extended, otherwise nullptr -static Expression* getSignExt(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (outerConst->value == innerConst->value) { - return inner->left; - } - } - } - } - } - } - } - return nullptr; -} - -// gets the size of the sign-extended value -static Index getSignExtBits(Expression* curr) { - return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); -} - -// Check if an expression is almost a sign-extend: perhaps the inner shift -// is too large. We can split the shifts in that case, which is sometimes -// useful (e.g. if we can remove the signext) -static Expression* getAlmostSignExt(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (outerConst->value.leU(innerConst->value).geti32()) { - return inner->left; - } - } - } - } - } - } - } - return nullptr; -} - -// gets the size of the almost sign-extended value, as well as the -// extra shifts, if any -static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { - extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() - - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); - return getSignExtBits(curr); -} - -// get a mask to keep only the low # of bits -static int32_t lowBitMask(int32_t bits) { - uint32_t ret = -1; - if (bits >= 32) return ret; - return ret >> (32 - bits); -} - -// checks if the input is a mask of lower bits, i.e., all 1s up to some high bit, and all zeros -// from there. returns the number of masked bits, or 0 if this is not such a mask -static uint32_t getMaskedBits(int32_t mask) { - if (mask == -1) return 32; // all the bits - if (mask == 0) return 0; // trivially not a mask - // otherwise, see if adding one turns this into a 1-bit thing, 00011111 + 1 => 00100000 - if (PopCount(mask + 1) != 1) return 0; - // this is indeed a mask - return 32 - CountLeadingZeroes(mask); -} - // looks through fallthrough operations, like tee_local, block fallthrough, etc. // too and block fallthroughs, etc. Expression* getFallthrough(Expression* curr) { @@ -386,8 +312,8 @@ struct LocalScanner : PostWalker<LocalScanner, Visitor<LocalScanner>> { auto& info = localInfo[curr->index]; info.maxBits = std::max(info.maxBits, getMaxBits(value, this)); auto signExtBits = LocalInfo::kUnknown; - if (getSignExt(value)) { - signExtBits = getSignExtBits(value); + if (Properties::getSignExtValue(value)) { + signExtBits = Properties::getSignExtBits(value); } else if (auto* load = value->dynCast<Load>()) { if (load->signed_) { signExtBits = load->bytes * 8; @@ -475,9 +401,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, std::swap(binary->left, binary->right); } } - if (auto* ext = getAlmostSignExt(binary)) { + if (auto* ext = Properties::getAlmostSignExt(binary)) { Index extraShifts; - auto bits = getAlmostSignExtBits(binary, extraShifts); + auto bits = Properties::getAlmostSignExtBits(binary, extraShifts); if (auto* load = getFallthrough(ext)->dynCast<Load>()) { // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. if ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16)) { @@ -495,28 +421,28 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } else if (binary->op == EqInt32 || binary->op == NeInt32) { if (auto* c = binary->right->dynCast<Const>()) { - if (auto* ext = getSignExt(binary->left)) { + if (auto* ext = Properties::getSignExtValue(binary->left)) { // we are comparing a sign extend to a constant, which means we can use a cheaper zext - auto bits = getSignExtBits(binary->left); + auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(ext, bits); // the const we compare to only needs the relevant bits - c->value = c->value.and_(Literal(lowBitMask(bits))); + c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); return binary; } if (binary->op == EqInt32 && c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } - } else if (auto* left = getSignExt(binary->left)) { - if (auto* right = getSignExt(binary->right)) { + } else if (auto* left = Properties::getSignExtValue(binary->left)) { + if (auto* right = Properties::getSignExtValue(binary->right)) { // we are comparing two sign-exts, so we may as well replace both with cheaper zexts - auto bits = getSignExtBits(binary->left); + auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(left, bits); binary->right = makeZeroExt(right, bits); return binary; } else if (auto* load = binary->right->dynCast<Load>()) { // we are comparing a load to a sign-ext, we may be able to switch to zext - auto leftBits = getSignExtBits(binary->left); + auto leftBits = Properties::getSignExtBits(binary->left); if (load->signed_ && leftBits == load->bytes * 8) { load->signed_ = false; binary->left = makeZeroExt(left, leftBits); @@ -524,9 +450,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } else if (auto* load = binary->left->dynCast<Load>()) { - if (auto* right = getSignExt(binary->right)) { + if (auto* right = Properties::getSignExtValue(binary->right)) { // we are comparing a load to a sign-ext, we may be able to switch to zext - auto rightBits = getSignExtBits(binary->right); + auto rightBits = Properties::getSignExtBits(binary->right); if (load->signed_ && rightBits == load->bytes * 8) { load->signed_ = false; binary->right = makeZeroExt(right, rightBits); @@ -553,7 +479,7 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, load->signed_ = false; return binary->left; } - } else if (auto maskedBits = getMaskedBits(mask)) { + } else if (auto maskedBits = Bits::getMaskedBits(mask)) { if (getMaxBits(binary->left, this) <= maskedBits) { // a mask of lower bits is not needed if we are already smaller return binary->left; @@ -618,9 +544,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } // eqz of a sign extension can be of zero-extension - if (auto* ext = getSignExt(unary->value)) { + if (auto* ext = Properties::getSignExtValue(unary->value)) { // we are comparing a sign extend to a constant, which means we can use a cheaper zext - auto bits = getSignExtBits(unary->value); + auto bits = Properties::getSignExtBits(unary->value); unary->value = makeZeroExt(ext, bits); return unary; } @@ -674,10 +600,10 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } - } else if (auto* ext = getSignExt(binary)) { + } else if (auto* ext = Properties::getSignExtValue(binary)) { // if sign extending the exact bit size we store, we can skip the extension // if extending something bigger, then we just alter bits we don't save anyhow - if (getSignExtBits(binary) >= store->bytes * 8) { + if (Properties::getSignExtBits(binary) >= store->bytes * 8) { store->value = ext; } } @@ -724,9 +650,9 @@ private: } } } - if (auto* ext = getSignExt(binary)) { + if (auto* ext = Properties::getSignExtValue(binary)) { // use a cheaper zero-extent, we just care about the boolean value anyhow - return makeZeroExt(ext, getSignExtBits(binary)); + return makeZeroExt(ext, Properties::getSignExtBits(binary)); } } else if (auto* block = boolean->dynCast<Block>()) { if (block->type == i32 && block->list.size() > 0) { @@ -913,7 +839,7 @@ private: Expression* makeZeroExt(Expression* curr, int32_t bits) { Builder builder(*getModule()); - return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits)))); + return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits)))); } // given an "almost" sign extend - either a proper one, or it @@ -933,8 +859,8 @@ private: // check if an expression is already sign-extended bool isSignExted(Expression* curr, Index bits) { - if (getSignExt(curr)) { - return getSignExtBits(curr) == bits; + if (Properties::getSignExtValue(curr)) { + return Properties::getSignExtBits(curr) == bits; } if (auto* get = curr->dynCast<GetLocal>()) { // check what we know about the local |