From a78dddbcf2bf9f23840c7074ce16c04a8d55c3df Mon Sep 17 00:00:00 2001 From: Alon Zakai Date: Thu, 16 Feb 2017 21:37:39 -0800 Subject: Optimize sign-extends (#902) * optimize sign-extend output * optimize sign-extend input --- src/passes/OptimizeInstructions.cpp | 157 +++++++++++++++++++++++++++++++----- 1 file changed, 139 insertions(+), 18 deletions(-) (limited to 'src') diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 8f7bf63e7..c374a15ac 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -160,9 +160,115 @@ struct Match { }; return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy); } +}; +// Utilities -}; +// returns the maximum amount of bits used in an integer expression +// not extremely precise (doesn't look into add operands, etc.) +static Index getMaxBits(Expression* curr) { + if (auto* const_ = curr->dynCast()) { + switch (curr->type) { + case i32: return 32 - const_->value.countLeadingZeroes().geti32(); + case i64: return 64 - const_->value.countLeadingZeroes().geti64(); + default: WASM_UNREACHABLE(); + } + } else if (auto* binary = curr->dynCast()) { + switch (binary->op) { + // 32-bit + case AddInt32: case SubInt32: case MulInt32: + case DivSInt32: case DivUInt32: case RemSInt32: + case RemUInt32: case RotLInt32: case RotRInt32: return 32; + case AndInt32: case XorInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right)); + case OrInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right)); + case ShlInt32: { + if (auto* shifts = binary->right->dynCast()) { + return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32()); + } + return 32; + } + case ShrUInt32: { + if (auto* shift = binary->right->dynCast()) { + auto maxBits = getMaxBits(binary->left); + auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + case ShrSInt32: { + if (auto* shift = binary->right->dynCast()) { + auto maxBits = getMaxBits(binary->left); + if (maxBits == 32) return 32; + auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out + return std::max(Index(0), maxBits - shifts); + } + return 32; + } + // 64-bit TODO + // comparisons + case EqInt32: case NeInt32: case LtSInt32: + case LtUInt32: case LeSInt32: case LeUInt32: + case GtSInt32: case GtUInt32: case GeSInt32: + case GeUInt32: + case EqInt64: case NeInt64: case LtSInt64: + case LtUInt64: case LeSInt64: case LeUInt64: + case GtSInt64: case GtUInt64: case GeSInt64: + case GeUInt64: + case EqFloat32: case NeFloat32: + case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: + case EqFloat64: case NeFloat64: + case LtFloat64: case LeFloat64: case GtFloat64: case GeFloat64: return 1; + default: {} + } + } else if (auto* unary = curr->dynCast()) { + switch (unary->op) { + case ClzInt32: case CtzInt32: case PopcntInt32: return 5; + case ClzInt64: case CtzInt64: case PopcntInt64: return 6; + case EqZInt32: case EqZInt64: return 1; + case WrapInt64: return std::min(Index(32), getMaxBits(unary->value)); + default: {} + } + } + switch (curr->type) { + case i32: return 32; + case i64: return 64; + case unreachable: return 64; // not interesting, but don't crash + default: WASM_UNREACHABLE(); + } +} + +// Check if an expression is a sign-extend, and if so, returns the value +// that is extended, otherwise nullptr +static Expression* getSignExt(Expression* curr) { + if (auto* outer = curr->dynCast()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast()) { + if (auto* inner = outer->left->dynCast()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast()) { + if (outerConst->value == innerConst->value) { + return inner->left; + } + } + } + } + } + } + } + return nullptr; +} + +// gets the size of the sign-extended value +static Index getSignExtBits(Expression* curr) { + return 32 - curr->cast()->right->cast()->value.geti32(); +} + +// get a mask to keep only the low # of bits +static int32_t lowBitMask(int32_t bits) { + uint32_t ret = -1; + if (bits >= 32) return ret; + return ret >> (32 - bits); +} // Main pass class struct OptimizeInstructions : public WalkerPass>> { @@ -215,32 +321,42 @@ struct OptimizeInstructions : public WalkerPassleft, binary->right); } } - // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. - if (binary->op == BinaryOp::ShrSInt32 && binary->right->is()) { - auto shifts = binary->right->cast()->value.geti32(); - if (shifts == 24 || shifts == 16) { - auto* left = binary->left->dynCast(); - if (left && left->op == ShlInt32 && left->right->is() && left->right->cast()->value.geti32() == shifts) { - auto* load = left->left->dynCast(); - if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) { - load->signed_ = true; - return load; - } - } + if (auto* ext = getSignExt(binary)) { + auto bits = getSignExtBits(binary); + auto* load = ext->dynCast(); + // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. + if (load && ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) { + load->signed_ = true; + return load; + } + // if the sign-extend input cannot have a sign bit, we don't need it + if (getMaxBits(ext) < bits) { + return ext; } } else if (binary->op == EqInt32) { if (auto* c = binary->right->dynCast()) { + if (auto* ext = getSignExt(binary->left)) { + // we are comparing a sign extend to a constant, which means we can use a cheaper zext + auto bits = getSignExtBits(binary->left); + binary->left = makeZeroExt(ext, bits); + // the const we compare to only needs the relevant bits + c->value = c->value.and_(Literal(lowBitMask(bits))); + return binary; + } if (c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } - } - if (auto* c = binary->left->dynCast()) { - if (c->value.geti32() == 0) { - // equal 0 => eqz - return Builder(*getModule()).makeUnary(EqZInt32, binary->right); + } else if (auto* left = getSignExt(binary->left)) { + if (auto* right = getSignExt(binary->right)) { + // we are comparing two sign-exts, so we may as well replace both with cheaper zexts + auto bits = getSignExtBits(binary->left); + binary->left = makeZeroExt(left, bits); + binary->right = makeZeroExt(right, bits); + return binary; } } + // note that both left and right may be consts, but then we let precompute compute the constant result } else if (binary->op == AndInt32) { if (auto* right = binary->right->dynCast()) { if (right->type == i32) { @@ -454,6 +570,11 @@ private: offset = 0; } } + + Expression* makeZeroExt(Expression* curr, int32_t bits) { + Builder builder(*getModule()); + return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits)))); + } }; Pass *createOptimizeInstructionsPass() { -- cgit v1.2.3