summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ir/bits.h301
-rw-r--r--src/passes/OptimizeInstructions.cpp153
-rw-r--r--test/example/cpp-unit.cpp34
3 files changed, 265 insertions, 223 deletions
diff --git a/src/ir/bits.h b/src/ir/bits.h
index 8b28bb031..e0bca8d87 100644
--- a/src/ir/bits.h
+++ b/src/ir/bits.h
@@ -20,92 +20,255 @@
#include "ir/literal-utils.h"
#include "support/bits.h"
#include "wasm-builder.h"
+#include <ir/load-utils.h>
namespace wasm {
-struct Bits {
- // get a mask to keep only the low # of bits
- static int32_t lowBitMask(int32_t bits) {
- uint32_t ret = -1;
- if (bits >= 32) {
- return ret;
- }
- return ret >> (32 - bits);
+namespace Bits {
+
+// get a mask to keep only the low # of bits
+inline int32_t lowBitMask(int32_t bits) {
+ uint32_t ret = -1;
+ if (bits >= 32) {
+ return ret;
}
+ return ret >> (32 - bits);
+}
- // checks if the input is a mask of lower bits, i.e., all 1s up to some high
- // bit, and all zeros from there. returns the number of masked bits, or 0 if
- // this is not such a mask
- static uint32_t getMaskedBits(uint32_t mask) {
- if (mask == uint32_t(-1)) {
- return 32; // all the bits
- }
- if (mask == 0) {
- return 0; // trivially not a mask
- }
- // otherwise, see if x & (x + 1) turns this into non-zero value
- // 00011111 & (00011111 + 1) => 0
- if (mask & (mask + 1)) {
- return 0;
- }
- // this is indeed a mask
- return 32 - CountLeadingZeroes(mask);
+// checks if the input is a mask of lower bits, i.e., all 1s up to some high
+// bit, and all zeros from there. returns the number of masked bits, or 0 if
+// this is not such a mask
+inline uint32_t getMaskedBits(uint32_t mask) {
+ if (mask == uint32_t(-1)) {
+ return 32; // all the bits
+ }
+ if (mask == 0) {
+ return 0; // trivially not a mask
}
+ // otherwise, see if x & (x + 1) turns this into non-zero value
+ // 00011111 & (00011111 + 1) => 0
+ if (mask & (mask + 1)) {
+ return 0;
+ }
+ // this is indeed a mask
+ return 32 - CountLeadingZeroes(mask);
+}
+
+// gets the number of effective shifts a shift operation does. In
+// wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
+inline Index getEffectiveShifts(Index amount, Type type) {
+ if (type == Type::i32) {
+ return amount & 31;
+ } else if (type == Type::i64) {
+ return amount & 63;
+ }
+ WASM_UNREACHABLE("unexpected type");
+}
- // gets the number of effective shifts a shift operation does. In
- // wasm, only 5 bits matter for 32-bit shifts, and 6 for 64.
- static Index getEffectiveShifts(Index amount, Type type) {
- if (type == Type::i32) {
- return amount & 31;
- } else if (type == Type::i64) {
- return amount & 63;
+inline Index getEffectiveShifts(Expression* expr) {
+ auto* amount = expr->cast<Const>();
+ if (amount->type == Type::i32) {
+ return getEffectiveShifts(amount->value.geti32(), Type::i32);
+ } else if (amount->type == Type::i64) {
+ return getEffectiveShifts(amount->value.geti64(), Type::i64);
+ }
+ WASM_UNREACHABLE("unexpected type");
+}
+
+inline Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
+ if (value->type == Type::i32) {
+ if (bytes == 1 || bytes == 2) {
+ auto shifts = bytes == 1 ? 24 : 16;
+ Builder builder(wasm);
+ return builder.makeBinary(
+ ShrSInt32,
+ builder.makeBinary(
+ ShlInt32,
+ value,
+ LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
+ LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
+ }
+ assert(bytes == 4);
+ return value; // nothing to do
+ } else {
+ assert(value->type == Type::i64);
+ if (bytes == 1 || bytes == 2 || bytes == 4) {
+ auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
+ Builder builder(wasm);
+ return builder.makeBinary(
+ ShrSInt64,
+ builder.makeBinary(
+ ShlInt64,
+ value,
+ LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
+ LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
}
- WASM_UNREACHABLE("unexpected type");
+ assert(bytes == 8);
+ return value; // nothing to do
}
+}
- static Index getEffectiveShifts(Expression* expr) {
- auto* amount = expr->cast<Const>();
- if (amount->type == Type::i32) {
- return getEffectiveShifts(amount->value.geti32(), Type::i32);
- } else if (amount->type == Type::i64) {
- return getEffectiveShifts(amount->value.geti64(), Type::i64);
+// getMaxBits() helper that has pessimistic results for the bits used in locals.
+struct DummyLocalInfoProvider {
+ Index getMaxBitsForLocal(LocalGet* get) {
+ if (get->type == Type::i32) {
+ return 32;
}
- WASM_UNREACHABLE("unexpected type");
+ if (get->type == Type::i32) {
+ return 64;
+ }
+ WASM_UNREACHABLE("type has no integer bit size");
}
+};
- static Expression* makeSignExt(Expression* value, Index bytes, Module& wasm) {
- if (value->type == Type::i32) {
- if (bytes == 1 || bytes == 2) {
- auto shifts = bytes == 1 ? 24 : 16;
- Builder builder(wasm);
- return builder.makeBinary(
- ShrSInt32,
- builder.makeBinary(
- ShlInt32,
- value,
- LiteralUtils::makeFromInt32(shifts, Type::i32, wasm)),
- LiteralUtils::makeFromInt32(shifts, Type::i32, wasm));
+// Returns the maximum amount of bits used in an integer expression
+// not extremely precise (doesn't look into add operands, etc.)
+// LocalInfoProvider is an optional class that can provide answers about
+// local.get.
+template<typename LocalInfoProvider = DummyLocalInfoProvider>
+Index getMaxBits(Expression* curr,
+ LocalInfoProvider* localInfoProvider = nullptr) {
+ if (auto* const_ = curr->dynCast<Const>()) {
+ switch (curr->type.getSingle()) {
+ case Type::i32:
+ return 32 - const_->value.countLeadingZeroes().geti32();
+ case Type::i64:
+ return 64 - const_->value.countLeadingZeroes().geti64();
+ default:
+ WASM_UNREACHABLE("invalid type");
+ }
+ } else if (auto* binary = curr->dynCast<Binary>()) {
+ switch (binary->op) {
+ // 32-bit
+ case AddInt32:
+ case SubInt32:
+ case MulInt32:
+ case DivSInt32:
+ case DivUInt32:
+ case RemSInt32:
+ case RemUInt32:
+ case RotLInt32:
+ case RotRInt32:
+ return 32;
+ case AndInt32:
+ return std::min(getMaxBits(binary->left, localInfoProvider),
+ getMaxBits(binary->right, localInfoProvider));
+ case OrInt32:
+ case XorInt32:
+ return std::max(getMaxBits(binary->left, localInfoProvider),
+ getMaxBits(binary->right, localInfoProvider));
+ case ShlInt32: {
+ if (auto* shifts = binary->right->dynCast<Const>()) {
+ return std::min(Index(32),
+ getMaxBits(binary->left, localInfoProvider) +
+ Bits::getEffectiveShifts(shifts));
+ }
+ return 32;
}
- assert(bytes == 4);
- return value; // nothing to do
- } else {
- assert(value->type == Type::i64);
- if (bytes == 1 || bytes == 2 || bytes == 4) {
- auto shifts = bytes == 1 ? 56 : (bytes == 2 ? 48 : 32);
- Builder builder(wasm);
- return builder.makeBinary(
- ShrSInt64,
- builder.makeBinary(
- ShlInt64,
- value,
- LiteralUtils::makeFromInt32(shifts, Type::i64, wasm)),
- LiteralUtils::makeFromInt32(shifts, Type::i64, wasm));
+ case ShrUInt32: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 32;
+ }
+ case ShrSInt32: {
+ if (auto* shift = binary->right->dynCast<Const>()) {
+ auto maxBits = getMaxBits(binary->left, localInfoProvider);
+ if (maxBits == 32) {
+ return 32;
+ }
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(shift)),
+ maxBits); // can ignore more shifts than zero us out
+ return std::max(Index(0), maxBits - shifts);
+ }
+ return 32;
+ }
+ // 64-bit TODO
+ // comparisons
+ case EqInt32:
+ case NeInt32:
+ case LtSInt32:
+ case LtUInt32:
+ case LeSInt32:
+ case LeUInt32:
+ case GtSInt32:
+ case GtUInt32:
+ case GeSInt32:
+ case GeUInt32:
+ case EqInt64:
+ case NeInt64:
+ case LtSInt64:
+ case LtUInt64:
+ case LeSInt64:
+ case LeUInt64:
+ case GtSInt64:
+ case GtUInt64:
+ case GeSInt64:
+ case GeUInt64:
+ case EqFloat32:
+ case NeFloat32:
+ case LtFloat32:
+ case LeFloat32:
+ case GtFloat32:
+ case GeFloat32:
+ case EqFloat64:
+ case NeFloat64:
+ case LtFloat64:
+ case LeFloat64:
+ case GtFloat64:
+ case GeFloat64:
+ return 1;
+ default: {
}
- assert(bytes == 8);
- return value; // nothing to do
+ }
+ } else if (auto* unary = curr->dynCast<Unary>()) {
+ switch (unary->op) {
+ case ClzInt32:
+ case CtzInt32:
+ case PopcntInt32:
+ return 6;
+ case ClzInt64:
+ case CtzInt64:
+ case PopcntInt64:
+ return 7;
+ case EqZInt32:
+ case EqZInt64:
+ return 1;
+ case WrapInt64:
+ return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
+ default: {
+ }
+ }
+ } else if (auto* set = curr->dynCast<LocalSet>()) {
+ // a tee passes through the value
+ return getMaxBits(set->value, localInfoProvider);
+ } else if (auto* get = curr->dynCast<LocalGet>()) {
+ return localInfoProvider->getMaxBitsForLocal(get);
+ } else if (auto* load = curr->dynCast<Load>()) {
+ // if signed, then the sign-extension might fill all the bits
+ // if unsigned, then we have a limit
+ if (LoadUtils::isSignRelevant(load) && !load->signed_) {
+ return 8 * load->bytes;
}
}
-};
+ switch (curr->type.getSingle()) {
+ case Type::i32:
+ return 32;
+ case Type::i64:
+ return 64;
+ case Type::unreachable:
+ return 64; // not interesting, but don't crash
+ default:
+ WASM_UNREACHABLE("invalid type");
+ }
+}
+
+} // namespace Bits
} // namespace wasm
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 860ea3ac7..7f66ca7ee 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -22,6 +22,7 @@
#include <type_traits>
#include <ir/abstract.h>
+#include <ir/bits.h>
#include <ir/cost.h>
#include <ir/effects.h>
#include <ir/literal-utils.h>
@@ -45,152 +46,6 @@ Name F32_EXPR = "f32.expr";
Name F64_EXPR = "f64.expr";
Name ANY_EXPR = "any.expr";
-// Utilities
-
-// returns the maximum amount of bits used in an integer expression
-// not extremely precise (doesn't look into add operands, etc.)
-// LocalInfoProvider is an optional class that can provide answers about
-// local.get.
-template<typename LocalInfoProvider>
-Index getMaxBits(Expression* curr, LocalInfoProvider* localInfoProvider) {
- if (auto* const_ = curr->dynCast<Const>()) {
- switch (curr->type.getSingle()) {
- case Type::i32:
- return 32 - const_->value.countLeadingZeroes().geti32();
- case Type::i64:
- return 64 - const_->value.countLeadingZeroes().geti64();
- default:
- WASM_UNREACHABLE("invalid type");
- }
- } else if (auto* binary = curr->dynCast<Binary>()) {
- switch (binary->op) {
- // 32-bit
- case AddInt32:
- case SubInt32:
- case MulInt32:
- case DivSInt32:
- case DivUInt32:
- case RemSInt32:
- case RemUInt32:
- case RotLInt32:
- case RotRInt32:
- return 32;
- case AndInt32:
- return std::min(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
- case OrInt32:
- case XorInt32:
- return std::max(getMaxBits(binary->left, localInfoProvider),
- getMaxBits(binary->right, localInfoProvider));
- case ShlInt32: {
- if (auto* shifts = binary->right->dynCast<Const>()) {
- return std::min(Index(32),
- getMaxBits(binary->left, localInfoProvider) +
- Bits::getEffectiveShifts(shifts));
- }
- return 32;
- }
- case ShrUInt32: {
- if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left, localInfoProvider);
- auto shifts =
- std::min(Index(Bits::getEffectiveShifts(shift)),
- maxBits); // can ignore more shifts than zero us out
- return std::max(Index(0), maxBits - shifts);
- }
- return 32;
- }
- case ShrSInt32: {
- if (auto* shift = binary->right->dynCast<Const>()) {
- auto maxBits = getMaxBits(binary->left, localInfoProvider);
- if (maxBits == 32) {
- return 32;
- }
- auto shifts =
- std::min(Index(Bits::getEffectiveShifts(shift)),
- maxBits); // can ignore more shifts than zero us out
- return std::max(Index(0), maxBits - shifts);
- }
- return 32;
- }
- // 64-bit TODO
- // comparisons
- case EqInt32:
- case NeInt32:
- case LtSInt32:
- case LtUInt32:
- case LeSInt32:
- case LeUInt32:
- case GtSInt32:
- case GtUInt32:
- case GeSInt32:
- case GeUInt32:
- case EqInt64:
- case NeInt64:
- case LtSInt64:
- case LtUInt64:
- case LeSInt64:
- case LeUInt64:
- case GtSInt64:
- case GtUInt64:
- case GeSInt64:
- case GeUInt64:
- case EqFloat32:
- case NeFloat32:
- case LtFloat32:
- case LeFloat32:
- case GtFloat32:
- case GeFloat32:
- case EqFloat64:
- case NeFloat64:
- case LtFloat64:
- case LeFloat64:
- case GtFloat64:
- case GeFloat64:
- return 1;
- default: {}
- }
- } else if (auto* unary = curr->dynCast<Unary>()) {
- switch (unary->op) {
- case ClzInt32:
- case CtzInt32:
- case PopcntInt32:
- return 6;
- case ClzInt64:
- case CtzInt64:
- case PopcntInt64:
- return 7;
- case EqZInt32:
- case EqZInt64:
- return 1;
- case WrapInt64:
- return std::min(Index(32), getMaxBits(unary->value, localInfoProvider));
- default: {}
- }
- } else if (auto* set = curr->dynCast<LocalSet>()) {
- // a tee passes through the value
- return getMaxBits(set->value, localInfoProvider);
- } else if (auto* get = curr->dynCast<LocalGet>()) {
- return localInfoProvider->getMaxBitsForLocal(get);
- } else if (auto* load = curr->dynCast<Load>()) {
- // if signed, then the sign-extension might fill all the bits
- // if unsigned, then we have a limit
- if (LoadUtils::isSignRelevant(load) && !load->signed_) {
- return 8 * load->bytes;
- }
- }
- switch (curr->type.getSingle()) {
- case Type::i32:
- return 32;
- case Type::i64:
- return 64;
- case Type::unreachable:
- return 64; // not interesting, but don't crash
- default:
- WASM_UNREACHABLE("invalid type");
- }
-}
-
// Useful information about locals
struct LocalInfo {
static const Index kUnknown = Index(-1);
@@ -243,7 +98,7 @@ struct LocalScanner : PostWalker<LocalScanner> {
auto* value = Properties::getFallthrough(
curr->value, passOptions, getModule()->features);
auto& info = localInfo[curr->index];
- info.maxBits = std::max(info.maxBits, getMaxBits(value, this));
+ info.maxBits = std::max(info.maxBits, Bits::getMaxBits(value, this));
auto signExtBits = LocalInfo::kUnknown;
if (Properties::getSignExtValue(value)) {
signExtBits = Properties::getSignExtBits(value);
@@ -373,7 +228,7 @@ struct OptimizeInstructions
// if the sign-extend input cannot have a sign bit, we don't need it
// we also don't need it if it already has an identical-sized sign
// extend
- if (getMaxBits(ext, this) + extraShifts < bits ||
+ if (Bits::getMaxBits(ext, this) + extraShifts < bits ||
isSignExted(ext, bits)) {
return removeAlmostSignExt(binary);
}
@@ -538,7 +393,7 @@ struct OptimizeInstructions
return binary->left;
}
} else if (auto maskedBits = Bits::getMaskedBits(mask)) {
- if (getMaxBits(binary->left, this) <= maskedBits) {
+ if (Bits::getMaxBits(binary->left, this) <= maskedBits) {
// a mask of lower bits is not needed if we are already smaller
return binary->left;
}
diff --git a/test/example/cpp-unit.cpp b/test/example/cpp-unit.cpp
index e6189d9d1..2ba4388d1 100644
--- a/test/example/cpp-unit.cpp
+++ b/test/example/cpp-unit.cpp
@@ -1,17 +1,41 @@
// test multiple uses of the threadPool
-#include <assert.h>
+#include <iostream>
-#include <wasm.h>
+#include <ir/bits.h>
#include <ir/cost.h>
+#include <wasm.h>
using namespace wasm;
-int main()
-{
+void compare(size_t x, size_t y) {
+ if (x != y) {
+ std::cout << "comparison error!\n" << x << '\n' << y << '\n';
+ abort();
+ }
+}
+
+void test_bits() {
+ Const c;
+ c.type = Type::i32;
+ c.value = Literal(int32_t(1));
+ compare(Bits::getMaxBits(&c), 1);
+ c.value = Literal(int32_t(2));
+ compare(Bits::getMaxBits(&c), 2);
+ c.value = Literal(int32_t(3));
+ compare(Bits::getMaxBits(&c), 2);
+}
+
+void test_cost() {
// Some optimizations assume that the cost of a get is zero, e.g. local-cse.
LocalGet get;
- assert(CostAnalyzer(&get).cost == 0);
+ compare(CostAnalyzer(&get).cost, 0);
+}
+
+int main() {
+ test_bits();
+
+ test_cost();
std::cout << "Success.\n";
return 0;