summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2016-09-14 21:28:43 -0700
committerGitHub <noreply@github.com>2016-09-14 21:28:43 -0700
commite567fa8675831e79f855cea2181fa58beb107e42 (patch)
tree14f1e37d27244b349e8ee34939119002f742748d /src
parent63b499e3ec9bbdf4e79ab6d9dc198299516e8aec (diff)
parentaf3bea2786fe62070522b7fd7add4290a4cb4e6d (diff)
downloadbinaryen-e567fa8675831e79f855cea2181fa58beb107e42.tar.gz
binaryen-e567fa8675831e79f855cea2181fa58beb107e42.tar.bz2
binaryen-e567fa8675831e79f855cea2181fa58beb107e42.zip
Merge pull request #695 from WebAssembly/opts
Get optimizer on par with emscripten asm.js optimizer
Diffstat (limited to 'src')
-rw-r--r--src/asm2wasm.h39
-rw-r--r--src/ast_utils.h103
-rw-r--r--src/pass.h19
-rw-r--r--src/passes/CMakeLists.txt2
-rw-r--r--src/passes/CoalesceLocals.cpp142
-rw-r--r--src/passes/DuplicateFunctionElimination.cpp10
-rw-r--r--src/passes/ExtractFunction.cpp46
-rw-r--r--src/passes/LowerInt64.cpp196
-rw-r--r--src/passes/Metrics.cpp2
-rw-r--r--src/passes/OptimizeInstructions.cpp80
-rw-r--r--src/passes/OptimizeInstructions.wast21
-rw-r--r--src/passes/OptimizeInstructions.wast.processed21
-rw-r--r--src/passes/RelooperJumpThreading.cpp248
-rw-r--r--src/passes/RemoveImports.cpp14
-rw-r--r--src/passes/RemoveUnusedBrs.cpp283
-rw-r--r--src/passes/SimplifyLocals.cpp67
-rw-r--r--src/passes/Vacuum.cpp73
-rw-r--r--src/passes/pass.cpp26
-rw-r--r--src/passes/passes.h2
-rw-r--r--src/wasm-builder.h48
-rw-r--r--src/wasm-module-building.h2
-rw-r--r--src/wasm-traversal.h2
-rw-r--r--src/wasm-validator.h3
-rw-r--r--src/wasm.h5
24 files changed, 1115 insertions, 339 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h
index b19451a7b..eb5413e20 100644
--- a/src/asm2wasm.h
+++ b/src/asm2wasm.h
@@ -523,6 +523,8 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions, [&](PassRunner& passRunner) {
// run autodrop first, before optimizations
passRunner.add<AutoDrop>();
+ // optimize relooper label variable usage at the wasm level, where it is easy
+ passRunner.add("relooper-jump-threading");
});
}
@@ -803,13 +805,16 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
add->right = parent->builder.makeConst(Literal((int32_t)parent->functionTableStarts[tableName]));
}
};
+
PassRunner passRunner(&wasm);
passRunner.add<FinalizeCalls>(this);
passRunner.add<ReFinalize>(); // FinalizeCalls changes call types, need to percolate
passRunner.add<AutoDrop>(); // FinalizeCalls may cause us to require additional drops
if (optimize) {
- passRunner.add("vacuum"); // autodrop can add some garbage
- passRunner.add("remove-unused-brs"); // vacuum may open up more opportunities
+ // autodrop can add some garbage
+ passRunner.add("vacuum");
+ passRunner.add("remove-unused-brs");
+ passRunner.add("optimize-instructions");
}
passRunner.run();
@@ -866,19 +871,17 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
#endif
-#if 0 // enable asm2wasm i64 optimizations when browsers have consistent i64 support in wasm
if (udivmoddi4.is() && getTempRet0.is()) {
// generate a wasm-optimized __udivmoddi4 method, which we can do much more efficiently in wasm
// we can only do this if we know getTempRet0 as well since we use it to figure out which minified global is tempRet0
// (getTempRet0 might be an import, if this is a shared module, so we can't optimize that case)
- int tempRet0;
+ Name tempRet0;
{
Expression* curr = wasm.getFunction(getTempRet0)->body;
if (curr->is<Block>()) curr = curr->cast<Block>()->list[0];
- curr = curr->cast<Return>()->value;
- auto* load = curr->cast<Load>();
- auto* ptr = load->ptr->cast<Const>();
- tempRet0 = ptr->value.geti32() + load->offset;
+ if (curr->is<Return>()) curr = curr->cast<Return>()->value;
+ auto* get = curr->cast<GetGlobal>();
+ tempRet0 = get->name;
}
// udivmoddi4 receives xl, xh, yl, yl, r, and
// if r then *r = x % y
@@ -898,13 +901,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
return builder.makeSetLocal(
target,
builder.makeBinary(
- Or,
+ OrInt64,
builder.makeUnary(
ExtendUInt32,
builder.makeGetLocal(low, i32)
),
builder.makeBinary(
- Shl,
+ ShlInt64,
builder.makeUnary(
ExtendUInt32,
builder.makeGetLocal(high, i32)
@@ -923,10 +926,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
8, 0, 8,
builder.makeGetLocal(r, i32),
builder.makeBinary(
- RemU,
+ RemUInt64,
builder.makeGetLocal(x64, i64),
builder.makeGetLocal(y64, i64)
- )
+ ),
+ i64
)
)
);
@@ -934,20 +938,19 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
builder.makeSetLocal(
x64,
builder.makeBinary(
- DivU,
+ DivUInt64,
builder.makeGetLocal(x64, i64),
builder.makeGetLocal(y64, i64)
)
)
);
body->list.push_back(
- builder.makeStore(
- 4, 0, 4,
- builder.makeConst(Literal(int32_t(tempRet0))),
+ builder.makeSetGlobal(
+ tempRet0,
builder.makeUnary(
WrapInt64,
builder.makeBinary(
- ShrU,
+ ShrUInt64,
builder.makeGetLocal(x64, i64),
builder.makeConst(Literal(int64_t(32)))
)
@@ -960,9 +963,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) {
builder.makeGetLocal(x64, i64)
)
);
+ body->finalize();
func->body = body;
}
-#endif
assert(WasmValidator().validate(wasm));
}
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 9b2ff10cd..4664f22ae 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -27,19 +27,27 @@ namespace wasm {
struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> {
Name target; // look for this one XXX looking by name may fall prey to duplicate names
- size_t found;
+ Index found;
+ WasmType valueType;
- BreakSeeker(Name target) : target(target), found(false) {}
+ BreakSeeker(Name target) : target(target), found(0) {}
+
+ void noteFound(Expression* value) {
+ found++;
+ if (found == 1) valueType = unreachable;
+ if (!value) valueType = none;
+ else if (value->type != unreachable) valueType = value->type;
+ }
void visitBreak(Break *curr) {
- if (curr->name == target) found++;
+ if (curr->name == target) noteFound(curr->value);
}
void visitSwitch(Switch *curr) {
for (auto name : curr->targets) {
- if (name == target) found++;
+ if (name == target) noteFound(curr->value);
}
- if (curr->default_ == target) found++;
+ if (curr->default_ == target) noteFound(curr->value);
}
static bool has(Expression* tree, Name target) {
@@ -47,6 +55,12 @@ struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> {
breakSeeker.walk(tree);
return breakSeeker.found > 0;
}
+
+ static Index count(Expression* tree, Name target) {
+ BreakSeeker breakSeeker(target);
+ breakSeeker.walk(tree);
+ return breakSeeker.found;
+ }
};
// Finds all functions that are reachable via direct calls.
@@ -92,13 +106,16 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
bool calls = false;
std::set<Index> localsRead;
std::set<Index> localsWritten;
+ std::set<Name> globalsRead;
+ std::set<Name> globalsWritten;
bool readsMemory = false;
bool writesMemory = false;
bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; }
+ bool accessesGlobal() { return globalsRead.size() + globalsWritten.size() > 0; }
bool accessesMemory() { return calls || readsMemory || writesMemory; }
- bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches; }
- bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory; }
+ bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches || globalsWritten.size() > 0; }
+ bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal(); }
// checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us)
bool invalidates(EffectAnalyzer& other) {
@@ -115,6 +132,17 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
for (auto local : localsRead) {
if (other.localsWritten.count(local)) return true;
}
+ if ((accessesGlobal() && other.calls) || (other.accessesGlobal() && calls)) {
+ return true;
+ }
+ for (auto global : globalsWritten) {
+ if (other.globalsWritten.count(global) || other.globalsRead.count(global)) {
+ return true;
+ }
+ }
+ for (auto global : globalsRead) {
+ if (other.globalsWritten.count(global)) return true;
+ }
return false;
}
@@ -163,8 +191,12 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer
void visitSetLocal(SetLocal *curr) {
localsWritten.insert(curr->index);
}
- void visitGetGlobal(GetGlobal *curr) { readsMemory = true; } // TODO: global-specific
- void visitSetGlobal(SetGlobal *curr) { writesMemory = true; } // stuff?
+ void visitGetGlobal(GetGlobal *curr) {
+ globalsRead.insert(curr->name);
+ }
+ void visitSetGlobal(SetGlobal *curr) {
+ globalsWritten.insert(curr->name);
+ }
void visitLoad(Load *curr) { readsMemory = true; }
void visitStore(Store *curr) { writesMemory = true; }
void visitReturn(Return *curr) { branches = true; }
@@ -340,6 +372,21 @@ struct ExpressionManipulator {
} copier;
return flexibleCopy(original, wasm, copier);
}
+
+ // Splice an item into the middle of a block's list
+ static void spliceIntoBlock(Block* block, Index index, Expression* add) {
+ auto& list = block->list;
+ if (index == list.size()) {
+ list.push_back(add); // simple append
+ } else {
+ // we need to make room
+ list.push_back(nullptr);
+ for (Index i = list.size() - 1; i > index; i--) {
+ list[i] = list[i - 1];
+ }
+ list[index] = add;
+ }
+ }
};
struct ExpressionAnalyzer {
@@ -373,6 +420,25 @@ struct ExpressionAnalyzer {
return func->result != none;
}
+ // Checks if a break is a simple - no condition, no value, just a plain branching
+ static bool isSimple(Break* curr) {
+ return !curr->condition && !curr->value;
+ }
+
+ // Checks if an expression ends with a simple break,
+ // and returns a pointer to it if so.
+ // (It might also have other internal branches.)
+ static Expression* getEndingSimpleBreak(Expression* curr) {
+ if (auto* br = curr->dynCast<Break>()) {
+ if (isSimple(br)) return br;
+ return nullptr;
+ }
+ if (auto* block = curr->dynCast<Block>()) {
+ if (block->list.size() > 0) return getEndingSimpleBreak(block->list.back());
+ }
+ return nullptr;
+ }
+
template<typename T>
static bool flexibleEqual(Expression* left, Expression* right, T& comparer) {
std::vector<Name> nameStack;
@@ -814,6 +880,25 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<Auto
curr->finalize(); // we may have changed our type
}
+ void visitIf(If* curr) {
+ if (curr->ifFalse) {
+ if (!isConcreteWasmType(curr->type)) {
+ // if either side of an if-else not returning a value is concrete, drop it
+ if (isConcreteWasmType(curr->ifTrue->type)) {
+ curr->ifTrue = Builder(*getModule()).makeDrop(curr->ifTrue);
+ }
+ if (isConcreteWasmType(curr->ifFalse->type)) {
+ curr->ifFalse = Builder(*getModule()).makeDrop(curr->ifFalse);
+ }
+ }
+ } else {
+ // if without else does not return a value, so the body must be dropped if it is concrete
+ if (isConcreteWasmType(curr->ifTrue->type)) {
+ curr->ifTrue = Builder(*getModule()).makeDrop(curr->ifTrue);
+ }
+ }
+ }
+
void visitFunction(Function* curr) {
if (curr->result == none && isConcreteWasmType(curr->body->type)) {
curr->body = Builder(*getModule()).makeDrop(curr->body);
diff --git a/src/pass.h b/src/pass.h
index 0ad8d8c1d..e237b8a98 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -72,17 +72,17 @@ struct PassRunner {
void add(std::string passName) {
auto pass = PassRegistry::get()->createPass(passName);
if (!pass) Fatal() << "Could not find pass: " << passName << "\n";
- passes.push_back(pass);
+ doAdd(pass);
}
template<class P>
void add() {
- passes.push_back(new P());
+ doAdd(new P());
}
template<class P, class Arg>
void add(Arg arg){
- passes.push_back(new P(arg));
+ doAdd(new P(arg));
}
// Adds the default set of optimization passes; this is
@@ -110,6 +110,8 @@ struct PassRunner {
~PassRunner();
private:
+ void doAdd(Pass* pass);
+
void runPassOnFunction(Pass* pass, Function* func);
};
@@ -121,12 +123,13 @@ public:
virtual ~Pass() {};
// Override this to perform preparation work before the pass runs.
- virtual void prepare(PassRunner* runner, Module* module) {}
+ // This will be called before the pass is run on a module.
+ virtual void prepareToRun(PassRunner* runner, Module* module) {}
+
+ // Implement this with code to run the pass on the whole module
virtual void run(PassRunner* runner, Module* module) = 0;
- // Override this to perform finalization work after the pass runs.
- virtual void finalize(PassRunner* runner, Module* module) {}
- // Run on a single function. This has no prepare/finalize calls.
+ // Implement this with code to run the pass on a single function
virtual void runFunction(PassRunner* runner, Module* module, Function* function) {
WASM_UNREACHABLE(); // by default, passes cannot be run this way
}
@@ -166,9 +169,7 @@ template <typename WalkerType>
class WalkerPass : public Pass, public WalkerType {
public:
void run(PassRunner* runner, Module* module) override {
- prepare(runner, module);
WalkerType::walkModule(module);
- finalize(runner, module);
}
void runFunction(PassRunner* runner, Module* module, Function* func) override {
diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt
index 1b4c65562..30f63c880 100644
--- a/src/passes/CMakeLists.txt
+++ b/src/passes/CMakeLists.txt
@@ -3,6 +3,7 @@ SET(passes_SOURCES
CoalesceLocals.cpp
DeadCodeElimination.cpp
DuplicateFunctionElimination.cpp
+ ExtractFunction.cpp
MergeBlocks.cpp
Metrics.cpp
NameManager.cpp
@@ -11,6 +12,7 @@ SET(passes_SOURCES
PostEmscripten.cpp
Precompute.cpp
Print.cpp
+ RelooperJumpThreading.cpp
RemoveImports.cpp
RemoveMemory.cpp
RemoveUnusedBrs.cpp
diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp
index 2063a20db..ba2c6dd81 100644
--- a/src/passes/CoalesceLocals.cpp
+++ b/src/passes/CoalesceLocals.cpp
@@ -198,6 +198,7 @@ struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<Coal
void scanLivenessThroughActions(std::vector<Action>& actions, LocalSet& live);
void pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices);
+ void pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices, Index& removedCopies);
virtual void pickIndices(std::vector<Index>& indices); // returns a vector of oldIndex => newIndex
@@ -224,14 +225,17 @@ struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<Coal
// copying state
- std::vector<uint8_t> copies; // canonicalized - accesses should check (low, high)
+ std::vector<uint8_t> copies; // canonicalized - accesses should check (low, high) TODO: use a map for high N, as this tends to be sparse? or don't look at copies at all for big N?
+ std::vector<Index> totalCopies; // total # of copies for each local, with all others
void addCopy(Index i, Index j) {
auto k = std::min(i, j) * numLocals + std::max(i, j);
copies[k] = std::min(copies[k], uint8_t(254)) + 1;
+ totalCopies[i]++;
+ totalCopies[j]++;
}
- bool getCopies(Index i, Index j) {
+ uint8_t getCopies(Index i, Index j) {
return copies[std::min(i, j) * numLocals + std::max(i, j)];
}
};
@@ -240,6 +244,8 @@ void CoalesceLocals::doWalkFunction(Function* func) {
numLocals = func->getNumLocals();
copies.resize(numLocals * numLocals);
std::fill(copies.begin(), copies.end(), 0);
+ totalCopies.resize(numLocals);
+ std::fill(totalCopies.begin(), totalCopies.end(), 0);
// collect initial liveness info
WalkerPass<CFGWalker<CoalesceLocals, Visitor<CoalesceLocals>, Liveness>>::doWalkFunction(func);
// ignore links to dead blocks, so they don't confuse us and we can see their stores are all ineffective
@@ -388,8 +394,43 @@ void CoalesceLocals::calculateInterferences(const LocalSet& locals) {
// Indices decision making
void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices) {
- // simple greedy coloring
- // TODO: take into account eliminated copies
+ Index removedCopies;
+ pickIndicesFromOrder(order, indices, removedCopies);
+}
+
+void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices, Index& removedCopies) {
+ // mostly-simple greedy coloring
+#if CFG_DEBUG
+ std::cerr << "\npickIndicesFromOrder on " << getFunction()->name << '\n';
+ std::cerr << getFunction()->body << '\n';
+ std::cerr << "order:\n";
+ for (auto i : order) std::cerr << i << ' ';
+ std::cerr << '\n';
+ std::cerr << "interferences:\n";
+ for (Index i = 0; i < numLocals; i++) {
+ for (Index j = 0; j < i + 1; j++) {
+ std::cerr << " ";
+ }
+ for (Index j = i + 1; j < numLocals; j++) {
+ std::cerr << int(interferes(i, j)) << ' ';
+ }
+ std::cerr << " : $" << i << '\n';
+ }
+ std::cerr << "copies:\n";
+ for (Index i = 0; i < numLocals; i++) {
+ for (Index j = 0; j < i + 1; j++) {
+ std::cerr << " ";
+ }
+ for (Index j = i + 1; j < numLocals; j++) {
+ std::cerr << int(getCopies(i, j)) << ' ';
+ }
+ std::cerr << " : $" << i << '\n';
+ }
+ std::cerr << "total copies:\n";
+ for (Index i = 0; i < numLocals; i++) {
+ std::cerr << " $" << i << ": " << totalCopies[i] << '\n';
+ }
+#endif
// TODO: take into account distribution (99-1 is better than 50-50 with two registers, for gzip)
std::vector<WasmType> types;
std::vector<bool> newInterferences; // new index * numLocals => list of all interferences of locals merged to it
@@ -397,12 +438,13 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector
indices.resize(numLocals);
types.resize(numLocals);
newInterferences.resize(numLocals * numLocals);
- newCopies.resize(numLocals * numLocals);
std::fill(newInterferences.begin(), newInterferences.end(), 0);
+ auto numParams = getFunction()->getNumParams();
+ newCopies.resize(numParams * numLocals); // start with enough room for the params
std::fill(newCopies.begin(), newCopies.end(), 0);
Index nextFree = 0;
+ removedCopies = 0;
// we can't reorder parameters, they are fixed in order, and cannot coalesce
- auto numParams = getFunction()->getNumParams();
Index i = 0;
for (; i < numParams; i++) {
assert(order[i] == i); // order must leave the params in place
@@ -421,6 +463,7 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector
for (Index j = 0; j < nextFree; j++) {
if (!newInterferences[j * numLocals + actual] && getFunction()->getLocalType(actual) == types[j]) {
// this does not interfere, so it might be what we want. but pick the one eliminating the most copies
+ // (we could stop looking forward when there are no more items that have copies anyhow, but it doesn't seem to help)
auto currCopies = newCopies[j * numLocals + actual];
if (found == Index(-1) || currCopies > foundCopies) {
indices[actual] = found = j;
@@ -432,7 +475,14 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector
indices[actual] = found = nextFree;
types[found] = getFunction()->getLocalType(actual);
nextFree++;
+ removedCopies += getCopies(found, actual);
+ newCopies.resize(nextFree * numLocals);
+ } else {
+ removedCopies += foundCopies;
}
+#if CFG_DEBUG
+ std::cerr << "set local $" << actual << " to $" << found << '\n';
+#endif
// merge new interferences and copies for the new index
for (Index k = i + 1; k < numLocals; k++) {
auto j = order[k]; // go in the order, we only need to update for those we will see later
@@ -442,31 +492,85 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector
}
}
+// Utilities for operating on permutation vectors
+
+static std::vector<Index> makeIdentity(Index num) {
+ std::vector<Index> ret;
+ ret.resize(num);
+ for (Index i = 0; i < num; i++) {
+ ret[i] = i;
+ }
+ return ret;
+}
+
+static void setIdentity(std::vector<Index>& ret) {
+ auto num = ret.size();
+ assert(num > 0); // must already be of the right size
+ for (Index i = 0; i < num; i++) {
+ ret[i] = i;
+ }
+}
+
+static std::vector<Index> makeReversed(std::vector<Index>& original) {
+ std::vector<Index> ret;
+ auto num = original.size();
+ ret.resize(num);
+ for (Index i = 0; i < num; i++) {
+ ret[original[i]] = i;
+ }
+ return ret;
+}
+
+// given a baseline order, adjust it based on an important order of priorities (higher values
+// are higher priority). The priorities take precedence, unless they are equal and then
+// the original order should be kept.
+std::vector<Index> adjustOrderByPriorities(std::vector<Index>& baseline, std::vector<Index>& priorities) {
+ std::vector<Index> ret = baseline;
+ std::vector<Index> reversed = makeReversed(baseline);
+ std::sort(ret.begin(), ret.end(), [&priorities, &reversed](Index x, Index y) {
+ return priorities[x] > priorities[y] || (priorities[x] == priorities[y] && reversed[x] < reversed[y]);
+ });
+ return ret;
+};
+
void CoalesceLocals::pickIndices(std::vector<Index>& indices) {
if (numLocals == 0) return;
if (numLocals == 1) {
indices.push_back(0);
return;
}
+ if (getFunction()->getNumVars() <= 1) {
+ // nothing to think about here, since we can't reorder params
+ indices = makeIdentity(numLocals);
+ return;
+ }
+ // take into account total copies. but we must keep params in place, so give them max priority
+ auto adjustedTotalCopies = totalCopies;
+ auto numParams = getFunction()->getNumParams();
+ for (Index i = 0; i < numParams; i++) {
+ adjustedTotalCopies[i] = std::numeric_limits<Index>::max();
+ }
// first try the natural order. this is less arbitrary than it seems, as the program
// may have a natural order of locals inherent in it.
- std::vector<Index> order;
- order.resize(numLocals);
- for (Index i = 0; i < numLocals; i++) {
- order[i] = i;
- }
- pickIndicesFromOrder(order, indices);
+ auto order = makeIdentity(numLocals);
+ order = adjustOrderByPriorities(order, adjustedTotalCopies);
+ Index removedCopies;
+ pickIndicesFromOrder(order, indices, removedCopies);
auto maxIndex = *std::max_element(indices.begin(), indices.end());
- // next try the reverse order. this both gives us anothe chance at something good,
+ // next try the reverse order. this both gives us another chance at something good,
// and also the very naturalness of the simple order may be quite suboptimal
- auto numParams = getFunction()->getNumParams();
+ setIdentity(order);
for (Index i = numParams; i < numLocals; i++) {
order[i] = numParams + numLocals - 1 - i;
}
+ order = adjustOrderByPriorities(order, adjustedTotalCopies);
std::vector<Index> reverseIndices;
- pickIndicesFromOrder(order, reverseIndices);
+ Index reverseRemovedCopies;
+ pickIndicesFromOrder(order, reverseIndices, reverseRemovedCopies);
auto reverseMaxIndex = *std::max_element(reverseIndices.begin(), reverseIndices.end());
- if (reverseMaxIndex < maxIndex) {
+ // prefer to remove copies foremost, as it matters more for code size (minus gzip), and
+ // improves throughput.
+ if (reverseRemovedCopies > removedCopies || (reverseRemovedCopies == removedCopies && reverseMaxIndex < maxIndex)) {
indices.swap(reverseIndices);
}
}
@@ -553,7 +657,8 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) {
void calculateFitness(Order* order) {
// apply the order
std::vector<Index> indices; // the phenotype
- parent->pickIndicesFromOrder(*order, indices);
+ Index removedCopies;
+ parent->pickIndicesFromOrder(*order, indices, removedCopies);
auto maxIndex = *std::max_element(indices.begin(), indices.end());
assert(maxIndex <= parent->numLocals);
// main part of fitness is the number of locals
@@ -563,6 +668,7 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) {
for (Index i = 0; i < parent->numLocals; i++) {
if ((*order)[i] == i) fitness += fragment; // boost for each that wasn't moved
}
+ fitness = (100 * fitness) + removedCopies; // removing copies is a secondary concern
order->setFitness(fitness);
}
@@ -577,6 +683,8 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) {
// first, there may be an inherent order in the input (frequent indices are lower,
// etc.). second, by ensuring we start with the natural order, we ensure we are at
// least as good as the non-learning variant.
+ // TODO: use ::pickIndices from the parent, so we literally get the simpler approach
+ // as our first option
first = false;
} else {
// leave params alone, shuffle the rest
diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp
index 2b8e69b54..cfe2d8565 100644
--- a/src/passes/DuplicateFunctionElimination.cpp
+++ b/src/passes/DuplicateFunctionElimination.cpp
@@ -102,7 +102,17 @@ struct DuplicateFunctionElimination : public Pass {
auto& group = pair.second;
if (group.size() == 1) continue;
// pick a base for each group, and try to replace everyone else to it. TODO: multiple bases per hash group, for collisions
+#if 0
+ // for comparison purposes, pick in a deterministic way based on the names
+ Function* base = nullptr;
+ for (auto* func : group) {
+ if (!base || strcmp(func->name.str, base->name.str) < 0) {
+ base = func;
+ }
+ }
+#else
Function* base = group[0];
+#endif
for (auto* func : group) {
if (func != base && equal(func, base)) {
replacements[func->name] = base->name;
diff --git a/src/passes/ExtractFunction.cpp b/src/passes/ExtractFunction.cpp
new file mode 100644
index 000000000..b342efc3b
--- /dev/null
+++ b/src/passes/ExtractFunction.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Removes code from all functions but one, leaving a valid module
+// with (mostly) just the code you want to debug (function-parallel,
+// non-lto) passes on.
+
+#include "wasm.h"
+#include "pass.h"
+
+namespace wasm {
+
+Name TO_LEAVE("_bytearray_join"); // TODO: commandline param
+
+struct ExtractFunction : public Pass {
+ void run(PassRunner* runner, Module* module) override {
+ for (auto& func : module->functions) {
+ if (func->name != TO_LEAVE) {
+ // wipe out the body
+ func->body = module->allocator.alloc<Unreachable>();
+ }
+ }
+ }
+};
+
+// declare pass
+
+Pass *createExtractFunctionPass() {
+ return new ExtractFunction();
+}
+
+} // namespace wasm
+
diff --git a/src/passes/LowerInt64.cpp b/src/passes/LowerInt64.cpp
deleted file mode 100644
index 69f6d5ae9..000000000
--- a/src/passes/LowerInt64.cpp
+++ /dev/null
@@ -1,196 +0,0 @@
-/*
- * Copyright 2015 WebAssembly Community Group participants
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-//
-// Lowers 64-bit ints to pairs of 32-bit ints, plus some library routines
-//
-// This is useful for wasm2asm, as JS has no native 64-bit integer support.
-//
-
-#include <memory>
-
-#include <wasm.h>
-#include <pass.h>
-
-namespace wasm {
-
-cashew::IString GET_HIGH("getHigh");
-
-struct LowerInt64 : public Pass {
- MixedArena* allocator;
- std::unique_ptr<NameManager> namer;
-
- void prepare(PassRunner* runner, Module *module) override {
- allocator = runner->allocator;
- namer = make_unique<NameManager>();
- namer->run(runner, module);
- }
-
- std::map<Expression*, Expression*> fixes; // fixed nodes, outputs of lowering, mapped to their high bits
- std::map<Name, Name> locals; // maps locals which were i64->i32 to their high bits
-
- void makeGetHigh() {
- auto ret = allocator->alloc<CallImport>();
- ret->target = GET_HIGH;
- ret->type = i32;
- return ret;
- }
-
- void fixCall(CallBase *call) {
- auto& operands = call->operands;
- for (size_t i = 0; i < operands.size(); i++) {
- auto fix = fixes.find(operands[i]);
- if (fix != fixes.end()) {
- operands.insert(operands.begin() + i + 1, *fix);
- }
- }
- if (curr->type == i64) {
- curr->type = i32;
- fixes[curr] = makeGetHigh(); // called function will setHigh
- }
- }
-
- void visitCall(Call *curr) {
- fixCall(curr);
- }
- void visitCallImport(CallImport *curr) {
- fixCall(curr);
- }
- void visitCallIndirect(CallIndirect *curr) {
- fixCall(curr);
- }
- void visitGetLocal(GetLocal *curr) {
- if (curr->type == i64) {
- if (locals.count(curr->name) == 0) {
- Name highName = namer->getUnique("high");
- locals[curr->name] = highName;
- };
- curr->type = i32;
- auto high = allocator->alloc<GetLocal>();
- high->name = locals[curr->name];
- high->type = i32;
- fixes[curr] = high;
- }
- }
- void visitSetLocal(SetLocal *curr) {
- if (curr->type == i64) {
- Name highName;
- if (locals.count(curr->name) == 0) {
- highName = namer->getUnique("high");
- locals[curr->name] = highName;
- } else {
- highName = locals[curr->name];
- }
- curr->type = i32;
- auto high = allocator->alloc<GetLocal>();
- high->name = highName;
- high->type = i32;
- fixes[curr] = high;
- // Set the high bits
- auto set = allocator.alloc<SetLocal>();
- set->name = highName;
- set->value = fixes[curr->value];
- set->type = i32;
- assert(set->value);
- auto low = allocator->alloc<GetLocal>();
- low->name = curr->name;
- low->type = i32;
- auto ret = allocator.alloc<Block>();
- ret->list.push_back(curr);
- ret->list.push_back(set);
- ret->list.push_back(low); // so the block returns the low bits
- ret->finalize();
- fixes[ret] = high;
- replaceCurrent(ret);
- }
- }
-
- // sets an expression to a local, and returns a block
- Block* setToLocalForBlock(Expression *value, Name& local, Block *ret = nullptr) {
- if (!ret) ret = allocator->alloc<Block>();
- if (value->is<GetLocal>()) {
- local = value->name;
- } else if (value->is<SetLocal>()) {
- local = value->name;
- } else {
- auto set = allocator.alloc<SetLocal>();
- set->name = local = namer->getUnique("temp");
- set->value = value;
- set->type = value->type;
- ret->list.push_back(set);
- }
- ret->finalize();
- return ret;
- }
-
- GetLocal* getLocal(Name name) {
- auto ret = allocator->alloc<GetLocal>();
- ret->name = name;
- ret->type = i32;
- return ret;
- }
-
- void visitLoad(Load *curr) {
- if (curr->type == i64) {
- Name local;
- auto ret = setToLocalForBlock(curr->ptr, local);
- curr->ptr = getLocal(local);
- curr->type = i32;
- curr->bytes = 4;
- auto high = allocator->alloc<Load>();
- *high = *curr;
- high->ptr = getLocal(local);
- high->offset += 4;
- ret->list.push_back(curr);
- fixes[ret] = high;
- replaceCurrent(ret);
- }
- }
- void visitStore(Store *curr) {
- if (curr->type == i64) {
- Name localPtr, localValue;
- auto ret = setToLocalForBlock(curr->ptr, localPtr);
- setToLocalForBlock(curr->value, localValue);
- curr->ptr = getLocal(localPtr);
- curr->value = getLocal(localValue);
- curr->type = i32;
- curr->bytes = 4;
- auto high = allocator->alloc<Load>();
- *high = *curr;
- high->ptr = getLocal(localPtr);
- high->value = getLocal(localValue);
- high->offset += 4;
- ret->list.push_back(high);
- ret->list.push_back(curr);
- fixes[ret] = high;
- replaceCurrent(ret);
- }
- }
- void visitFunction(Function *curr) {
- // TODO: new params
- for (auto localPair : locals) { // TODO: ignore params
- curr->locals.emplace_back(localPair.second, i32);
- }
- fixes.clear();
- locals.clear();
- }
-};
-
-Pass *createLowerInt64Pass() {
- return new LowerInt64();
-}
-
-} // namespace wasm
diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp
index c5309850e..9181da9bb 100644
--- a/src/passes/Metrics.cpp
+++ b/src/passes/Metrics.cpp
@@ -34,7 +34,7 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor<
counts[name]++;
}
- void finalize(PassRunner *runner, Module *module) override {
+ void visitModule(Module* module) {
ostream &o = cout;
o << "Counts"
<< "\n";
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 2a5e7b8c4..669a19b89 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -24,6 +24,7 @@
#include <pass.h>
#include <wasm-s-parser.h>
#include <support/threads.h>
+#include <ast_utils.h>
namespace wasm {
@@ -50,7 +51,6 @@ struct PatternDatabase {
std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more
PatternDatabase() {
- // TODO: do this on first use, with a lock, to avoid startup pause
// generate module
input = strdup(
#include "OptimizeInstructions.wast.processed"
@@ -74,14 +74,12 @@ struct PatternDatabase {
static PatternDatabase* database = nullptr;
-static void ensureDatabase() {
- if (!database) {
- // we must only ever create one database
- static OnlyOnce onlyOnce;
- onlyOnce.verify();
+struct DatabaseEnsurer {
+ DatabaseEnsurer() {
+ assert(!database);
database = new PatternDatabase;
}
-}
+};
// Check for matches and apply them
struct Match {
@@ -161,13 +159,18 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
Pass* create() override { return new OptimizeInstructions; }
- OptimizeInstructions() {
- ensureDatabase();
+ void prepareToRun(PassRunner* runner, Module* module) override {
+ static DatabaseEnsurer ensurer;
}
void visitExpression(Expression* curr) {
// we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles
while (1) {
+ auto* handOptimized = handOptimize(curr);
+ if (handOptimized) {
+ curr = handOptimized;
+ replaceCurrent(curr);
+ }
auto iter = database->patternMap.find(curr->_id);
if (iter == database->patternMap.end()) return;
auto& patterns = iter->second;
@@ -184,6 +187,65 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions,
if (!more) break;
}
}
+
+ // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe
+ Expression* handOptimize(Expression* curr) {
+ if (auto* binary = curr->dynCast<Binary>()) {
+ // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
+ if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) {
+ auto shifts = binary->right->cast<Const>()->value.geti32();
+ if (shifts == 24 || shifts == 16) {
+ auto* left = binary->left->dynCast<Binary>();
+ if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) {
+ auto* load = left->left->dynCast<Load>();
+ if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) {
+ load->signed_ = true;
+ return load;
+ }
+ }
+ }
+ }
+ } else if (auto* set = curr->dynCast<SetGlobal>()) {
+ // optimize out a set of a get
+ auto* get = set->value->dynCast<GetGlobal>();
+ if (get && get->name == set->name) {
+ ExpressionManipulator::nop(curr);
+ }
+ } else if (auto* iff = curr->dynCast<If>()) {
+ iff->condition = optimizeBoolean(iff->condition);
+ } else if (auto* select = curr->dynCast<Select>()) {
+ select->condition = optimizeBoolean(select->condition);
+ auto* condition = select->condition->dynCast<Unary>();
+ if (condition && condition->op == EqZInt32) {
+ // flip select to remove eqz, if we can reorder
+ EffectAnalyzer ifTrue(select->ifTrue);
+ EffectAnalyzer ifFalse(select->ifFalse);
+ if (!ifTrue.invalidates(ifFalse)) {
+ select->condition = condition->value;
+ std::swap(select->ifTrue, select->ifFalse);
+ }
+ }
+ } else if (auto* br = curr->dynCast<Break>()) {
+ if (br->condition) {
+ br->condition = optimizeBoolean(br->condition);
+ }
+ }
+ return nullptr;
+ }
+
+private:
+
+ Expression* optimizeBoolean(Expression* boolean) {
+ auto* condition = boolean->dynCast<Unary>();
+ if (condition && condition->op == EqZInt32) {
+ auto* condition2 = condition->value->dynCast<Unary>();
+ if (condition2 && condition2->op == EqZInt32) {
+ // double eqz
+ return condition2->value;
+ }
+ }
+ return boolean;
+ }
};
Pass *createOptimizeInstructionsPass() {
diff --git a/src/passes/OptimizeInstructions.wast b/src/passes/OptimizeInstructions.wast
index 7d5c56881..48a4e2c9d 100644
--- a/src/passes/OptimizeInstructions.wast
+++ b/src/passes/OptimizeInstructions.wast
@@ -19,8 +19,8 @@
;; main function. each block here is a pattern pair of input => output
(func $patterns
+ ;; flip if-else arms to get rid of an eqz
(block
- ;; flip if-else arms to get rid of an eqz
(if
(i32.eqz
(call_import $i32.expr (i32.const 0))
@@ -34,6 +34,25 @@
(call_import $any.expr (i32.const 1))
)
)
+ ;; equal 0 => eqz
+ (block
+ (i32.eq
+ (call_import $any.expr (i32.const 0))
+ (i32.const 0)
+ )
+ (i32.eqz
+ (call_import $any.expr (i32.const 0))
+ )
+ )
+ (block
+ (i32.eq
+ (i32.const 0)
+ (call_import $any.expr (i32.const 0))
+ )
+ (i32.eqz
+ (call_import $any.expr (i32.const 0))
+ )
+ )
;; De Morgans Laws
(block
(i32.eqz (i32.eq (call_import $i32.expr (i32.const 0)) (call_import $i32.expr (i32.const 1))))
diff --git a/src/passes/OptimizeInstructions.wast.processed b/src/passes/OptimizeInstructions.wast.processed
index 61fde86c6..13ccc8241 100644
--- a/src/passes/OptimizeInstructions.wast.processed
+++ b/src/passes/OptimizeInstructions.wast.processed
@@ -19,8 +19,8 @@
"\n"
";; main function. each block here is a pattern pair of input => output\n"
"(func $patterns\n"
-"(block\n"
";; flip if-else arms to get rid of an eqz\n"
+"(block\n"
"(if\n"
"(i32.eqz\n"
"(call_import $i32.expr (i32.const 0))\n"
@@ -34,6 +34,25 @@
"(call_import $any.expr (i32.const 1))\n"
")\n"
")\n"
+";; equal 0 => eqz\n"
+"(block\n"
+"(i32.eq\n"
+"(call_import $any.expr (i32.const 0))\n"
+"(i32.const 0)\n"
+")\n"
+"(i32.eqz\n"
+"(call_import $any.expr (i32.const 0))\n"
+")\n"
+")\n"
+"(block\n"
+"(i32.eq\n"
+"(i32.const 0)\n"
+"(call_import $any.expr (i32.const 0))\n"
+")\n"
+"(i32.eqz\n"
+"(call_import $any.expr (i32.const 0))\n"
+")\n"
+")\n"
";; De Morgans Laws\n"
"(block\n"
"(i32.eqz (i32.eq (call_import $i32.expr (i32.const 0)) (call_import $i32.expr (i32.const 1))))\n"
diff --git a/src/passes/RelooperJumpThreading.cpp b/src/passes/RelooperJumpThreading.cpp
new file mode 100644
index 000000000..7f74220d5
--- /dev/null
+++ b/src/passes/RelooperJumpThreading.cpp
@@ -0,0 +1,248 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Optimize relooper-generated label variable usage: add blocks and turn
+// a label-set/break/label-check into a break into the new block.
+// This assumes the very specific output the fastcomp relooper emits,
+// including the name of the 'label' variable.
+
+#include "wasm.h"
+#include "pass.h"
+#include "ast_utils.h"
+
+namespace wasm {
+
+static Name LABEL("label");
+
+// We need to use new label names, which we cannot create in parallel, so pre-create them
+
+const Index MAX_NAME_INDEX = 1000;
+
+std::vector<Name>* innerNames = nullptr;
+std::vector<Name>* outerNames = nullptr;
+
+struct NameEnsurer {
+ NameEnsurer() {
+ assert(!innerNames);
+ assert(!outerNames);
+ innerNames = new std::vector<Name>;
+ outerNames = new std::vector<Name>;
+ for (Index i = 0; i < MAX_NAME_INDEX; i++) {
+ innerNames->push_back(Name(std::string("jumpthreading$inner$") + std::to_string(i)));
+ outerNames->push_back(Name(std::string("jumpthreading$outer$") + std::to_string(i)));
+ }
+ }
+};
+
+static If* isLabelCheckingIf(Expression* curr, Index labelIndex) {
+ if (!curr) return nullptr;
+ auto* iff = curr->dynCast<If>();
+ if (!iff) return nullptr;
+ auto* condition = iff->condition->dynCast<Binary>();
+ if (!(condition && condition->op == EqInt32)) return nullptr;
+ auto* left = condition->left->dynCast<GetLocal>();
+ if (!(left && left->index == labelIndex)) return nullptr;
+ return iff;
+}
+
+static Index getCheckedLabelValue(If* iff) {
+ return iff->condition->cast<Binary>()->right->cast<Const>()->value.geti32();
+}
+
+static SetLocal* isLabelSettingSetLocal(Expression* curr, Index labelIndex) {
+ if (!curr) return nullptr;
+ auto* set = curr->dynCast<SetLocal>();
+ if (!set) return nullptr;
+ if (set->index != labelIndex) return nullptr;
+ return set;
+}
+
+static Index getSetLabelValue(SetLocal* set) {
+ return set->value->cast<Const>()->value.geti32();
+}
+
+struct LabelUseFinder : public PostWalker<LabelUseFinder, Visitor<LabelUseFinder>> {
+ Index labelIndex;
+ std::map<Index, Index>& checks; // label value => number of checks on it
+ std::map<Index, Index>& sets; // label value => number of sets to it
+
+ LabelUseFinder(Index labelIndex, std::map<Index, Index>& checks, std::map<Index, Index>& sets) : labelIndex(labelIndex), checks(checks), sets(sets) {}
+
+ void visitIf(If* curr) {
+ if (isLabelCheckingIf(curr, labelIndex)) {
+ checks[getCheckedLabelValue(curr)]++;
+ }
+ }
+
+ void visitSetLocal(SetLocal* curr) {
+ if (isLabelSettingSetLocal(curr, labelIndex)) {
+ sets[getSetLabelValue(curr)]++;
+ }
+ }
+};
+
+struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new RelooperJumpThreading; }
+
+ void prepareToRun(PassRunner* runner, Module* module) override {
+ static NameEnsurer ensurer;
+ }
+
+ std::map<Index, Index> labelChecks;
+ std::map<Index, Index> labelSets;
+
+ Index labelIndex;
+ Index newNameCounter = 0;
+
+ void visitBlock(Block* curr) {
+ // look for the if label == X pattern
+ auto& list = curr->list;
+ if (list.size() == 0) return;
+ for (Index i = 0; i < list.size() - 1; i++) {
+ // once we see something that might be irreducible, we must skip that if and the rest of the dependents
+ bool irreducible = false;
+ Index origin = i;
+ for (Index j = i + 1; j < list.size(); j++) {
+ if (auto* iff = isLabelCheckingIf(list[j], labelIndex)) {
+ irreducible |= hasIrreducibleControlFlow(iff, list[origin]);
+ if (!irreducible) {
+ optimizeJumpsToLabelCheck(list[origin], iff);
+ ExpressionManipulator::nop(iff);
+ }
+ i++;
+ continue;
+ }
+ // if the next element is a block, it may be the holding block of label-checking ifs
+ if (auto* holder = list[j]->dynCast<Block>()) {
+ if (holder->list.size() > 0) {
+ if (If* iff = isLabelCheckingIf(holder->list[0], labelIndex)) {
+ irreducible |= hasIrreducibleControlFlow(iff, list[origin]);
+ if (!irreducible) {
+ // this is indeed a holder. we can process the ifs, and must also move
+ // the block to enclose the origin, so it is properly reachable
+ assert(holder->list.size() == 1); // must be size 1, a relooper multiple will have its own label, and is an if-else sequence and nothing more
+ optimizeJumpsToLabelCheck(list[origin], iff);
+ holder->list[0] = list[origin];
+ list[origin] = holder;
+ // reuse the if as a nop
+ list[j] = iff;
+ ExpressionManipulator::nop(iff);
+ }
+ i++;
+ continue;
+ }
+ }
+ }
+ break; // we didn't see something we like, so stop here
+ }
+ }
+ }
+
+ void doWalkFunction(Function* func) {
+ // if there isn't a label variable, nothing for us to do
+ if (func->localIndices.count(LABEL)) {
+ labelIndex = func->getLocalIndex(LABEL);
+ LabelUseFinder finder(labelIndex, labelChecks, labelSets);
+ finder.walk(func->body);
+ WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>>::doWalkFunction(func);
+ }
+ }
+
+private:
+
+ bool hasIrreducibleControlFlow(If* iff, Expression* origin) {
+ // Gather the checks in this if chain. If all the label values checked are only set in origin,
+ // then since origin is right before us, this is not irreducible - we can replace all sets
+ // in origin with jumps forward to us, and since there is nothing else, this is safe and complete.
+ // We must also have the property that there is just one check for the label value, as otherwise
+ // node splitting has complicated things.
+ std::map<Index, Index> labelChecksInOrigin;
+ std::map<Index, Index> labelSetsInOrigin;
+ LabelUseFinder finder(labelIndex, labelChecksInOrigin, labelSetsInOrigin);
+ finder.walk(origin);
+ while (iff) {
+ auto num = getCheckedLabelValue(iff);
+ assert(labelChecks[num] > 0);
+ if (labelChecks[num] > 1) return true; // checked more than once, somewhere in function
+ assert(labelChecksInOrigin[num] == 0);
+ if (labelSetsInOrigin[num] != labelSets[num]) {
+ assert(labelSetsInOrigin[num] < labelSets[num]);
+ return true; // label set somewhere outside of origin TODO: if set in the if body here, it might be safe in some cases
+ }
+ iff = isLabelCheckingIf(iff->ifFalse, labelIndex);
+ }
+ return false;
+ }
+
+ // optimizes jumps to a label check
+ // * origin is where the jumps originate, and also where we should write our output
+ // * iff is the if
+ void optimizeJumpsToLabelCheck(Expression*& origin, If* iff) {
+ Index nameCounter = newNameCounter++;
+ if (nameCounter >= MAX_NAME_INDEX) {
+ std::cerr << "too many names in RelooperJumpThreading :(\n";
+ return;
+ }
+ Index num = getCheckedLabelValue(iff);
+ // create a new block for this jump target
+ Builder builder(*getModule());
+ // origin is where all jumps to this target must come from - the element right before this if
+ // we break out of inner to reach the target. instead of flowing out of normally, we break out of the outer, so we skip the target.
+ auto innerName = innerNames->at(nameCounter);
+ auto outerName = outerNames->at(nameCounter);
+ auto* ifFalse = iff->ifFalse;
+ // all assignments of label to the target can be replaced with breaks to the target, via innerName
+ struct JumpUpdater : public PostWalker<JumpUpdater, Visitor<JumpUpdater>> {
+ Index labelIndex;
+ Index targetNum;
+ Name targetName;
+
+ void visitSetLocal(SetLocal* curr) {
+ if (curr->index == labelIndex) {
+ if (Index(curr->value->cast<Const>()->value.geti32()) == targetNum) {
+ replaceCurrent(Builder(*getModule()).makeBreak(targetName));
+ }
+ }
+ }
+ };
+ JumpUpdater updater;
+ updater.labelIndex = labelIndex;
+ updater.targetNum = num;
+ updater.targetName = innerName;
+ updater.setModule(getModule());
+ updater.walk(origin);
+ // restructure code
+ auto* inner = builder.blockifyWithName(origin, innerName, builder.makeBreak(outerName));
+ auto* outer = builder.makeSequence(inner, iff->ifTrue);
+ outer->name = outerName;
+ origin = outer;
+ // if another label value is checked here, handle that too
+ if (ifFalse) {
+ optimizeJumpsToLabelCheck(origin, ifFalse->cast<If>());
+ }
+ }
+};
+
+// declare pass
+
+Pass *createRelooperJumpThreadingPass() {
+ return new RelooperJumpThreading();
+}
+
+} // namespace wasm
+
diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp
index 19d6c3eb1..429203a2e 100644
--- a/src/passes/RemoveImports.cpp
+++ b/src/passes/RemoveImports.cpp
@@ -28,22 +28,14 @@
namespace wasm {
struct RemoveImports : public WalkerPass<PostWalker<RemoveImports, Visitor<RemoveImports>>> {
- MixedArena* allocator;
- Module* module;
-
- void prepare(PassRunner* runner, Module *module_) override {
- allocator = runner->allocator;
- module = module_;
- }
-
void visitCallImport(CallImport *curr) {
- WasmType type = module->getImport(curr->target)->functionType->result;
+ WasmType type = getModule()->getImport(curr->target)->functionType->result;
if (type == none) {
- replaceCurrent(allocator->alloc<Nop>());
+ replaceCurrent(getModule()->allocator.alloc<Nop>());
} else {
Literal nopLiteral;
nopLiteral.type = type;
- replaceCurrent(allocator->alloc<Const>()->set(nopLiteral));
+ replaceCurrent(getModule()->allocator.alloc<Const>()->set(nopLiteral));
}
}
diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp
index 59d3af6fc..86a46374f 100644
--- a/src/passes/RemoveUnusedBrs.cpp
+++ b/src/passes/RemoveUnusedBrs.cpp
@@ -21,9 +21,20 @@
#include <wasm.h>
#include <pass.h>
#include <ast_utils.h>
+#include <wasm-builder.h>
namespace wasm {
+// to turn an if into a br-if, we must be able to reorder the
+// condition and possible value, and the possible value must
+// not have side effects (as they would run unconditionally)
+static bool canTurnIfIntoBrIf(Expression* ifCondition, Expression* brValue) {
+ if (!brValue) return true;
+ EffectAnalyzer value(brValue);
+ if (value.hasSideEffects()) return false;
+ return !EffectAnalyzer(ifCondition).invalidates(value);
+}
+
struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<RemoveUnusedBrs>>> {
bool isFunctionParallel() override { return true; }
@@ -44,6 +55,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
// a stack for if-else contents, we merge their outputs
std::vector<Flows> ifStack;
+ // list of all loops, so we can optimize them
+ std::vector<Loop*> loops;
+
static void visitAny(RemoveUnusedBrs* self, Expression** currp) {
auto* curr = *currp;
auto& flows = self->flows;
@@ -109,7 +123,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
// ignore (could be result of a previous cycle)
self->valueCanFlow = false;
} else {
- // anything else stops the flow TODO: optimize loops?
+ // anything else stops the flow
flows.clear();
self->valueCanFlow = false;
}
@@ -123,23 +137,25 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
self->ifStack.push_back(std::move(self->flows));
}
+ void visitLoop(Loop* curr) {
+ loops.push_back(curr);
+ }
+
void visitIf(If* curr) {
if (!curr->ifFalse) {
// if without an else. try to reduce if (condition) br => br_if (condition)
Break* br = curr->ifTrue->dynCast<Break>();
if (br && !br->condition) { // TODO: if there is a condition, join them
// if the br has a value, then if => br_if means we always execute the value, and also the order is value,condition vs condition,value
- if (br->value) {
- EffectAnalyzer value(br->value);
- if (value.hasSideEffects()) return;
- EffectAnalyzer condition(curr->condition);
- if (condition.invalidates(value)) return;
+ if (canTurnIfIntoBrIf(curr->condition, br->value)) {
+ br->condition = curr->condition;
+ br->finalize();
+ replaceCurrent(br);
+ anotherCycle = true;
}
- br->condition = curr->condition;
- replaceCurrent(br);
- anotherCycle = true;
}
}
+ // TODO: if-else can be turned into a br_if as well, if one of the sides is a dead end
}
// override scan to add a pre and a post check task to all nodes
@@ -163,6 +179,99 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
}
}
+ // optimizes a loop. returns true if we made changes
+ bool optimizeLoop(Loop* loop) {
+ // if a loop ends in
+ // (loop $in
+ // (block $out
+ // if (..) br $in; else br $out;
+ // )
+ // )
+ // then our normal opts can remove the break out since it flows directly out
+ // (and later passes make the if one-armed). however, the simple analysis
+ // fails on patterns like
+ // if (..) br $out;
+ // br $in;
+ // which is a common way to do a while (1) loop (end it with a jump to the
+ // top), so we handle that here. Specifically we want to conditionalize
+ // breaks to the loop top, i.e., put them behind a condition, so that other
+ // code can flow directly out and thus brs out can be removed. (even if
+ // the change is to let a break somewhere else flow out, that can still be
+ // helpful, as it shortens the logical loop. it is also good to generate
+ // an if-else instead of an if, as it might allow an eqz to be removed
+ // by flipping arms)
+ if (!loop->name.is()) return false;
+ auto* block = loop->body->dynCast<Block>();
+ if (!block) return false;
+ // does the last element break to the top of the loop?
+ auto& list = block->list;
+ if (list.size() <= 1) return false;
+ auto* last = list.back()->dynCast<Break>();
+ if (!last || !ExpressionAnalyzer::isSimple(last) || last->name != loop->name) return false;
+ // last is a simple break to the top of the loop. if we can conditionalize it,
+ // it won't block things from flowing out and not needing breaks to do so.
+ Index i = list.size() - 2;
+ Builder builder(*getModule());
+ while (1) {
+ auto* curr = list[i];
+ if (auto* iff = curr->dynCast<If>()) {
+ // let's try to move the code going to the top of the loop into the if-else
+ if (!iff->ifFalse) {
+ // we need the ifTrue to break, so it cannot reach the code we want to move
+ if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifTrue)) {
+ iff->ifFalse = builder.stealSlice(block, i + 1, list.size());
+ return true;
+ }
+ } else {
+ // this is already an if-else. if one side is a dead end, we can append to the other, if
+ // there is no returned value to concern us
+ assert(!isConcreteWasmType(iff->type)); // can't be, since in the middle of a block
+ if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifTrue)) {
+ iff->ifFalse = builder.blockifyMerge(iff->ifFalse, builder.stealSlice(block, i + 1, list.size()));
+ return true;
+ } else if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifFalse)) {
+ iff->ifTrue = builder.blockifyMerge(iff->ifTrue, builder.stealSlice(block, i + 1, list.size()));
+ return true;
+ }
+ }
+ return false;
+ } else if (auto* brIf = curr->dynCast<Break>()) {
+ // br_if is similar to if.
+ if (brIf->condition && !brIf->value && brIf->name != loop->name) {
+ if (i == list.size() - 2) {
+ // there is the br_if, and then the br to the top, so just flip them and the condition
+ brIf->condition = builder.makeUnary(EqZInt32, brIf->condition);
+ last->name = brIf->name;
+ brIf->name = loop->name;
+ return true;
+ } else {
+ // there are elements in the middle,
+ // br_if $somewhere (condition)
+ // (..more..)
+ // br $in
+ // we can convert the br_if to an if. this has a cost, though,
+ // so only do it if it looks useful, which it definitely is if
+ // (a) $somewhere is straight out (so the br out vanishes), and
+ // (b) this br_if is the only branch to that block (so the block will vanish)
+ if (brIf->name == block->name && BreakSeeker::count(block, block->name) == 1) {
+ // note that we could drop the last element here, it is a br we know for sure is removable,
+ // but telling stealSlice to steal all to the end is more efficient, it can just truncate.
+ list[i] = builder.makeIf(brIf->condition, builder.makeBreak(brIf->name), builder.stealSlice(block, i + 1, list.size()));
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+ // if there is control flow, we must stop looking
+ if (EffectAnalyzer(curr).branches) {
+ return false;
+ }
+ if (i == 0) return false;
+ i--;
+ }
+ }
+
void doWalkFunction(Function* func) {
// multiple cycles may be needed
bool worked = false;
@@ -172,7 +281,8 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
assert(ifStack.empty());
// flows may contain returns, which are flowing out and so can be optimized
for (size_t i = 0; i < flows.size(); i++) {
- auto* flow = (*flows[i])->cast<Return>(); // cannot be a break
+ auto* flow = (*flows[i])->dynCast<Return>();
+ if (!flow) continue;
if (!flow->value) {
// return => nop
ExpressionManipulator::nop(flow);
@@ -184,11 +294,138 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
}
}
flows.clear();
+ // optimize loops (we don't do it while tracking flows, as they can interfere)
+ for (auto* loop : loops) {
+ anotherCycle |= optimizeLoop(loop);
+ }
+ loops.clear();
if (anotherCycle) worked = true;
} while (anotherCycle);
- // finally, we may have simplified ifs enough to turn them into selects
- struct Selectifier : public WalkerPass<PostWalker<Selectifier, Visitor<Selectifier>>> {
+
+ if (worked) {
+ // Our work may alter block and if types, they may now return
+ struct TypeUpdater : public WalkerPass<PostWalker<TypeUpdater, Visitor<TypeUpdater>>> {
+ void visitBlock(Block* curr) {
+ curr->finalize();
+ }
+ void visitLoop(Loop* curr) {
+ curr->finalize();
+ }
+ void visitIf(If* curr) {
+ curr->finalize();
+ }
+ };
+ TypeUpdater typeUpdater;
+ typeUpdater.walkFunction(func);
+ }
+
+ // thread trivial jumps
+ struct JumpThreader : public ControlFlowWalker<JumpThreader, Visitor<JumpThreader>> {
+ // map of all value-less breaks going to a block (and not a loop)
+ std::map<Block*, std::vector<Break*>> breaksToBlock;
+
+ // number of definitions of each name - when a name is defined more than once, it is not trivially safe to do this
+ std::map<Name, Index> numDefs;
+
+ // the names to update, when we can (just one def)
+ std::map<Break*, Name> newNames;
+
+ void visitBreak(Break* curr) {
+ if (!curr->value) {
+ if (auto* target = findBreakTarget(curr->name)->dynCast<Block>()) {
+ breaksToBlock[target].push_back(curr);
+ }
+ }
+ }
+ // TODO: Switch?
+ void visitBlock(Block* curr) {
+ if (curr->name.is()) numDefs[curr->name]++;
+
+ auto& list = curr->list;
+ if (list.size() == 1 && curr->name.is()) {
+ // if this block has just one child, a sub-block, then jumps to the former are jumps to us, really
+ if (auto* child = list[0]->dynCast<Block>()) {
+ if (child->name.is() && child->name != curr->name) {
+ auto& breaks = breaksToBlock[child];
+ for (auto* br : breaks) {
+ newNames[br] = curr->name;
+ breaksToBlock[curr].push_back(br); // update the list - we may push it even more later
+ }
+ breaksToBlock.erase(child);
+ }
+ }
+ } else if (list.size() == 2) {
+ // if this block has two children, a child-block and a simple jump, then jumps to child-block can be replaced with jumps to the new target
+ auto* child = list[0]->dynCast<Block>();
+ auto* jump = list[1]->dynCast<Break>();
+ if (child && child->name.is() && jump && ExpressionAnalyzer::isSimple(jump)) {
+ auto& breaks = breaksToBlock[child];
+ for (auto* br : breaks) {
+ newNames[br] = jump->name;
+ }
+ // if the jump is to another block then we can update the list, and maybe push it even more later
+ if (auto* newTarget = findBreakTarget(jump->name)->dynCast<Block>()) {
+ for (auto* br : breaks) {
+ breaksToBlock[newTarget].push_back(br);
+ }
+ }
+ breaksToBlock.erase(child);
+ }
+ }
+ }
+ void visitLoop(Loop* curr) {
+ if (curr->name.is()) numDefs[curr->name]++;
+ }
+
+ void finish() {
+ for (auto& iter : newNames) {
+ auto* br = iter.first;
+ auto name = iter.second;
+ if (numDefs[name] == 1) {
+ br->name = name;
+ }
+ }
+ }
+ };
+ JumpThreader jumpThreader;
+ jumpThreader.setModule(getModule());
+ jumpThreader.walkFunction(func);
+ jumpThreader.finish();
+
+ // perform some final optimizations
+ struct FinalOptimizer : public PostWalker<FinalOptimizer, Visitor<FinalOptimizer>> {
+ void visitBlock(Block* curr) {
+ // if a block has an if br else br, we can un-conditionalize the latter, allowing
+ // the if to become a br_if.
+ // * note that if not in a block already, then we need to create a block for this, so not useful otherwise
+ // * note that this only happens at the end of a block, as code after the if is dead
+ // * note that we do this at the end, because un-conditionalizing can interfere with optimizeLoop()ing.
+ auto& list = curr->list;
+ for (Index i = 0; i < list.size(); i++) {
+ auto* iff = list[i]->dynCast<If>();
+ if (!iff || !iff->ifFalse || isConcreteWasmType(iff->type)) continue; // if it lacked an if-false, it would already be a br_if, as that's the easy case
+ auto* ifTrueBreak = iff->ifTrue->dynCast<Break>();
+ if (ifTrueBreak && !ifTrueBreak->condition && canTurnIfIntoBrIf(iff->condition, ifTrueBreak->value)) {
+ // we are an if-else where the ifTrue is a break without a condition, so we can do this
+ list[i] = ifTrueBreak;
+ ifTrueBreak->condition = iff->condition;
+ ifTrueBreak->finalize();
+ ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifFalse);
+ continue;
+ }
+ // otherwise, perhaps we can flip the if
+ auto* ifFalseBreak = iff->ifFalse->dynCast<Break>();
+ if (ifFalseBreak && !ifFalseBreak->condition && canTurnIfIntoBrIf(iff->condition, ifFalseBreak->value)) {
+ list[i] = ifFalseBreak;
+ ifFalseBreak->condition = Builder(*getModule()).makeUnary(EqZInt32, iff->condition);
+ ifFalseBreak->finalize();
+ ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifTrue);
+ continue;
+ }
+ }
+ }
void visitIf(If* curr) {
+ // we may have simplified ifs enough to turn them into selects
if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) {
// if with else, consider turning it into a select if there is no control flow
// TODO: estimate cost
@@ -210,25 +447,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R
}
}
};
- Selectifier selectifier;
- selectifier.setModule(getModule());
- selectifier.walkFunction(func);
- if (worked) {
- // Our work may alter block and if types, they may now return
- struct TypeUpdater : public WalkerPass<PostWalker<TypeUpdater, Visitor<TypeUpdater>>> {
- void visitBlock(Block* curr) {
- curr->finalize();
- }
- void visitLoop(Loop* curr) {
- curr->finalize();
- }
- void visitIf(If* curr) {
- curr->finalize();
- }
- };
- TypeUpdater typeUpdater;
- typeUpdater.walkFunction(func);
- }
+ FinalOptimizer finalOptimizer;
+ finalOptimizer.setModule(getModule());
+ finalOptimizer.walkFunction(func);
}
};
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index 5315edac4..e032acc77 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -43,7 +43,7 @@ namespace wasm {
// Helper classes
struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCounter>> {
- std::vector<int>* numGetLocals;
+ std::vector<Index>* numGetLocals;
void visitGetLocal(GetLocal *curr) {
(*numGetLocals)[curr->index]++;
@@ -51,7 +51,7 @@ struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCoun
};
struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemover>> {
- std::vector<int>* numGetLocals;
+ std::vector<Index>* numGetLocals;
void visitSetLocal(SetLocal *curr) {
if ((*numGetLocals)[curr->index] == 0) {
@@ -102,8 +102,8 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// block returns
std::map<Name, std::vector<BlockBreak>> blockBreaks;
- // blocks that are the targets of a switch; we need to know this
- // since we can't produce a block return value for them.
+ // blocks that we can't produce a block return value for them.
+ // (switch target, or some other reason)
std::set<Name> unoptimizableBlocks;
// A stack of sinkables from the current traversal state. When
@@ -114,6 +114,12 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
// whether we need to run an additional cycle
bool anotherCycle;
+ // whether this is the first cycle
+ bool firstCycle;
+
+ // local => # of get_locals for it
+ std::vector<Index> numGetLocals;
+
static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) {
auto* curr = *currp;
if (curr->is<Break>()) {
@@ -187,9 +193,15 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
if (found != sinkables.end()) {
// sink it, and nop the origin
auto* set = (*found->second.item)->cast<SetLocal>();
- replaceCurrent(set);
- assert(!set->isTee());
- set->setTee(true);
+ if (firstCycle) {
+ // just one get_local of this, so just sink the value
+ assert(numGetLocals[curr->index] == 1);
+ replaceCurrent(set->value);
+ } else {
+ replaceCurrent(set);
+ assert(!set->isTee());
+ set->setTee(true);
+ }
// reuse the getlocal that is dying
*found->second.item = curr;
ExpressionManipulator::nop(curr);
@@ -259,7 +271,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
self->checkInvalidations(effects);
}
- if (set && !set->isTee()) {
+ if (set && !set->isTee() && (!self->firstCycle || self->numGetLocals[set->index] == 1)) {
Index index = set->index;
assert(self->sinkables.count(index) == 0);
self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp)));
@@ -316,9 +328,18 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
for (size_t j = 0; j < breaks.size(); j++) {
// move break set_local's value to the break
auto* breakSetLocalPointer = breaks[j].sinkables.at(sharedIndex).item;
- assert(!breaks[j].br->value);
- breaks[j].br->value = (*breakSetLocalPointer)->cast<SetLocal>()->value;
- ExpressionManipulator::nop(*breakSetLocalPointer);
+ auto* br = breaks[j].br;
+ assert(!br->value);
+ // if the break is conditional, then we must set the value here - if the break is not taken, we must still have the new value in the local
+ auto* set = (*breakSetLocalPointer)->cast<SetLocal>();
+ if (br->condition) {
+ br->value = set;
+ set->setTee(true);
+ *breakSetLocalPointer = getModule()->allocator.alloc<Nop>();
+ } else {
+ br->value = set->value;
+ ExpressionManipulator::nop(set);
+ }
}
// finally, create a set_local on the block itself
auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, block);
@@ -397,11 +418,22 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
}
void doWalkFunction(Function* func) {
+ // scan get_locals
+ numGetLocals.resize(func->getNumLocals());
+ std::fill(numGetLocals.begin(), numGetLocals.end(), 0);
+ GetLocalCounter counter;
+ counter.numGetLocals = &numGetLocals;
+ counter.walkFunction(func);
// multiple passes may be required per function, consider this:
// x = load
// y = store
// c(x, y)
- // the load cannot cross the store, but y can be sunk, after which so can x
+ // the load cannot cross the store, but y can be sunk, after which so can x.
+ //
+ // we start with a cycle focusing on single-use locals, which are easy to
+ // sink (we don't need to put a set), and a good match for common compiler
+ // output patterns. further cycles do fully general sinking.
+ firstCycle = true;
do {
anotherCycle = false;
// main operation
@@ -435,15 +467,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals,
sinkables.clear();
blockBreaks.clear();
unoptimizableBlocks.clear();
+ if (firstCycle) {
+ firstCycle = false;
+ anotherCycle = true;
+ }
} while (anotherCycle);
// Finally, after optimizing a function, we can see if we have set_locals
// for a local with no remaining gets, in which case, we can
// remove the set.
- // First, count get_locals
- std::vector<int> numGetLocals; // local => # of get_locals for it
- numGetLocals.resize(func->getNumLocals());
- GetLocalCounter counter;
- counter.numGetLocals = &numGetLocals;
+ // First, recount get_locals
+ std::fill(numGetLocals.begin(), numGetLocals.end(), 0);
counter.walkFunction(func);
// Second, remove unneeded sets
SetLocalRemover remover;
diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp
index db42a994a..0fdd12dec 100644
--- a/src/passes/Vacuum.cpp
+++ b/src/passes/Vacuum.cpp
@@ -47,26 +47,31 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>
case Expression::Id::CallImportId:
case Expression::Id::CallIndirectId:
case Expression::Id::SetLocalId:
- case Expression::Id::LoadId:
case Expression::Id::StoreId:
case Expression::Id::ReturnId:
- case Expression::Id::GetGlobalId:
case Expression::Id::SetGlobalId:
case Expression::Id::HostId:
case Expression::Id::UnreachableId: return curr; // always needed
+ case Expression::Id::LoadId: {
+ if (!resultUsed) {
+ return curr->cast<Load>()->ptr;
+ }
+ return curr;
+ }
case Expression::Id::ConstId:
case Expression::Id::GetLocalId:
+ case Expression::Id::GetGlobalId: {
+ if (!resultUsed) return nullptr;
+ return curr;
+ }
+
case Expression::Id::UnaryId:
case Expression::Id::BinaryId:
case Expression::Id::SelectId: {
if (resultUsed) {
return curr; // used, keep it
}
- // result is not used, perhaps it is dead
- if (curr->is<Const>() || curr->is<GetLocal>()) {
- return nullptr;
- }
// for unary, binary, and select, we need to check their arguments for side effects
if (auto* unary = curr->dynCast<Unary>()) {
if (EffectAnalyzer(unary->value).hasSideEffects()) {
@@ -132,13 +137,12 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>
void visitBlock(Block *curr) {
// compress out nops and other dead code
- bool resultUsed = ExpressionAnalyzer::isResultUsed(expressionStack, getFunction());
int skip = 0;
auto& list = curr->list;
size_t size = list.size();
bool needResize = false;
for (size_t z = 0; z < size; z++) {
- auto* optimized = optimize(list[z], z == size - 1 && resultUsed);
+ auto* optimized = optimize(list[z], z == size - 1 && isConcreteWasmType(curr->type));
if (!optimized) {
skip++;
needResize = true;
@@ -153,7 +157,12 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>
Break* br = list[z - skip]->dynCast<Break>();
Switch* sw = list[z - skip]->dynCast<Switch>();
if ((br && !br->condition) || sw) {
+ auto* last = list.back();
list.resize(z - skip + 1);
+ // if we removed the last one, and it was a return value, it must be returned
+ if (list.back() != last && isConcreteWasmType(last->type)) {
+ list.push_back(last);
+ }
needResize = false;
break;
}
@@ -165,7 +174,7 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>
if (!curr->name.is()) {
if (list.size() == 1) {
// just one element. replace the block, either with it or with a nop if it's not needed
- if (resultUsed || EffectAnalyzer(list[0]).hasSideEffects()) {
+ if (isConcreteWasmType(curr->type) || EffectAnalyzer(list[0]).hasSideEffects()) {
replaceCurrent(list[0]);
} else {
ExpressionManipulator::nop(curr);
@@ -200,11 +209,53 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>>
}
void visitDrop(Drop* curr) {
- // if the drop input has no side effects, it can be wiped out
- if (!EffectAnalyzer(curr->value).hasSideEffects()) {
+ // optimize the dropped value, maybe leaving nothing
+ curr->value = optimize(curr->value, false);
+ if (curr->value == nullptr) {
ExpressionManipulator::nop(curr);
return;
}
+ // a drop of a tee is a set
+ if (auto* set = curr->value->dynCast<SetLocal>()) {
+ assert(set->isTee());
+ set->setTee(false);
+ replaceCurrent(set);
+ return;
+ }
+ // if we are dropping a block's return value, we might be able to remove it entirely
+ if (auto* block = curr->value->dynCast<Block>()) {
+ auto* last = block->list.back();
+ if (isConcreteWasmType(last->type)) {
+ assert(block->type == last->type);
+ last = optimize(last, false);
+ if (!last) {
+ // we may be able to remove this, if there are no brs
+ bool canPop = true;
+ if (block->name.is()) {
+ BreakSeeker breakSeeker(block->name);
+ Expression* temp = block;
+ breakSeeker.walk(temp);
+ if (breakSeeker.found && breakSeeker.valueType != none) {
+ canPop = false;
+ }
+ }
+ if (canPop) {
+ block->list.back() = last;
+ block->list.pop_back();
+ block->type = none;
+ // we don't need the drop anymore, let's see what we have left in the block
+ if (block->list.size() > 1) {
+ replaceCurrent(block);
+ } else if (block->list.size() == 1) {
+ replaceCurrent(block->list[0]);
+ } else {
+ ExpressionManipulator::nop(curr);
+ }
+ return;
+ }
+ }
+ }
+ }
// sink a drop into an arm of an if-else if the other arm ends in an unreachable, as it if is a branch, this can make that branch optimizable and more vaccuming possible
auto* iff = curr->value->dynCast<If>();
if (iff && iff->ifFalse && isConcreteWasmType(iff->type)) {
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index 437b023bc..f3bc6d4c7 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -18,6 +18,7 @@
#include <passes/passes.h>
#include <pass.h>
+#include <wasm-validator.h>
namespace wasm {
@@ -65,6 +66,7 @@ void PassRegistry::registerPasses() {
registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass);
registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass);
registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass);
+ registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass);
registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass);
registerPass("metrics", "reports metrics", createMetricsPass);
registerPass("nm", "name list", createNameListPass);
@@ -74,6 +76,7 @@ void PassRegistry::registerPasses() {
registerPass("print", "print in s-expression format", createPrinterPass);
registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass);
registerPass("print-full", "print in full s-expression format", createFullPrinterPass);
+ registerPass("relooper-jump-threading", "thread relooper jumps (fastcomp output only)", createRelooperJumpThreadingPass);
registerPass("remove-imports", "removes imports and replaces them with nops", createRemoveImportsPass);
registerPass("remove-memory", "removes memory segments", createRemoveMemoryPass);
registerPass("remove-unused-brs", "removes breaks from locations that are not needed", createRemoveUnusedBrsPass);
@@ -132,10 +135,9 @@ void PassRunner::addDefaultGlobalOptimizationPasses() {
void PassRunner::run() {
if (debug) {
// for debug logging purposes, run each pass in full before running the other
- std::chrono::high_resolution_clock::time_point beforeEverything;
+ auto totalTime = std::chrono::duration<double>(0);
size_t padding = 0;
std::cerr << "[PassRunner] running passes..." << std::endl;
- beforeEverything = std::chrono::high_resolution_clock::now();
for (auto pass : passes) {
padding = std::max(padding, pass->name.size());
}
@@ -158,10 +160,19 @@ void PassRunner::run() {
auto after = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = after - before;
std::cerr << diff.count() << " seconds." << std::endl;
+ totalTime += diff;
+#if 0
+ // validate, ignoring the time
+ std::cerr << "[PassRunner] (validating)\n";
+ if (!WasmValidator().validate(*wasm)) {
+ std::cerr << "last pass (" << pass->name << ") broke validation\n";
+ abort();
+ }
+#endif
}
- auto after = std::chrono::high_resolution_clock::now();
- std::chrono::duration<double> diff = after - beforeEverything;
- std::cerr << "[PassRunner] passes took " << diff.count() << " seconds." << std::endl;
+ std::cerr << "[PassRunner] passes took " << totalTime.count() << " seconds." << std::endl;
+ // validate
+ assert(WasmValidator().validate(*wasm));
} else {
// non-debug normal mode, run them in an optimal manner - for locality it is better
// to run as many passes as possible on a single function before moving to the next
@@ -223,6 +234,11 @@ PassRunner::~PassRunner() {
}
}
+void PassRunner::doAdd(Pass* pass) {
+ passes.push_back(pass);
+ pass->prepareToRun(this, wasm);
+}
+
void PassRunner::runPassOnFunction(Pass* pass, Function* func) {
// function-parallel passes get a new instance per function
if (pass->isFunctionParallel()) {
diff --git a/src/passes/passes.h b/src/passes/passes.h
index 4bb76edad..80fa394e1 100644
--- a/src/passes/passes.h
+++ b/src/passes/passes.h
@@ -26,6 +26,7 @@ Pass *createCoalesceLocalsPass();
Pass *createCoalesceLocalsWithLearningPass();
Pass *createDeadCodeEliminationPass();
Pass *createDuplicateFunctionEliminationPass();
+Pass *createExtractFunctionPass();
Pass *createLowerIfElsePass();
Pass *createMergeBlocksPass();
Pass *createMetricsPass();
@@ -36,6 +37,7 @@ Pass *createPostEmscriptenPass();
Pass *createPrinterPass();
Pass *createMinifiedPrinterPass();
Pass *createFullPrinterPass();
+Pass *createRelooperJumpThreadingPass();
Pass *createRemoveImportsPass();
Pass *createRemoveMemoryPass();
Pass *createRemoveUnusedBrsPass();
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 2841585e6..f9b2e73f6 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -283,6 +283,28 @@ public:
return block;
}
+ // ensures the first node is a block, if it isn't already, and merges in the second,
+ // either as a single element or, if a block, by appending to the first block
+ Block* blockifyMerge(Expression* any, Expression* append) {
+ Block* block = nullptr;
+ if (any) block = any->dynCast<Block>();
+ if (!block) {
+ block = makeBlock(any);
+ } else {
+ assert(!isConcreteWasmType(block->type));
+ }
+ auto* other = append->dynCast<Block>();
+ if (!other) {
+ block->list.push_back(append);
+ } else {
+ for (auto* item : other->list) {
+ block->list.push_back(item);
+ }
+ }
+ block->finalize(); // TODO: move out of if
+ return block;
+ }
+
// a helper for the common pattern of a sequence of two expressions. Similar to
// blockify, but does *not* reuse a block if the first is one.
Block* makeSequence(Expression* left, Expression* right) {
@@ -291,6 +313,32 @@ public:
block->finalize();
return block;
}
+
+ // Grab a slice out of a block, replacing it with nops, and returning
+ // either another block with the contents (if more than 1) or a single expression
+ Expression* stealSlice(Block* input, Index from, Index to) {
+ Expression* ret;
+ if (to == from + 1) {
+ // just one
+ ret = input->list[from];
+ } else {
+ auto* block = allocator.alloc<Block>();
+ for (Index i = from; i < to; i++) {
+ block->list.push_back(input->list[i]);
+ }
+ block->finalize();
+ ret = block;
+ }
+ if (to == input->list.size()) {
+ input->list.resize(from);
+ } else {
+ for (Index i = from; i < to; i++) {
+ input->list[i] = allocator.alloc<Nop>();
+ }
+ }
+ input->finalize();
+ return ret;
+ }
};
} // namespace wasm
diff --git a/src/wasm-module-building.h b/src/wasm-module-building.h
index 52a4e7536..92b96d98d 100644
--- a/src/wasm-module-building.h
+++ b/src/wasm-module-building.h
@@ -97,7 +97,7 @@ public:
}
// Before parallelism, create all passes on the main thread here, to ensure
- // constructors run at least once on the main thread, for one-time init things.
+ // prepareToRun() is called for each pass before we start to optimize functions.
{
PassRunner passRunner(wasm);
addPrePasses(passRunner);
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 4725237dd..47b9d26e8 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -335,8 +335,8 @@ struct PostWalker : public Walker<SubType, VisitorType> {
}
case Expression::Id::SwitchId: {
self->pushTask(SubType::doVisitSwitch, currp);
- self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
self->pushTask(SubType::scan, &curr->cast<Switch>()->condition);
+ self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value);
break;
}
case Expression::Id::CallId: {
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 58a30f9a3..e23221337 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -104,6 +104,9 @@ public:
void visitIf(If *curr) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid");
+ if (!curr->ifFalse) {
+ shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body");
+ }
}
// override scan to add a pre and a post check task to all nodes
diff --git a/src/wasm.h b/src/wasm.h
index d6bdfe91f..4d5cf4d70 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -981,7 +981,6 @@ public:
Expression *condition, *ifTrue, *ifFalse;
void finalize() {
- assert(ifTrue);
if (ifFalse) {
if (ifTrue->type == ifFalse->type) {
type = ifTrue->type;
@@ -992,6 +991,8 @@ public:
} else {
type = none;
}
+ } else {
+ type = none; // if without else
}
}
};
@@ -1027,6 +1028,8 @@ public:
void finalize() {
if (condition) {
type = none;
+ } else {
+ type = unreachable;
}
}
};