diff options
author | Alon Zakai <alonzakai@gmail.com> | 2017-02-16 21:37:39 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-16 21:37:39 -0800 |
commit | a78dddbcf2bf9f23840c7074ce16c04a8d55c3df (patch) | |
tree | dccd48fa4d610267637a754ebfde189f6e9222d9 /src | |
parent | 1d72468f464a725afdee4b28c5f6bd4d7a631c92 (diff) | |
download | binaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.tar.gz binaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.tar.bz2 binaryen-a78dddbcf2bf9f23840c7074ce16c04a8d55c3df.zip |
Optimize sign-extends (#902)
* optimize sign-extend output
* optimize sign-extend input
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 157 |
1 files changed, 139 insertions, 18 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 8f7bf63e7..c374a15ac 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -160,9 +160,115 @@ struct Match { }; return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy); } +}; +// Utilities -}; +// returns the maximum amount of bits used in an integer expression +// not extremely precise (doesn't look into add operands, etc.) +static Index getMaxBits(Expression* curr) { + if (auto* const_ = curr->dynCast<Const>()) { + switch (curr->type) { + case i32: return 32 - const_->value.countLeadingZeroes().geti32(); + case i64: return 64 - const_->value.countLeadingZeroes().geti64(); + default: WASM_UNREACHABLE(); + } + } else if (auto* binary = curr->dynCast<Binary>()) { + switch (binary->op) { + // 32-bit + case AddInt32: case SubInt32: case MulInt32: + case DivSInt32: case DivUInt32: case RemSInt32: + case RemUInt32: case RotLInt32: case RotRInt32: return 32; + case AndInt32: case XorInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right)); + case OrInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right)); + case ShlInt32: { + if (auto* shifts = binary->right->dynCast<Const>()) { + return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32()); + } + return 32; + } + case ShrUInt32: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left); + auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + case ShrSInt32: { + if (auto* shift = binary->right->dynCast<Const>()) { + auto maxBits = getMaxBits(binary->left); + if (maxBits == 32) return 32; + auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + // 64-bit TODO + // comparisons + case EqInt32: case NeInt32: case LtSInt32: + case LtUInt32: case LeSInt32: case LeUInt32: + case GtSInt32: case GtUInt32: case GeSInt32: + case GeUInt32: + case EqInt64: case NeInt64: case LtSInt64: + case LtUInt64: case LeSInt64: case LeUInt64: + case GtSInt64: case GtUInt64: case GeSInt64: + case GeUInt64: + case EqFloat32: case NeFloat32: + case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: + case EqFloat64: case NeFloat64: + case LtFloat64: case LeFloat64: case GtFloat64: case GeFloat64: return 1; + default: {} + } + } else if (auto* unary = curr->dynCast<Unary>()) { + switch (unary->op) { + case ClzInt32: case CtzInt32: case PopcntInt32: return 5; + case ClzInt64: case CtzInt64: case PopcntInt64: return 6; + case EqZInt32: case EqZInt64: return 1; + case WrapInt64: return std::min(Index(32), getMaxBits(unary->value)); + default: {} + } + } + switch (curr->type) { + case i32: return 32; + case i64: return 64; + case unreachable: return 64; // not interesting, but don't crash + default: WASM_UNREACHABLE(); + } +} + +// Check if an expression is a sign-extend, and if so, returns the value +// that is extended, otherwise nullptr +static Expression* getSignExt(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (outerConst->value == innerConst->value) { + return inner->left; + } + } + } + } + } + } + } + return nullptr; +} + +// gets the size of the sign-extended value +static Index getSignExtBits(Expression* curr) { + return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); +} + +// get a mask to keep only the low # of bits +static int32_t lowBitMask(int32_t bits) { + uint32_t ret = -1; + if (bits >= 32) return ret; + return ret >> (32 - bits); +} // Main pass class struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> { @@ -215,32 +321,42 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, std::swap(binary->left, binary->right); } } - // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. - if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) { - auto shifts = binary->right->cast<Const>()->value.geti32(); - if (shifts == 24 || shifts == 16) { - auto* left = binary->left->dynCast<Binary>(); - if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) { - auto* load = left->left->dynCast<Load>(); - if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) { - load->signed_ = true; - return load; - } - } + if (auto* ext = getSignExt(binary)) { + auto bits = getSignExtBits(binary); + auto* load = ext->dynCast<Load>(); + // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. + if (load && ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) { + load->signed_ = true; + return load; + } + // if the sign-extend input cannot have a sign bit, we don't need it + if (getMaxBits(ext) < bits) { + return ext; } } else if (binary->op == EqInt32) { if (auto* c = binary->right->dynCast<Const>()) { + if (auto* ext = getSignExt(binary->left)) { + // we are comparing a sign extend to a constant, which means we can use a cheaper zext + auto bits = getSignExtBits(binary->left); + binary->left = makeZeroExt(ext, bits); + // the const we compare to only needs the relevant bits + c->value = c->value.and_(Literal(lowBitMask(bits))); + return binary; + } if (c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } - } - if (auto* c = binary->left->dynCast<Const>()) { - if (c->value.geti32() == 0) { - // equal 0 => eqz - return Builder(*getModule()).makeUnary(EqZInt32, binary->right); + } else if (auto* left = getSignExt(binary->left)) { + if (auto* right = getSignExt(binary->right)) { + // we are comparing two sign-exts, so we may as well replace both with cheaper zexts + auto bits = getSignExtBits(binary->left); + binary->left = makeZeroExt(left, bits); + binary->right = makeZeroExt(right, bits); + return binary; } } + // note that both left and right may be consts, but then we let precompute compute the constant result } else if (binary->op == AndInt32) { if (auto* right = binary->right->dynCast<Const>()) { if (right->type == i32) { @@ -454,6 +570,11 @@ private: offset = 0; } } + + Expression* makeZeroExt(Expression* curr, int32_t bits) { + Builder builder(*getModule()); + return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits)))); + } }; Pass *createOptimizeInstructionsPass() { |