diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 603 |
1 files changed, 391 insertions, 212 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index c098d0ed7..8a9309554 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -20,28 +20,29 @@ #include <algorithm> -#include <wasm.h> -#include <pass.h> -#include <wasm-s-parser.h> -#include <support/threads.h> #include <ir/abstract.h> -#include <ir/utils.h> #include <ir/cost.h> #include <ir/effects.h> -#include <ir/manipulation.h> -#include <ir/properties.h> #include <ir/literal-utils.h> #include <ir/load-utils.h> +#include <ir/manipulation.h> +#include <ir/properties.h> +#include <ir/utils.h> +#include <pass.h> +#include <support/threads.h> +#include <wasm-s-parser.h> +#include <wasm.h> -// TODO: Use the new sign-extension opcodes where appropriate. This needs to be conditionalized on the availability of atomics. +// TODO: Use the new sign-extension opcodes where appropriate. This needs to be +// conditionalized on the availability of atomics. namespace wasm { -Name I32_EXPR = "i32.expr", - I64_EXPR = "i64.expr", - F32_EXPR = "f32.expr", - F64_EXPR = "f64.expr", - ANY_EXPR = "any.expr"; +Name I32_EXPR = "i32.expr"; +Name I64_EXPR = "i64.expr"; +Name F32_EXPR = "f32.expr"; +Name F64_EXPR = "f64.expr"; +Name ANY_EXPR = "any.expr"; // Utilities @@ -53,28 +54,47 @@ template<typename LocalInfoProvider> Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { if (auto* const_ = curr->dynCast<Const>()) { switch (curr->type) { - case i32: return 32 - const_->value.countLeadingZeroes().geti32(); - case i64: return 64 - const_->value.countLeadingZeroes().geti64(); - default: WASM_UNREACHABLE(); + case i32: + return 32 - const_->value.countLeadingZeroes().geti32(); + case i64: + return 64 - const_->value.countLeadingZeroes().geti64(); + default: + WASM_UNREACHABLE(); } } else if (auto* binary = curr->dynCast<Binary>()) { switch (binary->op) { // 32-bit - case AddInt32: case SubInt32: case MulInt32: - case DivSInt32: case DivUInt32: case RemSInt32: - case RemUInt32: case RotLInt32: case RotRInt32: return 32; - case AndInt32: return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); - case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); + case AddInt32: + case SubInt32: + case MulInt32: + case DivSInt32: + case DivUInt32: + case RemSInt32: + case RemUInt32: + case RotLInt32: + case RotRInt32: + return 32; + case AndInt32: + return std::min(getMaxBits(binary->left, localInfoProvider), + getMaxBits(binary->right, localInfoProvider)); + case OrInt32: + case XorInt32: + return std::max(getMaxBits(binary->left, localInfoProvider), + getMaxBits(binary->right, localInfoProvider)); case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { - return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + Bits::getEffectiveShifts(shifts)); + return std::min(Index(32), + getMaxBits(binary->left, localInfoProvider) + + Bits::getEffectiveShifts(shifts)); } return 32; } case ShrUInt32: { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); - auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; @@ -82,34 +102,67 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { case ShrSInt32: { if (auto* shift = binary->right->dynCast<Const>()) { auto maxBits = getMaxBits(binary->left, localInfoProvider); - if (maxBits == 32) return 32; - auto shifts = std::min(Index(Bits::getEffectiveShifts(shift)), maxBits); // can ignore more shifts than zero us out + if (maxBits == 32) + return 32; + auto shifts = + std::min(Index(Bits::getEffectiveShifts(shift)), + maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } return 32; } // 64-bit TODO // comparisons - case EqInt32: case NeInt32: case LtSInt32: - case LtUInt32: case LeSInt32: case LeUInt32: - case GtSInt32: case GtUInt32: case GeSInt32: + case EqInt32: + case NeInt32: + case LtSInt32: + case LtUInt32: + case LeSInt32: + case LeUInt32: + case GtSInt32: + case GtUInt32: + case GeSInt32: case GeUInt32: - case EqInt64: case NeInt64: case LtSInt64: - case LtUInt64: case LeSInt64: case LeUInt64: - case GtSInt64: case GtUInt64: case GeSInt64: + case EqInt64: + case NeInt64: + case LtSInt64: + case LtUInt64: + case LeSInt64: + case LeUInt64: + case GtSInt64: + case GtUInt64: + case GeSInt64: case GeUInt64: - case EqFloat32: case NeFloat32: - case LtFloat32: case LeFloat32: case GtFloat32: case GeFloat32: - case EqFloat64: case NeFloat64: - case LtFloat64: case LeFloat64: case GtFloat64: case GeFloat64: return 1; + case EqFloat32: + case NeFloat32: + case LtFloat32: + case LeFloat32: + case GtFloat32: + case GeFloat32: + case EqFloat64: + case NeFloat64: + case LtFloat64: + case LeFloat64: + case GtFloat64: + case GeFloat64: + return 1; default: {} } } else if (auto* unary = curr->dynCast<Unary>()) { switch (unary->op) { - case ClzInt32: case CtzInt32: case PopcntInt32: return 6; - case ClzInt64: case CtzInt64: case PopcntInt64: return 7; - case EqZInt32: case EqZInt64: return 1; - case WrapInt64: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); + case ClzInt32: + case CtzInt32: + case PopcntInt32: + return 6; + case ClzInt64: + case CtzInt64: + case PopcntInt64: + return 7; + case EqZInt32: + case EqZInt64: + return 1; + case WrapInt64: + return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); default: {} } } else if (auto* set = curr->dynCast<SetLocal>()) { @@ -125,10 +178,14 @@ Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { } } switch (curr->type) { - case i32: return 32; - case i64: return 64; - case unreachable: return 64; // not interesting, but don't crash - default: WASM_UNREACHABLE(); + case i32: + return 32; + case i64: + return 64; + case unreachable: + return 64; // not interesting, but don't crash + default: + WASM_UNREACHABLE(); } } @@ -170,9 +227,11 @@ struct LocalScanner : PostWalker<LocalScanner> { void visitSetLocal(SetLocal* curr) { auto* func = getFunction(); - if (func->isParam(curr->index)) return; + if (func->isParam(curr->index)) + return; auto type = getFunction()->getLocalType(curr->index); - if (type != i32 && type != i64) return; + if (type != i32 && type != i64) + return; // an integer var, worth processing auto* value = Properties::getFallthrough(curr->value); auto& info = localInfo[curr->index]; @@ -188,26 +247,32 @@ struct LocalScanner : PostWalker<LocalScanner> { if (info.signExtedBits == 0) { info.signExtedBits = signExtBits; // first info we see } else if (info.signExtedBits != signExtBits) { - info.signExtedBits = LocalInfo::kUnknown; // contradictory information, give up + // contradictory information, give up + info.signExtedBits = LocalInfo::kUnknown; } } - // define this for the templated getMaxBits method. we know nothing here yet about locals, so return the maxes - Index getMaxBitsForLocal(GetLocal* get) { - return getBitsForType(get->type); - } + // define this for the templated getMaxBits method. we know nothing here yet + // about locals, so return the maxes + Index getMaxBitsForLocal(GetLocal* get) { return getBitsForType(get->type); } Index getBitsForType(Type type) { switch (type) { - case i32: return 32; - case i64: return 64; - default: return -1; + case i32: + return 32; + case i64: + return 64; + default: + return -1; } } }; // Main pass class -struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> { +struct OptimizeInstructions + : public WalkerPass< + PostWalker<OptimizeInstructions, + UnifiedExpressionVisitor<OptimizeInstructions>>> { bool isFunctionParallel() override { return true; } Pass* create() override { return new OptimizeInstructions; } @@ -229,7 +294,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } void visitExpression(Expression* curr) { - // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles + // we may be able to apply multiple patterns, one may open opportunities + // that look deeper NB: patterns must not have cycles while (1) { auto* handOptimized = handOptimize(curr); if (handOptimized) { @@ -258,14 +324,15 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } - // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe + // Optimizations that don't yet fit in the pattern DSL, but could be + // eventually maybe Expression* handOptimize(Expression* curr) { // if this contains dead code, don't bother trying to optimize it, the type - // might change (if might not be unreachable if just one arm is, for example). - // this optimization pass focuses on actually executing code. the only - // exceptions are control flow changes - if (curr->type == unreachable && - !curr->is<Break>() && !curr->is<Switch>() && !curr->is<If>()) { + // might change (if might not be unreachable if just one arm is, for + // example). this optimization pass focuses on actually executing code. the + // only exceptions are control flow changes + if (curr->type == unreachable && !curr->is<Break>() && + !curr->is<Switch>() && !curr->is<If>()) { return nullptr; } if (auto* binary = curr->dynCast<Binary>()) { @@ -277,10 +344,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, auto bits = Properties::getAlmostSignExtBits(binary, extraShifts); if (extraShifts == 0) { if (auto* load = Properties::getFallthrough(ext)->dynCast<Load>()) { - // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. + // pattern match a load of 8 bits and a sign extend using a shl of + // 24 then shr_s of 24 as well, etc. if (LoadUtils::canBeSigned(load) && - ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) { - // if the value falls through, we can't alter the load, as it might be captured in a tee + ((load->bytes == 1 && bits == 8) || + (load->bytes == 2 && bits == 16))) { + // if the value falls through, we can't alter the load, as it + // might be captured in a tee if (load->signed_ == true || load == ext) { load->signed_ = true; return ext; @@ -289,8 +359,10 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } // if the sign-extend input cannot have a sign bit, we don't need it - // we also don't need it if it already has an identical-sized sign extend - if (getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { + // we also don't need it if it already has an identical-sized sign + // extend + if (getMaxBits(ext, this) + extraShifts < bits || + isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } } else if (binary->op == EqInt32 || binary->op == NeInt32) { @@ -300,34 +372,44 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } if (auto* ext = Properties::getSignExtValue(binary->left)) { - // we are comparing a sign extend to a constant, which means we can use a cheaper zext + // we are comparing a sign extend to a constant, which means we can + // use a cheaper zext auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(ext, bits); - // when we replace the sign-ext of the non-constant with a zero-ext, we are forcing - // the high bits to be all zero, instead of all zero or all one depending on the - // sign bit. so we may be changing the high bits from all one to all zero: - // * if the constant value's higher bits are mixed, then it can't be equal anyhow - // * if they are all zero, we may get a false true if the non-constant's upper bits - // were one. this can only happen if the non-constant's sign bit is set, so this - // false true is a risk only if the constant's sign bit is set (otherwise, false). - // But a constant with a sign bit but with upper bits zero is impossible to be - // equal to a sign-extended value anyhow, so the entire thing is false. - // * if they were all one, we may get a false false, if the only difference is in - // those upper bits. that means we are equal on the other bits, including the sign - // bit. so we can just mask off the upper bits in the constant value, in this - // case, forcing them to zero like we do in the zero-extend. + // when we replace the sign-ext of the non-constant with a zero-ext, + // we are forcing the high bits to be all zero, instead of all zero + // or all one depending on the sign bit. so we may be changing the + // high bits from all one to all zero: + // * if the constant value's higher bits are mixed, then it can't + // be equal anyhow + // * if they are all zero, we may get a false true if the + // non-constant's upper bits were one. this can only happen if + // the non-constant's sign bit is set, so this false true is a + // risk only if the constant's sign bit is set (otherwise, + // false). But a constant with a sign bit but with upper bits + // zero is impossible to be equal to a sign-extended value + // anyhow, so the entire thing is false. + // * if they were all one, we may get a false false, if the only + // difference is in those upper bits. that means we are equal on + // the other bits, including the sign bit. so we can just mask + // off the upper bits in the constant value, in this case, + // forcing them to zero like we do in the zero-extend. int32_t constValue = c->value.geti32(); auto upperConstValue = constValue & ~Bits::lowBitMask(bits); uint32_t count = PopCount(upperConstValue); auto constSignBit = constValue & (1 << (bits - 1)); - if ((count > 0 && count < 32 - bits) || (constSignBit && count == 0)) { - // mixed or [zero upper const bits with sign bit set]; the compared values can never be identical, so - // force something definitely impossible even after zext + if ((count > 0 && count < 32 - bits) || + (constSignBit && count == 0)) { + // mixed or [zero upper const bits with sign bit set]; the + // compared values can never be identical, so force something + // definitely impossible even after zext assert(bits < 32); c->value = Literal(int32_t(0x80000000)); - // TODO: if no side effects, we can just replace it all with 1 or 0 + // TODO: if no side effects, we can just replace it all with 1 or + // 0 } else { - // otherwise, they are all ones, so we can mask them off as mentioned before + // otherwise, they are all ones, so we can mask them off as + // mentioned before c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); } return binary; @@ -336,13 +418,15 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, if (auto* right = Properties::getSignExtValue(binary->right)) { auto bits = Properties::getSignExtBits(binary->left); if (Properties::getSignExtBits(binary->right) == bits) { - // we are comparing two sign-exts with the same bits, so we may as well replace both with cheaper zexts + // we are comparing two sign-exts with the same bits, so we may as + // well replace both with cheaper zexts binary->left = makeZeroExt(left, bits); binary->right = makeZeroExt(right, bits); return binary; } } else if (auto* load = binary->right->dynCast<Load>()) { - // we are comparing a load to a sign-ext, we may be able to switch to zext + // we are comparing a load to a sign-ext, we may be able to switch + // to zext auto leftBits = Properties::getSignExtBits(binary->left); if (load->signed_ && leftBits == load->bytes * 8) { load->signed_ = false; @@ -352,7 +436,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } else if (auto* load = binary->left->dynCast<Load>()) { if (auto* right = Properties::getSignExtValue(binary->right)) { - // we are comparing a load to a sign-ext, we may be able to switch to zext + // we are comparing a load to a sign-ext, we may be able to switch + // to zext auto rightBits = Properties::getSignExtBits(binary->right); if (load->signed_ && rightBits == load->bytes * 8) { load->signed_ = false; @@ -361,7 +446,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } - // note that both left and right may be consts, but then we let precompute compute the constant result + // note that both left and right may be consts, but then we let + // precompute compute the constant result } else if (binary->op == AddInt32) { // try to get rid of (0 - ..), that is, a zero only used to negate an // int. an add of a subtract can be flipped in order to remove it: @@ -382,7 +468,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, if (sub->op == SubInt32) { if (auto* subZero = sub->left->dynCast<Const>()) { if (subZero->value.geti32() == 0) { - if (EffectAnalyzer::canReorder(getPassOptions(), sub->right, binary->right)) { + if (EffectAnalyzer::canReorder( + getPassOptions(), sub->right, binary->right)) { sub->left = binary->right; return sub; } @@ -414,10 +501,12 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } auto* ret = optimizeAddedConstants(binary); - if (ret) return ret; + if (ret) + return ret; } else if (binary->op == SubInt32) { auto* ret = optimizeAddedConstants(binary); - if (ret) return ret; + if (ret) + return ret; } // a bunch of operations on a constant right side can be simplified if (auto* right = binary->right->dynCast<Const>()) { @@ -443,7 +532,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } // some math operations have trivial results Expression* ret = optimizeWithConstantOnRight(binary); - if (ret) return ret; + if (ret) + return ret; // the square of some operations can be merged if (auto* left = binary->left->dynCast<Binary>()) { if (left->op == binary->op) { @@ -454,11 +544,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } else if (left->op == OrInt32) { leftRight->value = leftRight->value.or_(right->value); return left; - } else if (left->op == ShlInt32 || left->op == ShrUInt32 || left->op == ShrSInt32 || - left->op == ShlInt64 || left->op == ShrUInt64 || left->op == ShrSInt64) { - // shifts only use an effective amount from the constant, so adding must - // be done carefully - auto total = Bits::getEffectiveShifts(leftRight) + Bits::getEffectiveShifts(right); + } else if (left->op == ShlInt32 || left->op == ShrUInt32 || + left->op == ShrSInt32 || left->op == ShlInt64 || + left->op == ShrUInt64 || left->op == ShrSInt64) { + // shifts only use an effective amount from the constant, so + // adding must be done carefully + auto total = Bits::getEffectiveShifts(leftRight) + + Bits::getEffectiveShifts(right); if (total == Bits::getEffectiveShifts(total, right->type)) { // no overflow, we can do this leftRight->value = Literal::makeFromInt32(total, right->type); @@ -483,7 +575,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, // a bunch of operations on a constant left side can be simplified if (binary->left->is<Const>()) { Expression* ret = optimizeWithConstantOnLeft(binary); - if (ret) return ret; + if (ret) + return ret; } // bitwise operations if (binary->op == AndInt32) { @@ -540,40 +633,89 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, if (unary->op == EqZInt32) { if (auto* inner = unary->value->dynCast<Binary>()) { switch (inner->op) { - case EqInt32: inner->op = NeInt32; return inner; - case NeInt32: inner->op = EqInt32; return inner; - case LtSInt32: inner->op = GeSInt32; return inner; - case LtUInt32: inner->op = GeUInt32; return inner; - case LeSInt32: inner->op = GtSInt32; return inner; - case LeUInt32: inner->op = GtUInt32; return inner; - case GtSInt32: inner->op = LeSInt32; return inner; - case GtUInt32: inner->op = LeUInt32; return inner; - case GeSInt32: inner->op = LtSInt32; return inner; - case GeUInt32: inner->op = LtUInt32; return inner; + case EqInt32: + inner->op = NeInt32; + return inner; + case NeInt32: + inner->op = EqInt32; + return inner; + case LtSInt32: + inner->op = GeSInt32; + return inner; + case LtUInt32: + inner->op = GeUInt32; + return inner; + case LeSInt32: + inner->op = GtSInt32; + return inner; + case LeUInt32: + inner->op = GtUInt32; + return inner; + case GtSInt32: + inner->op = LeSInt32; + return inner; + case GtUInt32: + inner->op = LeUInt32; + return inner; + case GeSInt32: + inner->op = LtSInt32; + return inner; + case GeUInt32: + inner->op = LtUInt32; + return inner; - case EqInt64: inner->op = NeInt64; return inner; - case NeInt64: inner->op = EqInt64; return inner; - case LtSInt64: inner->op = GeSInt64; return inner; - case LtUInt64: inner->op = GeUInt64; return inner; - case LeSInt64: inner->op = GtSInt64; return inner; - case LeUInt64: inner->op = GtUInt64; return inner; - case GtSInt64: inner->op = LeSInt64; return inner; - case GtUInt64: inner->op = LeUInt64; return inner; - case GeSInt64: inner->op = LtSInt64; return inner; - case GeUInt64: inner->op = LtUInt64; return inner; + case EqInt64: + inner->op = NeInt64; + return inner; + case NeInt64: + inner->op = EqInt64; + return inner; + case LtSInt64: + inner->op = GeSInt64; + return inner; + case LtUInt64: + inner->op = GeUInt64; + return inner; + case LeSInt64: + inner->op = GtSInt64; + return inner; + case LeUInt64: + inner->op = GtUInt64; + return inner; + case GtSInt64: + inner->op = LeSInt64; + return inner; + case GtUInt64: + inner->op = LeUInt64; + return inner; + case GeSInt64: + inner->op = LtSInt64; + return inner; + case GeUInt64: + inner->op = LtUInt64; + return inner; - case EqFloat32: inner->op = NeFloat32; return inner; - case NeFloat32: inner->op = EqFloat32; return inner; + case EqFloat32: + inner->op = NeFloat32; + return inner; + case NeFloat32: + inner->op = EqFloat32; + return inner; - case EqFloat64: inner->op = NeFloat64; return inner; - case NeFloat64: inner->op = EqFloat64; return inner; + case EqFloat64: + inner->op = NeFloat64; + return inner; + case NeFloat64: + inner->op = EqFloat64; + return inner; default: {} } } // eqz of a sign extension can be of zero-extension if (auto* ext = Properties::getSignExtValue(unary->value)) { - // we are comparing a sign extend to a constant, which means we can use a cheaper zext + // we are comparing a sign extend to a constant, which means we can + // use a cheaper zext auto bits = Properties::getSignExtBits(unary->value); unary->value = makeZeroExt(ext, bits); return unary; @@ -595,24 +737,26 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, std::swap(iff->ifTrue, iff->ifFalse); } } - if (iff->condition->type != unreachable && ExpressionAnalyzer::equal(iff->ifTrue, iff->ifFalse)) { + if (iff->condition->type != unreachable && + ExpressionAnalyzer::equal(iff->ifTrue, iff->ifFalse)) { // sides are identical, fold - // if we can replace the if with one arm, and no side effects in the condition, do that - auto needCondition = EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects(); + // if we can replace the if with one arm, and no side effects in the + // condition, do that + auto needCondition = + EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects(); auto typeIsIdentical = iff->ifTrue->type == iff->type; if (typeIsIdentical && !needCondition) { return iff->ifTrue; } else { Builder builder(*getModule()); if (typeIsIdentical) { - return builder.makeSequence( - builder.makeDrop(iff->condition), - iff->ifTrue - ); + return builder.makeSequence(builder.makeDrop(iff->condition), + iff->ifTrue); } else { - // the types diff. as the condition is reachable, that means the if must be - // concrete while the arm is not - assert(isConcreteType(iff->type) && iff->ifTrue->type == unreachable); + // the types diff. as the condition is reachable, that means the + // if must be concrete while the arm is not + assert(isConcreteType(iff->type) && + iff->ifTrue->type == unreachable); // emit a block with a forced type auto* ret = builder.makeBlock(); if (needCondition) { @@ -638,22 +782,24 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } if (auto* c = select->condition->dynCast<Const>()) { - // constant condition, we can just pick the right side (barring side effects) + // constant condition, we can just pick the right side (barring side + // effects) if (c->value.getInteger()) { - if (!EffectAnalyzer(getPassOptions(), select->ifFalse).hasSideEffects()) { + if (!EffectAnalyzer(getPassOptions(), select->ifFalse) + .hasSideEffects()) { return select->ifTrue; } else { - // don't bother - we would need to reverse the order using a temp local, which is bad + // don't bother - we would need to reverse the order using a temp + // local, which is bad } } else { - if (!EffectAnalyzer(getPassOptions(), select->ifTrue).hasSideEffects()) { + if (!EffectAnalyzer(getPassOptions(), select->ifTrue) + .hasSideEffects()) { return select->ifFalse; } else { Builder builder(*getModule()); - return builder.makeSequence( - builder.makeDrop(select->ifTrue), - select->ifFalse - ); + return builder.makeSequence(builder.makeDrop(select->ifTrue), + select->ifFalse); } } } @@ -676,10 +822,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, // can reorder if (!condition.invalidates(value)) { Builder builder(*getModule()); - return builder.makeSequence( - builder.makeDrop(select->condition), - select->ifTrue - ); + return builder.makeSequence(builder.makeDrop(select->condition), + select->ifTrue); } } } @@ -705,8 +849,9 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } else if (auto* ext = Properties::getSignExtValue(binary)) { - // if sign extending the exact bit size we store, we can skip the extension - // if extending something bigger, then we just alter bits we don't save anyhow + // if sign extending the exact bit size we store, we can skip the + // extension if extending something bigger, then we just alter bits we + // don't save anyhow if (Properties::getSignExtBits(binary) >= Index(store->bytes) * 8) { store->value = ext; } @@ -736,11 +881,13 @@ private: void canonicalize(Binary* binary) { assert(Properties::isSymmetric(binary)); auto swap = [&]() { - assert(EffectAnalyzer::canReorder(getPassOptions(), binary->left, binary->right)); + assert(EffectAnalyzer::canReorder( + getPassOptions(), binary->left, binary->right)); std::swap(binary->left, binary->right); }; auto maybeSwap = [&]() { - if (EffectAnalyzer::canReorder(getPassOptions(), binary->left, binary->right)) { + if (EffectAnalyzer::canReorder( + getPassOptions(), binary->left, binary->right)) { swap(); } }; @@ -748,7 +895,8 @@ private: if (binary->left->is<Const>() && !binary->right->is<Const>()) { return swap(); } - if (binary->right->is<Const>()) return; + if (binary->right->is<Const>()) + return; // Prefer a get on the right. if (binary->left->is<GetLocal>() && !binary->right->is<GetLocal>()) { return maybeSwap(); @@ -793,7 +941,8 @@ private: } } else if (auto* binary = boolean->dynCast<Binary>()) { if (binary->op == OrInt32) { - // an or flowing into a boolean context can consider each input as boolean + // an or flowing into a boolean context can consider each input as + // boolean binary->left = optimizeBoolean(binary->left); binary->right = optimizeBoolean(binary->right); } else if (binary->op == NeInt32) { @@ -805,7 +954,8 @@ private: } } if (auto* ext = Properties::getSignExtValue(binary)) { - // use a cheaper zero-extent, we just care about the boolean value anyhow + // use a cheaper zero-extent, we just care about the boolean value + // anyhow return makeZeroExt(ext, Properties::getSignExtBits(binary)); } } else if (auto* block = boolean->dynCast<Block>()) { @@ -822,12 +972,14 @@ private: return boolean; } - // find added constants in an expression tree, including multiplied/shifted, and combine them - // note that we ignore division/shift-right, as rounding makes this nonlinear, so not a valid opt + // find added constants in an expression tree, including multiplied/shifted, + // and combine them note that we ignore division/shift-right, as rounding + // makes this nonlinear, so not a valid opt Expression* optimizeAddedConstants(Binary* binary) { uint32_t constant = 0; std::vector<Const*> constants; - std::function<void (Expression*, int)> seek = [&](Expression* curr, int mul) { + std::function<void(Expression*, int)> seek = [&](Expression* curr, + int mul) { if (auto* c = curr->dynCast<Const>()) { uint32_t value = c->value.geti32(); if (value != 0) { @@ -867,7 +1019,8 @@ private: // find all factors seek(binary, 1); if (constants.size() <= 1) { - // nothing much to do, except for the trivial case of adding/subbing a zero + // nothing much to do, except for the trivial case of adding/subbing a + // zero if (auto* c = binary->right->dynCast<Const>()) { if (c->value.geti32() == 0) { return binary->left; @@ -906,19 +1059,24 @@ private: return; } } else if (curr->op == ShlInt32) { - // shifting a 0 is a 0, or anything by 0 has no effect, all unless the shift has side effects - if (((left && left->value.geti32() == 0) || (right && Bits::getEffectiveShifts(right) == 0)) && + // shifting a 0 is a 0, or anything by 0 has no effect, all unless the + // shift has side effects + if (((left && left->value.geti32() == 0) || + (right && Bits::getEffectiveShifts(right) == 0)) && !EffectAnalyzer(passOptions, curr->right).hasSideEffects()) { replaceCurrent(curr->left); return; } } else if (curr->op == MulInt32) { - // multiplying by zero is a zero, unless the other side has side effects - if (left && left->value.geti32() == 0 && !EffectAnalyzer(passOptions, curr->right).hasSideEffects()) { + // multiplying by zero is a zero, unless the other side has side + // effects + if (left && left->value.geti32() == 0 && + !EffectAnalyzer(passOptions, curr->right).hasSideEffects()) { replaceCurrent(left); return; } - if (right && right->value.geti32() == 0 && !EffectAnalyzer(passOptions, curr->left).hasSideEffects()) { + if (right && right->value.geti32() == 0 && + !EffectAnalyzer(passOptions, curr->left).hasSideEffects()) { replaceCurrent(right); return; } @@ -927,50 +1085,58 @@ private: }; Expression* walked = binary; ZeroRemover(getPassOptions()).walk(walked); - if (constant == 0) return walked; // nothing more to do + if (constant == 0) + return walked; // nothing more to do if (auto* c = walked->dynCast<Const>()) { assert(c->value.geti32() == 0); c->value = Literal(constant); return c; } Builder builder(*getModule()); - return builder.makeBinary(AddInt32, - walked, - builder.makeConst(Literal(constant)) - ); + return builder.makeBinary( + AddInt32, walked, builder.makeConst(Literal(constant))); } - // expensive1 | expensive2 can be turned into expensive1 ? 1 : expensive2, and - // expensive | cheap can be turned into cheap ? 1 : expensive, + // expensive1 | expensive2 can be turned into expensive1 ? 1 : expensive2, + // and expensive | cheap can be turned into cheap ? 1 : expensive, // so that we can avoid one expensive computation, if it has no side effects. Expression* conditionalizeExpensiveOnBitwise(Binary* binary) { // this operation can increase code size, so don't always do it auto& options = getPassRunner()->options; - if (options.optimizeLevel < 2 || options.shrinkLevel > 0) return nullptr; + if (options.optimizeLevel < 2 || options.shrinkLevel > 0) + return nullptr; const auto MIN_COST = 7; assert(binary->op == AndInt32 || binary->op == OrInt32); - if (binary->right->is<Const>()) return nullptr; // trivial - // bitwise logical operator on two non-numerical values, check if they are boolean + if (binary->right->is<Const>()) + return nullptr; // trivial + // bitwise logical operator on two non-numerical values, check if they are + // boolean auto* left = binary->left; auto* right = binary->right; - if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) return nullptr; + if (!Properties::emitsBoolean(left) || !Properties::emitsBoolean(right)) + return nullptr; auto leftEffects = EffectAnalyzer(getPassOptions(), left); auto rightEffects = EffectAnalyzer(getPassOptions(), right); auto leftHasSideEffects = leftEffects.hasSideEffects(); auto rightHasSideEffects = rightEffects.hasSideEffects(); - if (leftHasSideEffects && rightHasSideEffects) return nullptr; // both must execute + if (leftHasSideEffects && rightHasSideEffects) + return nullptr; // both must execute // canonicalize with side effects, if any, happening on the left if (rightHasSideEffects) { - if (CostAnalyzer(left).cost < MIN_COST) return nullptr; // avoidable code is too cheap - if (leftEffects.invalidates(rightEffects)) return nullptr; // cannot reorder + if (CostAnalyzer(left).cost < MIN_COST) + return nullptr; // avoidable code is too cheap + if (leftEffects.invalidates(rightEffects)) + return nullptr; // cannot reorder std::swap(left, right); } else if (leftHasSideEffects) { - if (CostAnalyzer(right).cost < MIN_COST) return nullptr; // avoidable code is too cheap + if (CostAnalyzer(right).cost < MIN_COST) + return nullptr; // avoidable code is too cheap } else { // no side effects, reorder based on cost estimation auto leftCost = CostAnalyzer(left).cost; auto rightCost = CostAnalyzer(right).cost; - if (std::max(leftCost, rightCost) < MIN_COST) return nullptr; // avoidable code is too cheap + if (std::max(leftCost, rightCost) < MIN_COST) + return nullptr; // avoidable code is too cheap // canonicalize with expensive code on the right if (leftCost > rightCost) { std::swap(left, right); @@ -979,9 +1145,11 @@ private: // worth it! perform conditionalization Builder builder(*getModule()); if (binary->op == OrInt32) { - return builder.makeIf(left, builder.makeConst(Literal(int32_t(1))), right); + return builder.makeIf( + left, builder.makeConst(Literal(int32_t(1))), right); } else { // & - return builder.makeIf(left, right, builder.makeConst(Literal(int32_t(0)))); + return builder.makeIf( + left, right, builder.makeConst(Literal(int32_t(0)))); } } @@ -1015,8 +1183,9 @@ private: // fold constant factors into the offset void optimizeMemoryAccess(Expression*& ptr, Address& offset) { - // ptr may be a const, but it isn't worth folding that in (we still have a const); in fact, - // it's better to do the opposite for gzip purposes as well as for readability. + // ptr may be a const, but it isn't worth folding that in (we still have a + // const); in fact, it's better to do the opposite for gzip purposes as well + // as for readability. auto* last = ptr->dynCast<Const>(); if (last) { // don't do this if it would wrap the pointer @@ -1058,7 +1227,8 @@ private: Expression* makeZeroExt(Expression* curr, int32_t bits) { Builder builder(*getModule()); - return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits)))); + return builder.makeBinary( + AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits)))); } // given an "almost" sign extend - either a proper one, or it @@ -1070,7 +1240,8 @@ private: auto* outerConst = outer->right->cast<Const>(); auto* innerConst = inner->right->cast<Const>(); auto* value = inner->left; - if (outerConst->value == innerConst->value) return value; + if (outerConst->value == innerConst->value) + return value; // add a shift, by reusing the existing node innerConst->value = innerConst->value.sub(outerConst->value); return inner; @@ -1105,7 +1276,8 @@ private: return binary->left; } else if ((binary->op == Abstract::getBinary(type, Abstract::Mul) || binary->op == Abstract::getBinary(type, Abstract::And)) && - !EffectAnalyzer(getPassOptions(), binary->left).hasSideEffects()) { + !EffectAnalyzer(getPassOptions(), binary->left) + .hasSideEffects()) { return binary->right; } } @@ -1116,7 +1288,8 @@ private: if (binary->op == Abstract::getBinary(type, Abstract::And)) { return binary->left; } else if (binary->op == Abstract::getBinary(type, Abstract::Or) && - !EffectAnalyzer(getPassOptions(), binary->left).hasSideEffects()) { + !EffectAnalyzer(getPassOptions(), binary->left) + .hasSideEffects()) { return binary->right; } } @@ -1129,15 +1302,10 @@ private: if (binary->op == Abstract::getBinary(type, Abstract::Add) || binary->op == Abstract::getBinary(type, Abstract::Sub)) { auto value = right->value.getInteger(); - if (value == 0x40 || - value == 0x2000 || - value == 0x100000 || - value == 0x8000000 || - value == 0x400000000LL || - value == 0x20000000000LL || - value == 0x1000000000000LL || - value == 0x80000000000000LL || - value == 0x4000000000000000LL) { + if (value == 0x40 || value == 0x2000 || value == 0x100000 || + value == 0x8000000 || value == 0x400000000LL || + value == 0x20000000000LL || value == 0x1000000000000LL || + value == 0x80000000000000LL || value == 0x4000000000000000LL) { right->value = right->value.neg(); if (binary->op == Abstract::getBinary(type, Abstract::Add)) { binary->op = Abstract::getBinary(type, Abstract::Sub); @@ -1202,12 +1370,16 @@ private: left->op == Abstract::getBinary(type, Abstract::Sub)) { if (auto* leftConst = left->right->dynCast<Const>()) { if (auto* rightConst = binary->right->dynCast<Const>()) { - return combineRelationalConstants(binary, left, leftConst, nullptr, rightConst); + return combineRelationalConstants( + binary, left, leftConst, nullptr, rightConst); } else if (auto* rightBinary = binary->right->dynCast<Binary>()) { - if (rightBinary->op == Abstract::getBinary(type, Abstract::Add) || - rightBinary->op == Abstract::getBinary(type, Abstract::Sub)) { + if (rightBinary->op == + Abstract::getBinary(type, Abstract::Add) || + rightBinary->op == + Abstract::getBinary(type, Abstract::Sub)) { if (auto* rightConst = rightBinary->right->dynCast<Const>()) { - return combineRelationalConstants(binary, left, leftConst, rightBinary, rightConst); + return combineRelationalConstants( + binary, left, leftConst, rightBinary, rightConst); } } } @@ -1220,9 +1392,13 @@ private: } // given a relational binary with a const on both sides, combine the constants - // left is also a binary, and has a constant; right may be just a constant, in which - // case right is nullptr - Expression* combineRelationalConstants(Binary* binary, Binary* left, Const* leftConst, Binary* right, Const* rightConst) { + // left is also a binary, and has a constant; right may be just a constant, in + // which case right is nullptr + Expression* combineRelationalConstants(Binary* binary, + Binary* left, + Const* leftConst, + Binary* right, + Const* rightConst) { auto type = binary->right->type; // we fold constants to the right Literal extra = leftConst->value; @@ -1237,8 +1413,8 @@ private: return binary; } - // given a binary expression with equal children and no side effects in either, - // we can fold various things + // given a binary expression with equal children and no side effects in + // either, we can fold various things // TODO: trinaries, things like (x & (y & x)) ? Expression* optimizeBinaryWithEqualEffectlessChildren(Binary* binary) { // TODO add: perhaps worth doing 2*x if x is quite large? @@ -1246,7 +1422,8 @@ private: case SubInt32: case XorInt32: case SubInt64: - case XorInt64: return LiteralUtils::makeZero(binary->left->type, *getModule()); + case XorInt64: + return LiteralUtils::makeZero(binary->left->type, *getModule()); case NeInt64: case LtSInt64: case LtUInt64: @@ -1256,11 +1433,13 @@ private: case LtSInt32: case LtUInt32: case GtSInt32: - case GtUInt32: return LiteralUtils::makeZero(i32, *getModule()); + case GtUInt32: + return LiteralUtils::makeZero(i32, *getModule()); case AndInt32: case OrInt32: case AndInt64: - case OrInt64: return binary->left; + case OrInt64: + return binary->left; case EqInt32: case LeSInt32: case LeUInt32: @@ -1270,14 +1449,14 @@ private: case LeSInt64: case LeUInt64: case GeSInt64: - case GeUInt64: return LiteralUtils::makeFromInt32(1, i32, *getModule()); - default: return nullptr; + case GeUInt64: + return LiteralUtils::makeFromInt32(1, i32, *getModule()); + default: + return nullptr; } } }; -Pass *createOptimizeInstructionsPass() { - return new OptimizeInstructions(); -} +Pass* createOptimizeInstructionsPass() { return new OptimizeInstructions(); } } // namespace wasm |