summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast/bits.h47
-rw-r--r--src/ast/properties.h77
-rw-r--r--src/passes/CMakeLists.txt1
-rw-r--r--src/passes/OptimizeInstructions.cpp282
-rw-r--r--src/passes/PickLoadSigns.cpp107
-rw-r--r--src/passes/pass.cpp4
-rw-r--r--src/passes/passes.h1
7 files changed, 438 insertions, 81 deletions
diff --git a/src/ast/bits.h b/src/ast/bits.h
new file mode 100644
index 000000000..d88cb5edb
--- /dev/null
+++ b/src/ast/bits.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ast_bits_h
+#define wasm_ast_bits_h
+
+#include "support/bits.h"
+
+namespace wasm {
+
+struct Bits {
+ // get a mask to keep only the low # of bits
+ static int32_t lowBitMask(int32_t bits) {
+ uint32_t ret = -1;
+ if (bits >= 32) return ret;
+ return ret >> (32 - bits);
+ }
+
+ // checks if the input is a mask of lower bits, i.e., all 1s up to some high bit, and all zeros
+ // from there. returns the number of masked bits, or 0 if this is not such a mask
+ static uint32_t getMaskedBits(uint32_t mask) {
+ if (mask == uint32_t(-1)) return 32; // all the bits
+ if (mask == 0) return 0; // trivially not a mask
+ // otherwise, see if adding one turns this into a 1-bit thing, 00011111 + 1 => 00100000
+ if (PopCount(mask + 1) != 1) return 0;
+ // this is indeed a mask
+ return 32 - CountLeadingZeroes(mask);
+ }
+};
+
+} // namespace wasm
+
+#endif // wasm_ast_bits_h
+
diff --git a/src/ast/properties.h b/src/ast/properties.h
index 9834c73d0..9121deba5 100644
--- a/src/ast/properties.h
+++ b/src/ast/properties.h
@@ -18,6 +18,7 @@
#define wasm_ast_properties_h
#include "wasm.h"
+#include "ast/bits.h"
namespace wasm {
@@ -52,6 +53,82 @@ struct Properties {
default: return false;
}
}
+
+ // Check if an expression is a sign-extend, and if so, returns the value
+ // that is extended, otherwise nullptr
+ static Expression* getSignExtValue(Expression* curr) {
+ if (auto* outer = curr->dynCast<Binary>()) {
+ if (outer->op == ShrSInt32) {
+ if (auto* outerConst = outer->right->dynCast<Const>()) {
+ if (auto* inner = outer->left->dynCast<Binary>()) {
+ if (inner->op == ShlInt32) {
+ if (auto* innerConst = inner->right->dynCast<Const>()) {
+ if (outerConst->value == innerConst->value) {
+ return inner->left;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ // gets the size of the sign-extended value
+ static Index getSignExtBits(Expression* curr) {
+ return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32();
+ }
+
+ // Check if an expression is almost a sign-extend: perhaps the inner shift
+ // is too large. We can split the shifts in that case, which is sometimes
+ // useful (e.g. if we can remove the signext)
+ static Expression* getAlmostSignExt(Expression* curr) {
+ if (auto* outer = curr->dynCast<Binary>()) {
+ if (outer->op == ShrSInt32) {
+ if (auto* outerConst = outer->right->dynCast<Const>()) {
+ if (auto* inner = outer->left->dynCast<Binary>()) {
+ if (inner->op == ShlInt32) {
+ if (auto* innerConst = inner->right->dynCast<Const>()) {
+ if (outerConst->value.leU(innerConst->value).geti32()) {
+ return inner->left;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ // gets the size of the almost sign-extended value, as well as the
+ // extra shifts, if any
+ static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) {
+ extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() -
+ curr->cast<Binary>()->right->cast<Const>()->value.geti32();
+ return getSignExtBits(curr);
+ }
+
+ // Check if an expression is a zero-extend, and if so, returns the value
+ // that is extended, otherwise nullptr
+ static Expression* getZeroExtValue(Expression* curr) {
+ if (auto* binary = curr->dynCast<Binary>()) {
+ if (binary->op == AndInt32) {
+ if (auto* c = binary->right->dynCast<Const>()) {
+ if (Bits::getMaskedBits(c->value.geti32())) {
+ return binary->right;
+ }
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ // gets the size of the sign-extended value
+ static Index getZeroExtBits(Expression* curr) {
+ return Bits::getMaskedBits(curr->cast<Binary>()->right->cast<Const>()->value.geti32());
+ }
};
} // wasm
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 9db1c66ae..74f120e85 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -13,6 +13,7 @@ SET(passes_SOURCES
NameManager.cpp
NameList.cpp
OptimizeInstructions.cpp
+ PickLoadSigns.cpp
PostEmscripten.cpp
Precompute.cpp
Print.cpp
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index bb4748a97..79c01dee7 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -166,7 +166,10 @@ struct Match {
// returns the maximum amount of bits used in an integer expression
// not extremely precise (doesn't look into add operands, etc.)
-static Index getMaxBits(Expression* curr) {
+// LocalInfoProvider is an optional class that can provide answers about
+// get_local.
+template<typename LocalInfoProvider>
+Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) {
if (auto* const_ = curr->dynCast<Const>()) {
switch (curr->type) {
case i32: return 32 - const_->value.countLeadingZeroes().geti32();
@@ -179,17 +182,17 @@ static Index getMaxBits(Expression* curr) {
case AddInt32: case SubInt32: case MulInt32:
case DivSInt32: case DivUInt32: case RemSInt32:
case RemUInt32: case RotLInt32: case RotRInt32: return 32;
- case AndInt32: case XorInt32: return std::min(getMaxBits(binary->left), getMaxBits(binary->right));
- case OrInt32: return std::max(getMaxBits(binary->left), getMaxBits(binary->right));
+ case AndInt32: return std::min(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider));
+ case OrInt32: case XorInt32: return std::max(getMaxBits(binary->left, localInfoProvider), getMaxBits(binary->right, localInfoProvider));
case ShlInt32: {
if (auto* shifts = binary->right->dynCast<Const>()) {
- return std::min(Index(32), getMaxBits(binary->left) + shifts->value.geti32());
+ return std::min(Index(32), getMaxBits(binary->left, localInfoProvider) + shifts->value.geti32());
}
return 32;
}
case ShrUInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left);
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
}
@@ -197,7 +200,7 @@ static Index getMaxBits(Expression* curr) {
}
case ShrSInt32: {
if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left);
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
if (maxBits == 32) return 32;
auto shifts = std::min(Index(shift->value.geti32()), maxBits); // can ignore more shifts than zero us out
return std::max(Index(0), maxBits - shifts);
@@ -225,9 +228,19 @@ static Index getMaxBits(Expression* curr) {
case ClzInt32: case CtzInt32: case PopcntInt32: return 5;
case ClzInt64: case CtzInt64: case PopcntInt64: return 6;
case EqZInt32: case EqZInt64: return 1;
- case WrapInt64: return std::min(Index(32), getMaxBits(unary->value));
+ case WrapInt64: return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
default: {}
}
+ } else if (auto* set = curr->dynCast<SetLocal>()) {
+ // a tee passes through the value
+ return getMaxBits(set->value, localInfoProvider);
+ } else if (auto* get = curr->dynCast<GetLocal>()) {
+ return localInfoProvider->getMaxBitsForLocal(get);
+ } else if (auto* load = curr->dynCast<Load>()) {
+ // if signed, then the sign-extension might fill all the bits
+ if (!load->signed_) {
+ return 8 * load->bytes;
+ }
}
switch (curr->type) {
case i32: return 32;
@@ -237,68 +250,95 @@ static Index getMaxBits(Expression* curr) {
}
}
-// Check if an expression is a sign-extend, and if so, returns the value
-// that is extended, otherwise nullptr
-static Expression* getSignExt(Expression* curr) {
- if (auto* outer = curr->dynCast<Binary>()) {
- if (outer->op == ShrSInt32) {
- if (auto* outerConst = outer->right->dynCast<Const>()) {
- if (auto* inner = outer->left->dynCast<Binary>()) {
- if (inner->op == ShlInt32) {
- if (auto* innerConst = inner->right->dynCast<Const>()) {
- if (outerConst->value == innerConst->value) {
- return inner->left;
- }
- }
- }
- }
- }
+// looks through fallthrough operations, like tee_local, block fallthrough, etc.
+// too and block fallthroughs, etc.
+Expression* getFallthrough(Expression* curr) {
+ if (auto* set = curr->dynCast<SetLocal>()) {
+ if (set->isTee()) {
+ return getFallthrough(set->value);
+ }
+ } else if (auto* block = curr->dynCast<Block>()) {
+ // if no name, we can't be broken to, and then can look at the fallthrough
+ if (!block->name.is() && block->list.size() > 0) {
+ return getFallthrough(block->list.back());
}
}
- return nullptr;
+ return curr;
}
-// gets the size of the sign-extended value
-static Index getSignExtBits(Expression* curr) {
- return 32 - curr->cast<Binary>()->right->cast<Const>()->value.geti32();
-}
+// Useful information about locals
+struct LocalInfo {
+ static const Index kUnknown = Index(-1);
-// Check if an expression is almost a sign-extend: perhaps the inner shift
-// is too large. We can split the shifts in that case, which is sometimes
-// useful (e.g. if we can remove the signext)
-static Expression* getAlmostSignExt(Expression* curr) {
- if (auto* outer = curr->dynCast<Binary>()) {
- if (outer->op == ShrSInt32) {
- if (auto* outerConst = outer->right->dynCast<Const>()) {
- if (auto* inner = outer->left->dynCast<Binary>()) {
- if (inner->op == ShlInt32) {
- if (auto* innerConst = inner->right->dynCast<Const>()) {
- if (outerConst->value.leU(innerConst->value).geti32()) {
- return inner->left;
- }
- }
- }
- }
+ Index maxBits;
+ Index signExtedBits;
+};
+
+struct LocalScanner : PostWalker<LocalScanner, Visitor<LocalScanner>> {
+ std::vector<LocalInfo>& localInfo;
+
+ LocalScanner(std::vector<LocalInfo>& localInfo) : localInfo(localInfo) {}
+
+ void doWalkFunction(Function* func) {
+ // prepare
+ localInfo.resize(func->getNumLocals());
+ for (Index i = 0; i < func->getNumLocals(); i++) {
+ auto& info = localInfo[i];
+ if (func->isParam(i)) {
+ info.maxBits = getBitsForType(func->getLocalType(i)); // worst-case
+ info.signExtedBits = LocalInfo::kUnknown; // we will never know anything
+ } else {
+ info.maxBits = info.signExtedBits = 0; // we are open to learning
+ }
+ }
+ // walk
+ PostWalker<LocalScanner, Visitor<LocalScanner>>::doWalkFunction(func);
+ // finalize
+ for (Index i = 0; i < func->getNumLocals(); i++) {
+ auto& info = localInfo[i];
+ if (info.signExtedBits == LocalInfo::kUnknown) {
+ info.signExtedBits = 0;
}
}
}
- return nullptr;
-}
-// gets the size of the almost sign-extended value, as well as the
-// extra shifts, if any
-static Index getAlmostSignExtBits(Expression* curr, Index& extraShifts) {
- extraShifts = curr->cast<Binary>()->left->cast<Binary>()->right->cast<Const>()->value.geti32() -
- curr->cast<Binary>()->right->cast<Const>()->value.geti32();
- return getSignExtBits(curr);
-}
+ void visitSetLocal(SetLocal* curr) {
+ auto* func = getFunction();
+ if (func->isParam(curr->index)) return;
+ auto type = getFunction()->getLocalType(curr->index);
+ if (type != i32 && type != i64) return;
+ // an integer var, worth processing
+ auto* value = getFallthrough(curr->value);
+ auto& info = localInfo[curr->index];
+ info.maxBits = std::max(info.maxBits, getMaxBits(value, this));
+ auto signExtBits = LocalInfo::kUnknown;
+ if (Properties::getSignExtValue(value)) {
+ signExtBits = Properties::getSignExtBits(value);
+ } else if (auto* load = value->dynCast<Load>()) {
+ if (load->signed_) {
+ signExtBits = load->bytes * 8;
+ }
+ }
+ if (info.signExtedBits == 0) {
+ info.signExtedBits = signExtBits; // first info we see
+ } else if (info.signExtedBits != signExtBits) {
+ info.signExtedBits = LocalInfo::kUnknown; // contradictory information, give up
+ }
+ }
-// get a mask to keep only the low # of bits
-static int32_t lowBitMask(int32_t bits) {
- uint32_t ret = -1;
- if (bits >= 32) return ret;
- return ret >> (32 - bits);
-}
+ // define this for the templated getMaxBits method. we know nothing here yet about locals, so return the maxes
+ Index getMaxBitsForLocal(GetLocal* get) {
+ return getBitsForType(get->type);
+ }
+
+ Index getBitsForType(WasmType type) {
+ switch (type) {
+ case i32: return 32;
+ case i64: return 64;
+ default: return -1;
+ }
+ }
+};
// Main pass class
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> {
@@ -312,6 +352,16 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
#endif
}
+ void doWalkFunction(Function* func) {
+ // first, scan locals
+ {
+ LocalScanner scanner(localInfo);
+ scanner.walkFunction(func);
+ }
+ // main walk
+ WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>>::doWalkFunction(func);
+ }
+
void visitExpression(Expression* curr) {
// we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles
while (1) {
@@ -351,40 +401,63 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
std::swap(binary->left, binary->right);
}
}
- if (auto* ext = getAlmostSignExt(binary)) {
+ if (auto* ext = Properties::getAlmostSignExt(binary)) {
Index extraShifts;
- auto bits = getAlmostSignExtBits(binary, extraShifts);
- auto* load = ext->dynCast<Load>();
- // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
- if (load && ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16))) {
- load->signed_ = true;
- return removeAlmostSignExt(binary);
+ auto bits = Properties::getAlmostSignExtBits(binary, extraShifts);
+ if (auto* load = getFallthrough(ext)->dynCast<Load>()) {
+ // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
+ if ((load->bytes == 1 && bits == 8) || (load->bytes == 2 && bits == 16)) {
+ // if the value falls through, we can't alter the load, as it might be captured in a tee
+ if (load->signed_ == true || load == ext) {
+ load->signed_ = true;
+ return removeAlmostSignExt(binary);
+ }
+ }
}
// if the sign-extend input cannot have a sign bit, we don't need it
- if (getMaxBits(ext) + extraShifts < bits) {
+ // we also don't need it if it already has an identical-sized sign extend
+ if (getMaxBits(ext, this) + extraShifts < bits || isSignExted(ext, bits)) {
return removeAlmostSignExt(binary);
}
- } else if (binary->op == EqInt32) {
+ } else if (binary->op == EqInt32 || binary->op == NeInt32) {
if (auto* c = binary->right->dynCast<Const>()) {
- if (auto* ext = getSignExt(binary->left)) {
+ if (auto* ext = Properties::getSignExtValue(binary->left)) {
// we are comparing a sign extend to a constant, which means we can use a cheaper zext
- auto bits = getSignExtBits(binary->left);
+ auto bits = Properties::getSignExtBits(binary->left);
binary->left = makeZeroExt(ext, bits);
// the const we compare to only needs the relevant bits
- c->value = c->value.and_(Literal(lowBitMask(bits)));
+ c->value = c->value.and_(Literal(Bits::lowBitMask(bits)));
return binary;
}
- if (c->value.geti32() == 0) {
+ if (binary->op == EqInt32 && c->value.geti32() == 0) {
// equal 0 => eqz
return Builder(*getModule()).makeUnary(EqZInt32, binary->left);
}
- } else if (auto* left = getSignExt(binary->left)) {
- if (auto* right = getSignExt(binary->right)) {
+ } else if (auto* left = Properties::getSignExtValue(binary->left)) {
+ if (auto* right = Properties::getSignExtValue(binary->right)) {
// we are comparing two sign-exts, so we may as well replace both with cheaper zexts
- auto bits = getSignExtBits(binary->left);
+ auto bits = Properties::getSignExtBits(binary->left);
binary->left = makeZeroExt(left, bits);
binary->right = makeZeroExt(right, bits);
return binary;
+ } else if (auto* load = binary->right->dynCast<Load>()) {
+ // we are comparing a load to a sign-ext, we may be able to switch to zext
+ auto leftBits = Properties::getSignExtBits(binary->left);
+ if (load->signed_ && leftBits == load->bytes * 8) {
+ load->signed_ = false;
+ binary->left = makeZeroExt(left, leftBits);
+ return binary;
+ }
+ }
+ } else if (auto* load = binary->left->dynCast<Load>()) {
+ if (auto* right = Properties::getSignExtValue(binary->right)) {
+ // we are comparing a load to a sign-ext, we may be able to switch to zext
+ auto rightBits = Properties::getSignExtBits(binary->right);
+ if (load->signed_ && rightBits == load->bytes * 8) {
+ load->signed_ = false;
+ binary->right = makeZeroExt(right, rightBits);
+ return binary;
+ }
}
}
// note that both left and right may be consts, but then we let precompute compute the constant result
@@ -404,11 +477,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
if ((load->bytes == 1 && mask == 0xff) ||
(load->bytes == 2 && mask == 0xffff)) {
load->signed_ = false;
- return load;
+ return binary->left;
+ }
+ } else if (auto maskedBits = Bits::getMaskedBits(mask)) {
+ if (getMaxBits(binary->left, this) <= maskedBits) {
+ // a mask of lower bits is not needed if we are already smaller
+ return binary->left;
}
- } else if (mask == 1 && Properties::emitsBoolean(binary->left)) {
- // (bool) & 1 does not need the outer mask
- return binary->left;
}
}
// the square of some operations can be merged
@@ -468,6 +543,13 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
default: {}
}
}
+ // eqz of a sign extension can be of zero-extension
+ if (auto* ext = Properties::getSignExtValue(unary->value)) {
+ // we are comparing a sign extend to a constant, which means we can use a cheaper zext
+ auto bits = Properties::getSignExtBits(unary->value);
+ unary->value = makeZeroExt(ext, bits);
+ return unary;
+ }
}
} else if (auto* set = curr->dynCast<SetGlobal>()) {
// optimize out a set of a get
@@ -518,6 +600,12 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
}
}
}
+ } else if (auto* ext = Properties::getSignExtValue(binary)) {
+ // if sign extending the exact bit size we store, we can skip the extension
+ // if extending something bigger, then we just alter bits we don't save anyhow
+ if (Properties::getSignExtBits(binary) >= store->bytes * 8) {
+ store->value = ext;
+ }
}
} else if (auto* unary = store->value->dynCast<Unary>()) {
if (unary->op == WrapInt64) {
@@ -530,7 +618,15 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
return nullptr;
}
+ Index getMaxBitsForLocal(GetLocal* get) {
+ // check what we know about the local
+ return localInfo[get->index].maxBits;
+ }
+
private:
+ // Information about our locals
+ std::vector<LocalInfo> localInfo;
+
// Optimize given that the expression is flowing into a boolean context
Expression* optimizeBoolean(Expression* boolean) {
if (auto* unary = boolean->dynCast<Unary>()) {
@@ -554,6 +650,10 @@ private:
}
}
}
+ if (auto* ext = Properties::getSignExtValue(binary)) {
+ // use a cheaper zero-extent, we just care about the boolean value anyhow
+ return makeZeroExt(ext, Properties::getSignExtBits(binary));
+ }
} else if (auto* block = boolean->dynCast<Block>()) {
if (block->type == i32 && block->list.size() > 0) {
block->list.back() = optimizeBoolean(block->list.back());
@@ -611,7 +711,15 @@ private:
};
// find all factors
seek(binary, 1);
- if (constants.size() <= 1) return nullptr; // nothing to do
+ if (constants.size() <= 1) {
+ // nothing much to do, except for the trivial case of adding/subbing a zero
+ if (auto* c = binary->right->dynCast<Const>()) {
+ if (c->value.geti32() == 0) {
+ return binary->left;
+ }
+ }
+ return nullptr;
+ }
// wipe out all constants, we'll replace with a single added one
for (auto* c : constants) {
c->value = Literal(int32_t(0));
@@ -731,7 +839,7 @@ private:
Expression* makeZeroExt(Expression* curr, int32_t bits) {
Builder builder(*getModule());
- return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(lowBitMask(bits))));
+ return builder.makeBinary(AndInt32, curr, builder.makeConst(Literal(Bits::lowBitMask(bits))));
}
// given an "almost" sign extend - either a proper one, or it
@@ -748,6 +856,18 @@ private:
innerConst->value = innerConst->value.sub(outerConst->value);
return inner;
}
+
+ // check if an expression is already sign-extended
+ bool isSignExted(Expression* curr, Index bits) {
+ if (Properties::getSignExtValue(curr)) {
+ return Properties::getSignExtBits(curr) == bits;
+ }
+ if (auto* get = curr->dynCast<GetLocal>()) {
+ // check what we know about the local
+ return localInfo[get->index].signExtedBits == bits;
+ }
+ return false;
+ }
};
Pass *createOptimizeInstructionsPass() {
diff --git a/src/passes/PickLoadSigns.cpp b/src/passes/PickLoadSigns.cpp
new file mode 100644
index 000000000..6e44fddfe
--- /dev/null
+++ b/src/passes/PickLoadSigns.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <wasm.h>
+#include <pass.h>
+#include <ast/properties.h>
+
+namespace wasm {
+
+// Adjust load signedness based on usage. If a load only has uses that sign or
+// unsign it anyhow, then it could be either, and picking the popular one can
+// help remove the most sign/unsign operations
+// unsigned, then it could be either
+
+struct PickLoadSigns : public WalkerPass<ExpressionStackWalker<PickLoadSigns, Visitor<PickLoadSigns>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new PickLoadSigns; }
+
+ struct Usage {
+ Index signedUsages = 0;
+ Index signedBits;
+ Index unsignedUsages = 0;
+ Index unsignedBits;
+ Index totalUsages = 0;
+ };
+ std::vector<Usage> usages; // local index => usage
+
+ std::unordered_map<Load*, Index> loads; // loads that write to a local => the local
+
+ void doWalkFunction(Function* func) {
+ // prepare
+ usages.resize(func->getNumLocals());
+ // walk
+ ExpressionStackWalker<PickLoadSigns, Visitor<PickLoadSigns>>::doWalkFunction(func);
+ // optimize based on the info we saw
+ for (auto& pair : loads) {
+ auto* load = pair.first;
+ auto index = pair.second;
+ auto& usage = usages[index];
+ // if we can't optimize, give up
+ if (usage.totalUsages == 0 || // no usages, so no idea
+ usage.signedUsages + usage.unsignedUsages != usage.totalUsages || // non-sign/unsigned usages, so cannot change
+ (usage.signedUsages != 0 && usage.signedBits != load->bytes * 8) || // sign usages exist but the wrong size
+ (usage.unsignedUsages != 0 && usage.unsignedBits != load->bytes * 8)) { // unsigned usages exist but the wrong size
+ continue;
+ }
+ // we can pick the optimal one. our hope is to remove 2 items per
+ // signed use (two shifts), so we factor that in
+ load->signed_ = usage.signedUsages * 2 >= usage.unsignedUsages;
+ }
+ }
+
+ void visitGetLocal(GetLocal* curr) {
+ // this is a use. check from the context what it is, signed or unsigned, etc.
+ auto& usage = usages[curr->index];
+ usage.totalUsages++;
+ if (expressionStack.size() >= 2) {
+ auto* parent = expressionStack[expressionStack.size() - 2];
+ if (Properties::getZeroExtValue(parent)) {
+ auto bits = Properties::getZeroExtBits(parent);
+ if (usage.unsignedUsages == 0) {
+ usage.unsignedBits = bits;
+ } else if (usage.unsignedBits != bits) {
+ usage.unsignedBits = 0;
+ }
+ usage.unsignedUsages++;
+ } else if (expressionStack.size() >= 3) {
+ auto* grandparent = expressionStack[expressionStack.size() - 3];
+ if (Properties::getSignExtValue(grandparent)) {
+ auto bits = Properties::getSignExtBits(grandparent);
+ if (usage.signedUsages == 0) {
+ usage.signedBits = bits;
+ } else if (usage.signedBits != bits) {
+ usage.signedBits = 0;
+ }
+ usage.signedUsages++;
+ }
+ }
+ }
+ }
+
+ void visitSetLocal(SetLocal* curr) {
+ if (auto* load = curr->value->dynCast<Load>()) {
+ loads[load] = curr->index;
+ }
+ }
+};
+
+Pass *createPickLoadSignsPass() {
+ return new PickLoadSigns();
+}
+
+} // namespace wasm
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 32b596eef..df98590c5 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -77,6 +77,7 @@ void PassRegistry::registerPasses() {
registerPass("nm", "name list", createNameListPass);
registerPass("name-manager", "utility pass to manage names in modules", createNameManagerPass);
registerPass("optimize-instructions", "optimizes instruction combinations", createOptimizeInstructionsPass);
+ registerPass("pick-load-signs", "pick load signs based on their uses", createPickLoadSignsPass);
registerPass("post-emscripten", "miscellaneous optimizations for Emscripten-generated code", createPostEmscriptenPass);
registerPass("print", "print in s-expression format", createPrinterPass);
registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass);
@@ -112,6 +113,9 @@ void PassRunner::addDefaultFunctionOptimizationPasses() {
add("remove-unused-brs");
add("remove-unused-names");
add("optimize-instructions");
+ if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) {
+ add("pick-load-signs");
+ }
add("precompute");
if (options.optimizeLevel >= 2 || options.shrinkLevel >= 2) {
add("code-pushing");
diff --git a/src/passes/passes.h b/src/passes/passes.h
index cbfc48327..83bf556d8 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -39,6 +39,7 @@ Pass *createMetricsPass();
Pass *createNameListPass();
Pass *createNameManagerPass();
Pass *createOptimizeInstructionsPass();
+Pass *createPickLoadSignsPass();
Pass *createPostEmscriptenPass();
Pass *createPrinterPass();
Pass *createPrintCallGraphPass();