diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ast/bits.h | 47 | ||||
-rw-r--r-- | src/ast/properties.h | 77 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 282 | ||||
-rw-r--r-- | src/passes/PickLoadSigns.cpp | 107 | ||||
-rw-r--r-- | src/passes/pass.cpp | 4 | ||||
-rw-r--r-- | src/passes/passes.h | 1 |
7 files changed, 438 insertions, 81 deletions
diff --git a/src/ast/bits.h b/src/ast/bits.h new file mode 100644 index 000000000..d88cb5edb --- /dev/null +++ b/src/ast/bits.h @@ -0,0 +1,47 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ast_bits_h +#define wasm_ast_bits_h + +#include "support/bits.h" + +namespace wasm { + +struct Bits { + // get a mask to keep only the low # of bits + static int32_t lowBitMask(int32_t bits) { + uint32_t ret = -1; + if (bits >= 32) return ret; + return ret >> (32 - bits); + } + + // checks if the input is a mask of lower bits, i.e., all 1s up to some high bit, and all zeros + // from there. returns the number of masked bits, or 0 if this is not such a mask + static uint32_t getMaskedBits(uint32_t mask) { + if (mask == uint32_t(-1)) return 32; // all the bits + if (mask == 0) return 0; // trivially not a mask + // otherwise, see if adding one turns this into a 1-bit thing, 00011111 + 1 => 00100000 + if (PopCount(mask + 1) != 1) return 0; + // this is indeed a mask + return 32 - CountLeadingZeroes(mask); + } +}; + +} // namespace wasm + +#endif // wasm_ast_bits_h + diff --git a/src/ast/properties.h b/src/ast/properties.h index 9834c73d0..9121deba5 100644 --- a/src/ast/properties.h +++ b/src/ast/properties.h @@ -18,6 +18,7 @@ #define wasm_ast_properties_h #include "wasm.h" +#include "ast/bits.h" namespace wasm { @@ -52,6 +53,82 @@ struct Properties { default: return false; } } + + // Check if an expression is a sign-extend, and if so, returns the value + // that is extended, otherwise nullptr + static Expression* getSignExtValue(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (outerConst->value == innerConst->value) { + return inner->left; + } + } + } + } + } + } + } + return nullptr; + } + + // gets the size of the sign-extended value + static Index getSignExtBits(Expression* curr) { + return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); + } + + // Check if an expression is almost a sign-extend: perhaps the inner shift + // is too large. We can split the shifts in that case, which is sometimes + // useful (e.g. if we can remove the signext) + static Expression* getAlmostSignExt(Expression* curr) { + if (auto* outer = curr->dynCast<Binary>()) { + if (outer->op == ShrSInt32) { + if (auto* outerConst = outer->right->dynCast<Const>()) { + if (auto* inner = outer->left->dynCast<Binary>()) { + if (inner->op == ShlInt32) { + if (auto* innerConst = inner->right->dynCast<Const>()) { + if (outerConst->value.leU(innerConst->value).geti32()) { + return inner->left; + } + } + } + } + } + } + } + return nullptr; + } + + // gets the size of the almost sign-extended value, as well as the + // extra shifts, if any + static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { + extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() - + curr->cast<Binary>()->right->cast<Const>()->value.geti32(); + return getSignExtBits(curr); + } + + // Check if an expression is a zero-extend, and if so, returns the value + // that is extended, otherwise nullptr + static Expression* getZeroExtValue(Expression* curr) { + if (auto* binary = curr->dynCast<Binary>()) { + if (binary->op == AndInt32) { + if (auto* c = binary->right->dynCast<Const>()) { + if (Bits::getMaskedBits(c->value.geti32())) { + return binary->right; + } + } + } + } + return nullptr; + } + + // gets the size of the sign-extended value + static Index getZeroExtBits(Expression* curr) { + return Bits::getMaskedBits(curr->cast<Binary>()->right->cast<Const>()->value.geti32()); + } }; } // wasm diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 9db1c66ae..74f120e85 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -13,6 +13,7 @@ SET(passes_SOURCES NameManager.cpp NameList.cpp OptimizeInstructions.cpp + PickLoadSigns.cpp PostEmscripten.cpp Precompute.cpp Print.cpp diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index bb4748a97..79c01dee7 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -166,7 +166,10 @@ struct Match { // returns the maximum amount of bits used in an integer expression // not extremely precise (doesn't look into add operands, etc.) -static Index getMaxBits(Expression* curr) { +// LocalInfoProvider is an optional class that can provide answers about +// get_local. +template<typename LocalInfoProvider> +Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) { if (auto* const_ = curr->dynCast<Const>()) { switch (curr->type) { case i32: return 32 - const_->value.countLeadingZeroes().geti32(); @@ -179,17 +182,17 @@ static Index getMaxBits(Expression* curr) { case AddInt32: case SubInt32: case MulInt32: case DivSInt32: case DivUInt32: case RemSInt32: case RemUInt32: case RotLInt32: case RotRInt32: return 32; - case AndInt32: case XorInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right)); - case OrInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right)); + case AndInt32: return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); + case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider)); case ShlInt32: { if (auto* shifts = binary->right->dynCast<Const>()) { - return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32()); + return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + shifts->value.geti32()); } return 32; } case ShrUInt32: { if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left); + auto maxBits = getMaxBits(binary->left, localInfoProvider); auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); } @@ -197,7 +200,7 @@ static Index getMaxBits(Expression* curr) { } case ShrSInt32: { if (auto* shift = binary->right->dynCast<Const>()) { - auto maxBits = getMaxBits(binary->left); + auto maxBits = getMaxBits(binary->left, localInfoProvider); if (maxBits == 32) return 32; auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out return std::max(Index(0), maxBits - shifts); @@ -225,9 +228,19 @@ static Index getMaxBits(Expression* curr) { case ClzInt32: case CtzInt32: case PopcntInt32: return 5; case ClzInt64: case CtzInt64: case PopcntInt64: return 6; case EqZInt32: case EqZInt64: return 1; - case WrapInt64: return std::min(Index(32), getMaxBits(unary->value)); + case WrapInt64: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider)); default: {} } + } else if (auto* set = curr->dynCast<SetLocal>()) { + // a tee passes through the value + return getMaxBits(set->value, localInfoProvider); + } else if (auto* get = curr->dynCast<GetLocal>()) { + return localInfoProvider->getMaxBitsForLocal(get); + } else if (auto* load = curr->dynCast<Load>()) { + // if signed, then the sign-extension might fill all the bits + if (!load->signed_) { + return 8 * load->bytes; + } } switch (curr->type) { case i32: return 32; @@ -237,68 +250,95 @@ static Index getMaxBits(Expression* curr) { } } -// Check if an expression is a sign-extend, and if so, returns the value -// that is extended, otherwise nullptr -static Expression* getSignExt(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (outerConst->value == innerConst->value) { - return inner->left; - } - } - } - } - } +// looks through fallthrough operations, like tee_local, block fallthrough, etc. +// too and block fallthroughs, etc. +Expression* getFallthrough(Expression* curr) { + if (auto* set = curr->dynCast<SetLocal>()) { + if (set->isTee()) { + return getFallthrough(set->value); + } + } else if (auto* block = curr->dynCast<Block>()) { + // if no name, we can't be broken to, and then can look at the fallthrough + if (!block->name.is() && block->list.size() > 0) { + return getFallthrough(block->list.back()); } } - return nullptr; + return curr; } -// gets the size of the sign-extended value -static Index getSignExtBits(Expression* curr) { - return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); -} +// Useful information about locals +struct LocalInfo { + static const Index kUnknown = Index(-1); -// Check if an expression is almost a sign-extend: perhaps the inner shift -// is too large. We can split the shifts in that case, which is sometimes -// useful (e.g. if we can remove the signext) -static Expression* getAlmostSignExt(Expression* curr) { - if (auto* outer = curr->dynCast<Binary>()) { - if (outer->op == ShrSInt32) { - if (auto* outerConst = outer->right->dynCast<Const>()) { - if (auto* inner = outer->left->dynCast<Binary>()) { - if (inner->op == ShlInt32) { - if (auto* innerConst = inner->right->dynCast<Const>()) { - if (outerConst->value.leU(innerConst->value).geti32()) { - return inner->left; - } - } - } - } + Index maxBits; + Index signExtedBits; +}; + +struct LocalScanner : PostWalker<LocalScanner, Visitor<LocalScanner>> { + std::vector<LocalInfo>& localInfo; + + LocalScanner(std::vector<LocalInfo>& localInfo) : localInfo(localInfo) {} + + void doWalkFunction(Function* func) { + // prepare + localInfo.resize(func->getNumLocals()); + for (Index i = 0; i < func->getNumLocals(); i++) { + auto& info = localInfo[i]; + if (func->isParam(i)) { + info.maxBits = getBitsForType(func->getLocalType(i)); // worst-case + info.signExtedBits = LocalInfo::kUnknown; // we will never know anything + } else { + info.maxBits = info.signExtedBits = 0; // we are open to learning + } + } + // walk + PostWalker<LocalScanner, Visitor<LocalScanner>>::doWalkFunction(func); + // finalize + for (Index i = 0; i < func->getNumLocals(); i++) { + auto& info = localInfo[i]; + if (info.signExtedBits == LocalInfo::kUnknown) { + info.signExtedBits = 0; } } } - return nullptr; -} -// gets the size of the almost sign-extended value, as well as the -// extra shifts, if any -static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) { - extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() - - curr->cast<Binary>()->right->cast<Const>()->value.geti32(); - return getSignExtBits(curr); -} + void visitSetLocal(SetLocal* curr) { + auto* func = getFunction(); + if (func->isParam(curr->index)) return; + auto type = getFunction()->getLocalType(curr->index); + if (type != i32 && type != i64) return; + // an integer var, worth processing + auto* value = getFallthrough(curr->value); + auto& info = localInfo[curr->index]; + info.maxBits = std::max(info.maxBits, getMaxBits(value, this)); + auto signExtBits = LocalInfo::kUnknown; + if (Properties::getSignExtValue(value)) { + signExtBits = Properties::getSignExtBits(value); + } else if (auto* load = value->dynCast<Load>()) { + if (load->signed_) { + signExtBits = load->bytes * 8; + } + } + if (info.signExtedBits == 0) { + info.signExtedBits = signExtBits; // first info we see + } else if (info.signExtedBits != signExtBits) { + info.signExtedBits = LocalInfo::kUnknown; // contradictory information, give up + } + } -// get a mask to keep only the low # of bits -static int32_t lowBitMask(int32_t bits) { - uint32_t ret = -1; - if (bits >= 32) return ret; - return ret >> (32 - bits); -} + // define this for the templated getMaxBits method. we know nothing here yet about locals, so return the maxes + Index getMaxBitsForLocal(GetLocal* get) { + return getBitsForType(get->type); + } + + Index getBitsForType(WasmType type) { + switch (type) { + case i32: return 32; + case i64: return 64; + default: return -1; + } + } +}; // Main pass class struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> { @@ -312,6 +352,16 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, #endif } + void doWalkFunction(Function* func) { + // first, scan locals + { + LocalScanner scanner(localInfo); + scanner.walkFunction(func); + } + // main walk + WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>>::doWalkFunction(func); + } + void visitExpression(Expression* curr) { // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles while (1) { @@ -351,40 +401,63 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, std::swap(binary->left, binary->right); } } - if (auto* ext = getAlmostSignExt(binary)) { + if (auto* ext = Properties::getAlmostSignExt(binary)) { Index extraShifts; - auto bits = getAlmostSignExtBits(binary, extraShifts); - auto* load = ext->dynCast<Load>(); - // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. - if (load && ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) { - load->signed_ = true; - return removeAlmostSignExt(binary); + auto bits = Properties::getAlmostSignExtBits(binary, extraShifts); + if (auto* load = getFallthrough(ext)->dynCast<Load>()) { + // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. + if ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16)) { + // if the value falls through, we can't alter the load, as it might be captured in a tee + if (load->signed_ == true || load == ext) { + load->signed_ = true; + return removeAlmostSignExt(binary); + } + } } // if the sign-extend input cannot have a sign bit, we don't need it - if (getMaxBits(ext) + extraShifts < bits) { + // we also don't need it if it already has an identical-sized sign extend + if (getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) { return removeAlmostSignExt(binary); } - } else if (binary->op == EqInt32) { + } else if (binary->op == EqInt32 || binary->op == NeInt32) { if (auto* c = binary->right->dynCast<Const>()) { - if (auto* ext = getSignExt(binary->left)) { + if (auto* ext = Properties::getSignExtValue(binary->left)) { // we are comparing a sign extend to a constant, which means we can use a cheaper zext - auto bits = getSignExtBits(binary->left); + auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(ext, bits); // the const we compare to only needs the relevant bits - c->value = c->value.and_(Literal(lowBitMask(bits))); + c->value = c->value.and_(Literal(Bits::lowBitMask(bits))); return binary; } - if (c->value.geti32() == 0) { + if (binary->op == EqInt32 && c->value.geti32() == 0) { // equal 0 => eqz return Builder(*getModule()).makeUnary(EqZInt32, binary->left); } - } else if (auto* left = getSignExt(binary->left)) { - if (auto* right = getSignExt(binary->right)) { + } else if (auto* left = Properties::getSignExtValue(binary->left)) { + if (auto* right = Properties::getSignExtValue(binary->right)) { // we are comparing two sign-exts, so we may as well replace both with cheaper zexts - auto bits = getSignExtBits(binary->left); + auto bits = Properties::getSignExtBits(binary->left); binary->left = makeZeroExt(left, bits); binary->right = makeZeroExt(right, bits); return binary; + } else if (auto* load = binary->right->dynCast<Load>()) { + // we are comparing a load to a sign-ext, we may be able to switch to zext + auto leftBits = Properties::getSignExtBits(binary->left); + if (load->signed_ && leftBits == load->bytes * 8) { + load->signed_ = false; + binary->left = makeZeroExt(left, leftBits); + return binary; + } + } + } else if (auto* load = binary->left->dynCast<Load>()) { + if (auto* right = Properties::getSignExtValue(binary->right)) { + // we are comparing a load to a sign-ext, we may be able to switch to zext + auto rightBits = Properties::getSignExtBits(binary->right); + if (load->signed_ && rightBits == load->bytes * 8) { + load->signed_ = false; + binary->right = makeZeroExt(right, rightBits); + return binary; + } } } // note that both left and right may be consts, but then we let precompute compute the constant result @@ -404,11 +477,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, if ((load->bytes == 1 && mask == 0xff) || (load->bytes == 2 && mask == 0xffff)) { load->signed_ = false; - return load; + return binary->left; + } + } else if (auto maskedBits = Bits::getMaskedBits(mask)) { + if (getMaxBits(binary->left, this) <= maskedBits) { + // a mask of lower bits is not needed if we are already smaller + return binary->left; } - } else if (mask == 1 && Properties::emitsBoolean(binary->left)) { - // (bool) & 1 does not need the outer mask - return binary->left; } } // the square of some operations can be merged @@ -468,6 +543,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, default: {} } } + // eqz of a sign extension can be of zero-extension + if (auto* ext = Properties::getSignExtValue(unary->value)) { + // we are comparing a sign extend to a constant, which means we can use a cheaper zext + auto bits = Properties::getSignExtBits(unary->value); + unary->value = makeZeroExt(ext, bits); + return unary; + } } } else if (auto* set = curr->dynCast<SetGlobal>()) { // optimize out a set of a get @@ -518,6 +600,12 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, } } } + } else if (auto* ext = Properties::getSignExtValue(binary)) { + // if sign extending the exact bit size we store, we can skip the extension + // if extending something bigger, then we just alter bits we don't save anyhow + if (Properties::getSignExtBits(binary) >= store->bytes * 8) { + store->value = ext; + } } } else if (auto* unary = store->value->dynCast<Unary>()) { if (unary->op == WrapInt64) { @@ -530,7 +618,15 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, return nullptr; } + Index getMaxBitsForLocal(GetLocal* get) { + // check what we know about the local + return localInfo[get->index].maxBits; + } + private: + // Information about our locals + std::vector<LocalInfo> localInfo; + // Optimize given that the expression is flowing into a boolean context Expression* optimizeBoolean(Expression* boolean) { if (auto* unary = boolean->dynCast<Unary>()) { @@ -554,6 +650,10 @@ private: } } } + if (auto* ext = Properties::getSignExtValue(binary)) { + // use a cheaper zero-extent, we just care about the boolean value anyhow + return makeZeroExt(ext, Properties::getSignExtBits(binary)); + } } else if (auto* block = boolean->dynCast<Block>()) { if (block->type == i32 && block->list.size() > 0) { block->list.back() = optimizeBoolean(block->list.back()); @@ -611,7 +711,15 @@ private: }; // find all factors seek(binary, 1); - if (constants.size() <= 1) return nullptr; // nothing to do + if (constants.size() <= 1) { + // nothing much to do, except for the trivial case of adding/subbing a zero + if (auto* c = binary->right->dynCast<Const>()) { + if (c->value.geti32() == 0) { + return binary->left; + } + } + return nullptr; + } // wipe out all constants, we'll replace with a single added one for (auto* c : constants) { c->value = Literal(int32_t(0)); @@ -731,7 +839,7 @@ private: Expression* makeZeroExt(Expression* curr, int32_t bits) { Builder builder(*getModule()); - return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits)))); + return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits)))); } // given an "almost" sign extend - either a proper one, or it @@ -748,6 +856,18 @@ private: innerConst->value = innerConst->value.sub(outerConst->value); return inner; } + + // check if an expression is already sign-extended + bool isSignExted(Expression* curr, Index bits) { + if (Properties::getSignExtValue(curr)) { + return Properties::getSignExtBits(curr) == bits; + } + if (auto* get = curr->dynCast<GetLocal>()) { + // check what we know about the local + return localInfo[get->index].signExtedBits == bits; + } + return false; + } }; Pass *createOptimizeInstructionsPass() { diff --git a/src/passes/PickLoadSigns.cpp b/src/passes/PickLoadSigns.cpp new file mode 100644 index 000000000..6e44fddfe --- /dev/null +++ b/src/passes/PickLoadSigns.cpp @@ -0,0 +1,107 @@ +/* + * Copyright 2017 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <wasm.h> +#include <pass.h> +#include <ast/properties.h> + +namespace wasm { + +// Adjust load signedness based on usage. If a load only has uses that sign or +// unsign it anyhow, then it could be either, and picking the popular one can +// help remove the most sign/unsign operations +// unsigned, then it could be either + +struct PickLoadSigns : public WalkerPass<ExpressionStackWalker<PickLoadSigns, Visitor<PickLoadSigns>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new PickLoadSigns; } + + struct Usage { + Index signedUsages = 0; + Index signedBits; + Index unsignedUsages = 0; + Index unsignedBits; + Index totalUsages = 0; + }; + std::vector<Usage> usages; // local index => usage + + std::unordered_map<Load*, Index> loads; // loads that write to a local => the local + + void doWalkFunction(Function* func) { + // prepare + usages.resize(func->getNumLocals()); + // walk + ExpressionStackWalker<PickLoadSigns, Visitor<PickLoadSigns>>::doWalkFunction(func); + // optimize based on the info we saw + for (auto& pair : loads) { + auto* load = pair.first; + auto index = pair.second; + auto& usage = usages[index]; + // if we can't optimize, give up + if (usage.totalUsages == 0 || // no usages, so no idea + usage.signedUsages + usage.unsignedUsages != usage.totalUsages || // non-sign/unsigned usages, so cannot change + (usage.signedUsages != 0 && usage.signedBits != load->bytes * 8) || // sign usages exist but the wrong size + (usage.unsignedUsages != 0 && usage.unsignedBits != load->bytes * 8)) { // unsigned usages exist but the wrong size + continue; + } + // we can pick the optimal one. our hope is to remove 2 items per + // signed use (two shifts), so we factor that in + load->signed_ = usage.signedUsages * 2 >= usage.unsignedUsages; + } + } + + void visitGetLocal(GetLocal* curr) { + // this is a use. check from the context what it is, signed or unsigned, etc. + auto& usage = usages[curr->index]; + usage.totalUsages++; + if (expressionStack.size() >= 2) { + auto* parent = expressionStack[expressionStack.size() - 2]; + if (Properties::getZeroExtValue(parent)) { + auto bits = Properties::getZeroExtBits(parent); + if (usage.unsignedUsages == 0) { + usage.unsignedBits = bits; + } else if (usage.unsignedBits != bits) { + usage.unsignedBits = 0; + } + usage.unsignedUsages++; + } else if (expressionStack.size() >= 3) { + auto* grandparent = expressionStack[expressionStack.size() - 3]; + if (Properties::getSignExtValue(grandparent)) { + auto bits = Properties::getSignExtBits(grandparent); + if (usage.signedUsages == 0) { + usage.signedBits = bits; + } else if (usage.signedBits != bits) { + usage.signedBits = 0; + } + usage.signedUsages++; + } + } + } + } + + void visitSetLocal(SetLocal* curr) { + if (auto* load = curr->value->dynCast<Load>()) { + loads[load] = curr->index; + } + } +}; + +Pass *createPickLoadSignsPass() { + return new PickLoadSigns(); +} + +} // namespace wasm diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 32b596eef..df98590c5 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -77,6 +77,7 @@ void PassRegistry::registerPasses() { registerPass("nm", "name list", createNameListPass); registerPass("name-manager", "utility pass to manage names in modules", createNameManagerPass); registerPass("optimize-instructions", "optimizes instruction combinations", createOptimizeInstructionsPass); + registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass); registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass); registerPass("print", "print in s-expression format", createPrinterPass); registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass); @@ -112,6 +113,9 @@ void PassRunner::addDefaultFunctionOptimizationPasses() { add("remove-unused-brs"); add("remove-unused-names"); add("optimize-instructions"); + if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) { + add("pick-load-signs"); + } add("precompute"); if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) { add("code-pushing"); diff --git a/src/passes/passes.h b/src/passes/passes.h index cbfc48327..83bf556d8 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -39,6 +39,7 @@ Pass *createMetricsPass(); Pass *createNameListPass(); Pass *createNameManagerPass(); Pass *createOptimizeInstructionsPass(); +Pass *createPickLoadSignsPass(); Pass *createPostEmscriptenPass(); Pass *createPrinterPass(); Pass *createPrintCallGraphPass(); |