diff options
Diffstat (limited to 'src/passes/OptimizeInstructions.cpp')
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 153 |
1 files changed, 4 insertions, 149 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 860ea3ac7..7f66ca7ee 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -22,6 +22,7 @@ #include <type_traits> #include <ir/abstract.h> +#include <ir/bits.h> #include <ir/cost.h> #include <ir/effects.h> #include <ir/literal-utils.h> @@ -45,152 +46,6 @@ Name F32_EXPR = "f32.expr"; Name F64_EXPR = "f64.expr"; Name ANY_EXPR = "any.expr"; -// Utilities - -// returns the maximum amount of bits used in an integer expression -// not extremely precise (doesn't look into add operands, etc.) -// LocalInfoProvider is an optional class that can provide answers about -// local.get. -template<typename LocalInfoProvider> -Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { - if (auto* const_ = curr->dynCast<Const>()) { - switch (curr->type.getSingle()) { - case Type::i32: - return 32 - const_->value.countLeadingZeroes().geti32(); - case Type::i64: - return 64 - const_->value.countLeadingZeroes().geti64(); - default: - WASM_UNREACHABLE("invalid type"); - } - } else if (auto* binary = curr->dynCast<Binary>()) { - switch (binary->op) { - // 32-bit - case AddInt32: - case SubInt32: - case MulInt32: - case DivSInt32: - case DivUInt32: - case RemSInt32: - case RemUInt32: - case RotLInt32: - case RotRInt32: - return 32; - case AndInt32: - return std::min(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); - case OrInt32: - case XorInt32: - return std::max(getMaxBits(binary->left, localInfoProvider), - getMaxBits(binary->right, localInfoProvider)); - case ShlInt32: { - if (auto* shifts = binary->right->dynCast<Const>()) { - return std::min(Index(32), - getMaxBits(binary->left, localInfoProvider) + - Bits::getEffectiveShifts(shifts)); - } - return 32; - } - case ShrUInt32: { - if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left, localInfoProvider); - auto shifts = - std::min(Index(Bits::getEffectiveShifts(shift)), - maxBits); // can ignore more shifts than zero us out - return std::max(Index(0), maxBits - shifts); - } - return 32; - } - case ShrSInt32: { - if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left, localInfoProvider); - if (maxBits == 32) { - return 32; - } - auto shifts = - std::min(Index(Bits::getEffectiveShifts(shift)), - maxBits); // can ignore more shifts than zero us out - return std::max(Index(0), maxBits - shifts); - } - return 32; - } - // 64-bit TODO - // comparisons - case EqInt32: - case NeInt32: - case LtSInt32: - case LtUInt32: - case LeSInt32: - case LeUInt32: - case GtSInt32: - case GtUInt32: - case GeSInt32: - case GeUInt32: - case EqInt64: - case NeInt64: - case LtSInt64: - case LtUInt64: - case LeSInt64: - case LeUInt64: - case GtSInt64: - case GtUInt64: - case GeSInt64: - case GeUInt64: - case EqFloat32: - case NeFloat32: - case LtFloat32: - case LeFloat32: - case GtFloat32: - case GeFloat32: - case EqFloat64: - case NeFloat64: - case LtFloat64: - case LeFloat64: - case GtFloat64: - case GeFloat64: - return 1; - default: {} - } - } else if (auto* unary = curr->dynCast<Unary>()) { - switch (unary->op) { - case ClzInt32: - case CtzInt32: - case PopcntInt32: - return 6; - case ClzInt64: - case CtzInt64: - case PopcntInt64: - return 7; - case EqZInt32: - case EqZInt64: - return 1; - case WrapInt64: - return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); - default: {} - } - } else if (auto* set = curr->dynCast<LocalSet>()) { - // a tee passes through the value - return getMaxBits(set->value, localInfoProvider); - } else if (auto* get = curr->dynCast<LocalGet>()) { - return localInfoProvider->getMaxBitsForLocal(get); - } else if (auto* load = curr->dynCast<Load>()) { - // if signed, then the sign-extension might fill all the bits - // if unsigned, then we have a limit - if (LoadUtils::isSignRelevant(load) && !load->signed_) { - return 8 * load->bytes; - } - } - switch (curr->type.getSingle()) { - case Type::i32: - return 32; - case Type::i64: - return 64; - case Type::unreachable: - return 64; // not interesting, but don't crash - default: - WASM_UNREACHABLE("invalid type"); - } -} - // Useful information about locals struct LocalInfo { static const Index kUnknown = Index(-1); @@ -243,7 +98,7 @@ struct LocalScanner : PostWalker<LocalScanner> { auto* value = Properties::getFallthrough( curr->value, passOptions, getModule()->features); auto& info = localInfo[curr->index]; - info.maxBits = std::max(info.maxBits, getMaxBits(value, this)); + info.maxBits = std::max(info.maxBits, Bits::getMaxBits(value, this)); auto signExtBits = LocalInfo::kUnknown; if (Properties::getSignExtValue(value)) { signExtBits = Properties::getSignExtBits(value); @@ -373,7 +228,7 @@ struct OptimizeInstructions // if the sign-extend input cannot have a sign bit, we don't need it // we also don't need it if it already has an identical-sized sign // extend - if (getMaxBits(ext, this) + extraShifts < bits || + if (Bits::getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } @@ -538,7 +393,7 @@ struct OptimizeInstructions return binary->left; } } else if (auto maskedBits = Bits::getMaskedBits(mask)) { - if (getMaxBits(binary->left, this) <= maskedBits) { + if (Bits::getMaxBits(binary->left, this) <= maskedBits) { // a mask of lower bits is not needed if we are already smaller return binary->left; } |