summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2017-02-16 21:37:39 -0800
committerGitHub <noreply@github.com>2017-02-16 21:37:39 -0800
commita78dddbcf2bf9f23840c7074ce16c04a8d55c3df (patch)
treedccd48fa4d610267637a754ebfde189f6e9222d9 /src
parent1d72468f464a725afdee4b28c5f6bd4d7a631c92 (diff)
downloadbinaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.tar.gz
binaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.tar.bz2
binaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.zip
Optimize sign-extends (#902)
* optimize sign-extend output * optimize sign-extend input
Diffstat (limited to 'src')
-rw-r--r--src/passes/OptimizeInstructions.cpp157
1 files changed, 139 insertions, 18 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 8f7bf63e7..c374a15ac 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -160,9 +160,115 @@ struct Match {
};
return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy);
}
+};
+// Utilities
-};
+// returns the maximum amount of bits used in an integer expression
+// not extremely precise (doesn't look into add operands, etc.)
+static Index getMaxBits(Expression* curr) {
+ if (auto* const_ = curr->dynCast<Const>()) {
+ switch (curr->type) {
+ case i32: return 32 - const_->value.countLeadingZeroes().geti32();
+ case i64: return 64 - const_->value.countLeadingZeroes().geti64();
+ default: WASM_UNREACHABLE();
+ }
+ } else if (auto* binary = curr->dynCast<Binary>()) {
+ switch (binary->op) {
+ // 32-bit
+ case AddInt32: case SubInt32: case MulInt32:
+ case DivSInt32: case DivUInt32: case RemSInt32:
+ case RemUInt32: case RotLInt32: case RotRInt32: return 32;
+ case AndInt32: case XorInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right));
+ case OrInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right));
+ case ShlInt32: {
+ if (auto* shifts = binary->right->dynCast<Const>()) {
+ return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32());
+ }
+ return 32;
+ }
+ case ShrUInt32: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left);
+ auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 32;
+ }
+ case ShrSInt32: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left);
+ if (maxBits == 32) return 32;
+ auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 32;
+ }
+ // 64-bit TODO
+ // comparisons
+ case EqInt32: case NeInt32: case LtSInt32:
+ case LtUInt32: case LeSInt32: case LeUInt32:
+ case GtSInt32: case GtUInt32: case GeSInt32:
+ case GeUInt32:
+ case EqInt64: case NeInt64: case LtSInt64:
+ case LtUInt64: case LeSInt64: case LeUInt64:
+ case GtSInt64: case GtUInt64: case GeSInt64:
+ case GeUInt64:
+ case EqFloat32: case NeFloat32:
+ case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32:
+ case EqFloat64: case NeFloat64:
+ case LtFloat64: case LeFloat64: case GtFloat64: case GeFloat64: return 1;
+ default: {}
+ }
+ } else if (auto* unary = curr->dynCast<Unary>()) {
+ switch (unary->op) {
+ case ClzInt32: case CtzInt32: case PopcntInt32: return 5;
+ case ClzInt64: case CtzInt64: case PopcntInt64: return 6;
+ case EqZInt32: case EqZInt64: return 1;
+ case WrapInt64: return std::min(Index(32), getMaxBits(unary->value));
+ default: {}
+ }
+ }
+ switch (curr->type) {
+ case i32: return 32;
+ case i64: return 64;
+ case unreachable: return 64; // not interesting, but don't crash
+ default: WASM_UNREACHABLE();
+ }
+}
+
+// Check if an expression is a sign-extend, and if so, returns the value
+// that is extended, otherwise nullptr
+static Expression* getSignExt(Expression* curr) {
+ if (auto* outer = curr->dynCast<Binary>()) {
+ if (outer->op == ShrSInt32) {
+ if (auto* outerConst = outer->right->dynCast<Const>()) {
+ if (auto* inner = outer->left->dynCast<Binary>()) {
+ if (inner->op == ShlInt32) {
+ if (auto* innerConst = inner->right->dynCast<Const>()) {
+ if (outerConst->value == innerConst->value) {
+ return inner->left;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
+// gets the size of the sign-extended value
+static Index getSignExtBits(Expression* curr) {
+ return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32();
+}
+
+// get a mask to keep only the low # of bits
+static int32_t lowBitMask(int32_t bits) {
+ uint32_t ret = -1;
+ if (bits >= 32) return ret;
+ return ret >> (32 - bits);
+}
// Main pass class
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> {
@@ -215,32 +321,42 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
std::swap(binary->left, binary->right);
}
}
- // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
- if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) {
- auto shifts = binary->right->cast<Const>()->value.geti32();
- if (shifts == 24 || shifts == 16) {
- auto* left = binary->left->dynCast<Binary>();
- if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) {
- auto* load = left->left->dynCast<Load>();
- if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) {
- load->signed_ = true;
- return load;
- }
- }
+ if (auto* ext = getSignExt(binary)) {
+ auto bits = getSignExtBits(binary);
+ auto* load = ext->dynCast<Load>();
+ // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
+ if (load && ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) {
+ load->signed_ = true;
+ return load;
+ }
+ // if the sign-extend input cannot have a sign bit, we don't need it
+ if (getMaxBits(ext) < bits) {
+ return ext;
}
} else if (binary->op == EqInt32) {
if (auto* c = binary->right->dynCast<Const>()) {
+ if (auto* ext = getSignExt(binary->left)) {
+ // we are comparing a sign extend to a constant, which means we can use a cheaper zext
+ auto bits = getSignExtBits(binary->left);
+ binary->left = makeZeroExt(ext, bits);
+ // the const we compare to only needs the relevant bits
+ c->value = c->value.and_(Literal(lowBitMask(bits)));
+ return binary;
+ }
if (c->value.geti32() == 0) {
// equal 0 => eqz
return Builder(*getModule()).makeUnary(EqZInt32, binary->left);
}
- }
- if (auto* c = binary->left->dynCast<Const>()) {
- if (c->value.geti32() == 0) {
- // equal 0 => eqz
- return Builder(*getModule()).makeUnary(EqZInt32, binary->right);
+ } else if (auto* left = getSignExt(binary->left)) {
+ if (auto* right = getSignExt(binary->right)) {
+ // we are comparing two sign-exts, so we may as well replace both with cheaper zexts
+ auto bits = getSignExtBits(binary->left);
+ binary->left = makeZeroExt(left, bits);
+ binary->right = makeZeroExt(right, bits);
+ return binary;
}
}
+ // note that both left and right may be consts, but then we let precompute compute the constant result
} else if (binary->op == AndInt32) {
if (auto* right = binary->right->dynCast<Const>()) {
if (right->type == i32) {
@@ -454,6 +570,11 @@ private:
offset = 0;
}
}
+
+ Expression* makeZeroExt(Expression* curr, int32_t bits) {
+ Builder builder(*getModule());
+ return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits))));
+ }
};
Pass *createOptimizeInstructionsPass() {