diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 131 |
1 files changed, 120 insertions, 11 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 844ed06ac..8e5ff6c12 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -166,7 +166,10 @@ struct Match { // returns the maximum amount of bits used in an integer expression // not extremely precise (doesn't look into add operands, etc.) -static Index getMaxBits(Expression* curr) { +// LocalInfoProvider is an optional class that can provide answers about +// get_local. +template<typename LocalInfoProvider> +Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { if (auto* const_ = curr->dynCast<Const>()) { switch (curr->type) { case i32: return 32 - const_->value.countLeadingZeroes().geti32(); @@ -179,17 +182,17 @@ static Index getMaxBits(Expression* curr) { case AddInt32: case SubInt32: case MulInt32: case DivSInt32: case DivUInt32: case RemSInt32: case RemUInt32: case RotLInt32: case RotRInt32: return 32; - case AndInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right)); - case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right)); + case AndInt32: return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); + case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { - return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32()); + return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + shifts->value.geti32()); } return 32; } case ShrUInt32: { if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left); + auto maxBits = getMaxBits(binary->left, localInfoProvider); auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } @@ -197,7 +200,7 @@ static Index getMaxBits(Expression* curr) { } case ShrSInt32: { if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left); + auto maxBits = getMaxBits(binary->left, localInfoProvider); if (maxBits == 32) return 32; auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); @@ -225,14 +228,19 @@ static Index getMaxBits(Expression* curr) { case ClzInt32: case CtzInt32: case PopcntInt32: return 5; case ClzInt64: case CtzInt64: case PopcntInt64: return 6; case EqZInt32: case EqZInt64: return 1; - case WrapInt64: return std::min(Index(32), getMaxBits(unary->value)); + case WrapInt64: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); default: {} } } else if (auto* set = curr->dynCast<SetLocal>()) { // a tee passes through the value - return getMaxBits(set->value); + return getMaxBits(set->value, localInfoProvider); + } else if (auto* get = curr->dynCast<GetLocal>()) { + return localInfoProvider->getMaxBitsForLocal(get); } else if (auto* load = curr->dynCast<Load>()) { - return 8 * load->bytes; + // if signed, then the sign-extension might fill all the bits + if (!load->signed_) { + return 8 * load->bytes; + } } switch (curr->type) { case i32: return 32; @@ -334,6 +342,76 @@ T* getFallthroughDynCast(Expression* curr) { return nullptr; } +// Useful information about locals +struct LocalInfo { + static const Index kUnknown = Index(-1); + + Index maxBits; + Index signExtedBits; +}; + +struct LocalScanner : PostWalker<LocalScanner, Visitor<LocalScanner>> { + std::vector<LocalInfo>& localInfo; + + LocalScanner(std::vector<LocalInfo>& localInfo) : localInfo(localInfo) {} + + void doWalkFunction(Function* func) { + // prepare + localInfo.resize(func->getNumLocals()); + for (Index i = 0; i < func->getNumLocals(); i++) { + auto& info = localInfo[i]; + if (func->isParam(i)) { + info.maxBits = getBitsForType(func->getLocalType(i)); // worst-case + info.signExtedBits = LocalInfo::kUnknown; // we will never know anything + } else { + info.maxBits = info.signExtedBits = 0; // we are open to learning + } + } + // walk + PostWalker<LocalScanner, Visitor<LocalScanner>>::doWalkFunction(func); + // finalize + for (Index i = 0; i < func->getNumLocals(); i++) { + auto& info = localInfo[i]; + if (info.signExtedBits == LocalInfo::kUnknown) { + info.signExtedBits = 0; + } + } + } + + void visitSetLocal(SetLocal* curr) { + auto* func = getFunction(); + if (func->isParam(curr->index)) return; + auto type = getFunction()->getLocalType(curr->index); + if (type != i32 && type != i64) return; + // an integer var, worth processing + auto& info = localInfo[curr->index]; + info.maxBits = std::max(info.maxBits, getMaxBits(curr->value, this)); + if (getSignExt(curr->value)) { + auto bits = getSignExtBits(curr->value); + if (info.signExtedBits == 0) { + info.signExtedBits = bits; // first info we see + } else if (info.signExtedBits != bits) { + info.signExtedBits = LocalInfo::kUnknown; // contradictory information, give up + } + } else { + info.signExtedBits = LocalInfo::kUnknown; // an input which isn't even a sign ext, give up + } + } + + // define this for the templated getMaxBits method. we know nothing here yet about locals, so return the maxes + Index getMaxBitsForLocal(GetLocal* get) { + return getBitsForType(get->type); + } + + Index getBitsForType(WasmType type) { + switch (type) { + case i32: return 32; + case i64: return 64; + default: return -1; + } + } +}; + // Main pass class struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> { bool isFunctionParallel() override { return true; } @@ -346,6 +424,16 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, #endif } + void doWalkFunction(Function* func) { + // first, scan locals + { + LocalScanner scanner(localInfo); + scanner.walkFunction(func); + } + // main walk + WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>>::doWalkFunction(func); + } + void visitExpression(Expression* curr) { // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles while (1) { @@ -399,7 +487,8 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } // if the sign-extend input cannot have a sign bit, we don't need it - if (getMaxBits(ext) + extraShifts < bits) { + // we also don't need it if it already has an identical-sized sign extend + if (getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } } else if (binary->op == EqInt32 || binary->op == NeInt32) { @@ -448,7 +537,7 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } else if (auto maskedBits = getMaskedBits(mask)) { - if (getMaxBits(binary->left) <= maskedBits) { + if (getMaxBits(binary->left, this) <= maskedBits) { // a mask of lower bits is not needed if we are already smaller return binary->left; } @@ -586,7 +675,15 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, return nullptr; } + Index getMaxBitsForLocal(GetLocal* get) { + // check what we know about the local + return localInfo[get->index].maxBits; + } + private: + // Information about our locals + std::vector<LocalInfo> localInfo; + // Optimize given that the expression is flowing into a boolean context Expression* optimizeBoolean(Expression* boolean) { if (auto* unary = boolean->dynCast<Unary>()) { @@ -816,6 +913,18 @@ private: innerConst->value = innerConst->value.sub(outerConst->value); return inner; } + + // check if an expression is already sign-extended + bool isSignExted(Expression* curr, Index bits) { + if (getSignExt(curr)) { + return getSignExtBits(curr) == bits; + } + if (auto* get = curr->dynCast<GetLocal>()) { + // check what we know about the local + return localInfo[get->index].signExtedBits == bits; + } + return false; + } }; Pass *createOptimizeInstructionsPass() { |