diff options
author | Alon Zakai <alonzakai@gmail.com> | 2016-09-14 21:28:43 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-14 21:28:43 -0700 |
commit | e567fa8675831e79f855cea2181fa58beb107e42 (patch) | |
tree | 14f1e37d27244b349e8ee34939119002f742748d /src | |
parent | 63b499e3ec9bbdf4e79ab6d9dc198299516e8aec (diff) | |
parent | af3bea2786fe62070522b7fd7add4290a4cb4e6d (diff) | |
download | binaryen-e567fa8675831e79f855cea2181fa58beb107e42.tar.gz binaryen-e567fa8675831e79f855cea2181fa58beb107e42.tar.bz2 binaryen-e567fa8675831e79f855cea2181fa58beb107e42.zip |
Merge pull request #695 from WebAssembly/opts
Get optimizer on par with emscripten asm.js optimizer
Diffstat (limited to 'src')
-rw-r--r-- | src/asm2wasm.h | 39 | ||||
-rw-r--r-- | src/ast_utils.h | 103 | ||||
-rw-r--r-- | src/pass.h | 19 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/passes/CoalesceLocals.cpp | 142 | ||||
-rw-r--r-- | src/passes/DuplicateFunctionElimination.cpp | 10 | ||||
-rw-r--r-- | src/passes/ExtractFunction.cpp | 46 | ||||
-rw-r--r-- | src/passes/LowerInt64.cpp | 196 | ||||
-rw-r--r-- | src/passes/Metrics.cpp | 2 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 80 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.wast | 21 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.wast.processed | 21 | ||||
-rw-r--r-- | src/passes/RelooperJumpThreading.cpp | 248 | ||||
-rw-r--r-- | src/passes/RemoveImports.cpp | 14 | ||||
-rw-r--r-- | src/passes/RemoveUnusedBrs.cpp | 283 | ||||
-rw-r--r-- | src/passes/SimplifyLocals.cpp | 67 | ||||
-rw-r--r-- | src/passes/Vacuum.cpp | 73 | ||||
-rw-r--r-- | src/passes/pass.cpp | 26 | ||||
-rw-r--r-- | src/passes/passes.h | 2 | ||||
-rw-r--r-- | src/wasm-builder.h | 48 | ||||
-rw-r--r-- | src/wasm-module-building.h | 2 | ||||
-rw-r--r-- | src/wasm-traversal.h | 2 | ||||
-rw-r--r-- | src/wasm-validator.h | 3 | ||||
-rw-r--r-- | src/wasm.h | 5 |
24 files changed, 1115 insertions, 339 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index b19451a7b..eb5413e20 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -523,6 +523,8 @@ void Asm2WasmBuilder::processAsm(Ref ast) { optimizingBuilder = make_unique<OptimizingIncrementalModuleBuilder>(&wasm, numFunctions, [&](PassRunner& passRunner) { // run autodrop first, before optimizations passRunner.add<AutoDrop>(); + // optimize relooper label variable usage at the wasm level, where it is easy + passRunner.add("relooper-jump-threading"); }); } @@ -803,13 +805,16 @@ void Asm2WasmBuilder::processAsm(Ref ast) { add->right = parent->builder.makeConst(Literal((int32_t)parent->functionTableStarts[tableName])); } }; + PassRunner passRunner(&wasm); passRunner.add<FinalizeCalls>(this); passRunner.add<ReFinalize>(); // FinalizeCalls changes call types, need to percolate passRunner.add<AutoDrop>(); // FinalizeCalls may cause us to require additional drops if (optimize) { - passRunner.add("vacuum"); // autodrop can add some garbage - passRunner.add("remove-unused-brs"); // vacuum may open up more opportunities + // autodrop can add some garbage + passRunner.add("vacuum"); + passRunner.add("remove-unused-brs"); + passRunner.add("optimize-instructions"); } passRunner.run(); @@ -866,19 +871,17 @@ void Asm2WasmBuilder::processAsm(Ref ast) { #endif -#if 0 // enable asm2wasm i64 optimizations when browsers have consistent i64 support in wasm if (udivmoddi4.is() && getTempRet0.is()) { // generate a wasm-optimized __udivmoddi4 method, which we can do much more efficiently in wasm // we can only do this if we know getTempRet0 as well since we use it to figure out which minified global is tempRet0 // (getTempRet0 might be an import, if this is a shared module, so we can't optimize that case) - int tempRet0; + Name tempRet0; { Expression* curr = wasm.getFunction(getTempRet0)->body; if (curr->is<Block>()) curr = curr->cast<Block>()->list[0]; - curr = curr->cast<Return>()->value; - auto* load = curr->cast<Load>(); - auto* ptr = load->ptr->cast<Const>(); - tempRet0 = ptr->value.geti32() + load->offset; + if (curr->is<Return>()) curr = curr->cast<Return>()->value; + auto* get = curr->cast<GetGlobal>(); + tempRet0 = get->name; } // udivmoddi4 receives xl, xh, yl, yl, r, and // if r then *r = x % y @@ -898,13 +901,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) { return builder.makeSetLocal( target, builder.makeBinary( - Or, + OrInt64, builder.makeUnary( ExtendUInt32, builder.makeGetLocal(low, i32) ), builder.makeBinary( - Shl, + ShlInt64, builder.makeUnary( ExtendUInt32, builder.makeGetLocal(high, i32) @@ -923,10 +926,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) { 8, 0, 8, builder.makeGetLocal(r, i32), builder.makeBinary( - RemU, + RemUInt64, builder.makeGetLocal(x64, i64), builder.makeGetLocal(y64, i64) - ) + ), + i64 ) ) ); @@ -934,20 +938,19 @@ void Asm2WasmBuilder::processAsm(Ref ast) { builder.makeSetLocal( x64, builder.makeBinary( - DivU, + DivUInt64, builder.makeGetLocal(x64, i64), builder.makeGetLocal(y64, i64) ) ) ); body->list.push_back( - builder.makeStore( - 4, 0, 4, - builder.makeConst(Literal(int32_t(tempRet0))), + builder.makeSetGlobal( + tempRet0, builder.makeUnary( WrapInt64, builder.makeBinary( - ShrU, + ShrUInt64, builder.makeGetLocal(x64, i64), builder.makeConst(Literal(int64_t(32))) ) @@ -960,9 +963,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) { builder.makeGetLocal(x64, i64) ) ); + body->finalize(); func->body = body; } -#endif assert(WasmValidator().validate(wasm)); } diff --git a/src/ast_utils.h b/src/ast_utils.h index 9b2ff10cd..4664f22ae 100644 --- a/src/ast_utils.h +++ b/src/ast_utils.h @@ -27,19 +27,27 @@ namespace wasm { struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> { Name target; // look for this one XXX looking by name may fall prey to duplicate names - size_t found; + Index found; + WasmType valueType; - BreakSeeker(Name target) : target(target), found(false) {} + BreakSeeker(Name target) : target(target), found(0) {} + + void noteFound(Expression* value) { + found++; + if (found == 1) valueType = unreachable; + if (!value) valueType = none; + else if (value->type != unreachable) valueType = value->type; + } void visitBreak(Break *curr) { - if (curr->name == target) found++; + if (curr->name == target) noteFound(curr->value); } void visitSwitch(Switch *curr) { for (auto name : curr->targets) { - if (name == target) found++; + if (name == target) noteFound(curr->value); } - if (curr->default_ == target) found++; + if (curr->default_ == target) noteFound(curr->value); } static bool has(Expression* tree, Name target) { @@ -47,6 +55,12 @@ struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> { breakSeeker.walk(tree); return breakSeeker.found > 0; } + + static Index count(Expression* tree, Name target) { + BreakSeeker breakSeeker(target); + breakSeeker.walk(tree); + return breakSeeker.found; + } }; // Finds all functions that are reachable via direct calls. @@ -92,13 +106,16 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer bool calls = false; std::set<Index> localsRead; std::set<Index> localsWritten; + std::set<Name> globalsRead; + std::set<Name> globalsWritten; bool readsMemory = false; bool writesMemory = false; bool accessesLocal() { return localsRead.size() + localsWritten.size() > 0; } + bool accessesGlobal() { return globalsRead.size() + globalsWritten.size() > 0; } bool accessesMemory() { return calls || readsMemory || writesMemory; } - bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches; } - bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory; } + bool hasSideEffects() { return calls || localsWritten.size() > 0 || writesMemory || branches || globalsWritten.size() > 0; } + bool hasAnything() { return branches || calls || accessesLocal() || readsMemory || writesMemory || accessesGlobal(); } // checks if these effects would invalidate another set (e.g., if we write, we invalidate someone that reads, they can't be moved past us) bool invalidates(EffectAnalyzer& other) { @@ -115,6 +132,17 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer for (auto local : localsRead) { if (other.localsWritten.count(local)) return true; } + if ((accessesGlobal() && other.calls) || (other.accessesGlobal() && calls)) { + return true; + } + for (auto global : globalsWritten) { + if (other.globalsWritten.count(global) || other.globalsRead.count(global)) { + return true; + } + } + for (auto global : globalsRead) { + if (other.globalsWritten.count(global)) return true; + } return false; } @@ -163,8 +191,12 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer, Visitor<EffectAnalyzer void visitSetLocal(SetLocal *curr) { localsWritten.insert(curr->index); } - void visitGetGlobal(GetGlobal *curr) { readsMemory = true; } // TODO: global-specific - void visitSetGlobal(SetGlobal *curr) { writesMemory = true; } // stuff? + void visitGetGlobal(GetGlobal *curr) { + globalsRead.insert(curr->name); + } + void visitSetGlobal(SetGlobal *curr) { + globalsWritten.insert(curr->name); + } void visitLoad(Load *curr) { readsMemory = true; } void visitStore(Store *curr) { writesMemory = true; } void visitReturn(Return *curr) { branches = true; } @@ -340,6 +372,21 @@ struct ExpressionManipulator { } copier; return flexibleCopy(original, wasm, copier); } + + // Splice an item into the middle of a block's list + static void spliceIntoBlock(Block* block, Index index, Expression* add) { + auto& list = block->list; + if (index == list.size()) { + list.push_back(add); // simple append + } else { + // we need to make room + list.push_back(nullptr); + for (Index i = list.size() - 1; i > index; i--) { + list[i] = list[i - 1]; + } + list[index] = add; + } + } }; struct ExpressionAnalyzer { @@ -373,6 +420,25 @@ struct ExpressionAnalyzer { return func->result != none; } + // Checks if a break is a simple - no condition, no value, just a plain branching + static bool isSimple(Break* curr) { + return !curr->condition && !curr->value; + } + + // Checks if an expression ends with a simple break, + // and returns a pointer to it if so. + // (It might also have other internal branches.) + static Expression* getEndingSimpleBreak(Expression* curr) { + if (auto* br = curr->dynCast<Break>()) { + if (isSimple(br)) return br; + return nullptr; + } + if (auto* block = curr->dynCast<Block>()) { + if (block->list.size() > 0) return getEndingSimpleBreak(block->list.back()); + } + return nullptr; + } + template<typename T> static bool flexibleEqual(Expression* left, Expression* right, T& comparer) { std::vector<Name> nameStack; @@ -814,6 +880,25 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop, Visitor<Auto curr->finalize(); // we may have changed our type } + void visitIf(If* curr) { + if (curr->ifFalse) { + if (!isConcreteWasmType(curr->type)) { + // if either side of an if-else not returning a value is concrete, drop it + if (isConcreteWasmType(curr->ifTrue->type)) { + curr->ifTrue = Builder(*getModule()).makeDrop(curr->ifTrue); + } + if (isConcreteWasmType(curr->ifFalse->type)) { + curr->ifFalse = Builder(*getModule()).makeDrop(curr->ifFalse); + } + } + } else { + // if without else does not return a value, so the body must be dropped if it is concrete + if (isConcreteWasmType(curr->ifTrue->type)) { + curr->ifTrue = Builder(*getModule()).makeDrop(curr->ifTrue); + } + } + } + void visitFunction(Function* curr) { if (curr->result == none && isConcreteWasmType(curr->body->type)) { curr->body = Builder(*getModule()).makeDrop(curr->body); diff --git a/src/pass.h b/src/pass.h index 0ad8d8c1d..e237b8a98 100644 --- a/src/pass.h +++ b/src/pass.h @@ -72,17 +72,17 @@ struct PassRunner { void add(std::string passName) { auto pass = PassRegistry::get()->createPass(passName); if (!pass) Fatal() << "Could not find pass: " << passName << "\n"; - passes.push_back(pass); + doAdd(pass); } template<class P> void add() { - passes.push_back(new P()); + doAdd(new P()); } template<class P, class Arg> void add(Arg arg){ - passes.push_back(new P(arg)); + doAdd(new P(arg)); } // Adds the default set of optimization passes; this is @@ -110,6 +110,8 @@ struct PassRunner { ~PassRunner(); private: + void doAdd(Pass* pass); + void runPassOnFunction(Pass* pass, Function* func); }; @@ -121,12 +123,13 @@ public: virtual ~Pass() {}; // Override this to perform preparation work before the pass runs. - virtual void prepare(PassRunner* runner, Module* module) {} + // This will be called before the pass is run on a module. + virtual void prepareToRun(PassRunner* runner, Module* module) {} + + // Implement this with code to run the pass on the whole module virtual void run(PassRunner* runner, Module* module) = 0; - // Override this to perform finalization work after the pass runs. - virtual void finalize(PassRunner* runner, Module* module) {} - // Run on a single function. This has no prepare/finalize calls. + // Implement this with code to run the pass on a single function virtual void runFunction(PassRunner* runner, Module* module, Function* function) { WASM_UNREACHABLE(); // by default, passes cannot be run this way } @@ -166,9 +169,7 @@ template <typename WalkerType> class WalkerPass : public Pass, public WalkerType { public: void run(PassRunner* runner, Module* module) override { - prepare(runner, module); WalkerType::walkModule(module); - finalize(runner, module); } void runFunction(PassRunner* runner, Module* module, Function* func) override { diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 1b4c65562..30f63c880 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -3,6 +3,7 @@ SET(passes_SOURCES CoalesceLocals.cpp DeadCodeElimination.cpp DuplicateFunctionElimination.cpp + ExtractFunction.cpp MergeBlocks.cpp Metrics.cpp NameManager.cpp @@ -11,6 +12,7 @@ SET(passes_SOURCES PostEmscripten.cpp Precompute.cpp Print.cpp + RelooperJumpThreading.cpp RemoveImports.cpp RemoveMemory.cpp RemoveUnusedBrs.cpp diff --git a/src/passes/CoalesceLocals.cpp b/src/passes/CoalesceLocals.cpp index 2063a20db..ba2c6dd81 100644 --- a/src/passes/CoalesceLocals.cpp +++ b/src/passes/CoalesceLocals.cpp @@ -198,6 +198,7 @@ struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<Coal void scanLivenessThroughActions(std::vector<Action>& actions, LocalSet& live); void pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices); + void pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices, Index& removedCopies); virtual void pickIndices(std::vector<Index>& indices); // returns a vector of oldIndex => newIndex @@ -224,14 +225,17 @@ struct CoalesceLocals : public WalkerPass<CFGWalker<CoalesceLocals, Visitor<Coal // copying state - std::vector<uint8_t> copies; // canonicalized - accesses should check (low, high) + std::vector<uint8_t> copies; // canonicalized - accesses should check (low, high) TODO: use a map for high N, as this tends to be sparse? or don't look at copies at all for big N? + std::vector<Index> totalCopies; // total # of copies for each local, with all others void addCopy(Index i, Index j) { auto k = std::min(i, j) * numLocals + std::max(i, j); copies[k] = std::min(copies[k], uint8_t(254)) + 1; + totalCopies[i]++; + totalCopies[j]++; } - bool getCopies(Index i, Index j) { + uint8_t getCopies(Index i, Index j) { return copies[std::min(i, j) * numLocals + std::max(i, j)]; } }; @@ -240,6 +244,8 @@ void CoalesceLocals::doWalkFunction(Function* func) { numLocals = func->getNumLocals(); copies.resize(numLocals * numLocals); std::fill(copies.begin(), copies.end(), 0); + totalCopies.resize(numLocals); + std::fill(totalCopies.begin(), totalCopies.end(), 0); // collect initial liveness info WalkerPass<CFGWalker<CoalesceLocals, Visitor<CoalesceLocals>, Liveness>>::doWalkFunction(func); // ignore links to dead blocks, so they don't confuse us and we can see their stores are all ineffective @@ -388,8 +394,43 @@ void CoalesceLocals::calculateInterferences(const LocalSet& locals) { // Indices decision making void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices) { - // simple greedy coloring - // TODO: take into account eliminated copies + Index removedCopies; + pickIndicesFromOrder(order, indices, removedCopies); +} + +void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector<Index>& indices, Index& removedCopies) { + // mostly-simple greedy coloring +#if CFG_DEBUG + std::cerr << "\npickIndicesFromOrder on " << getFunction()->name << '\n'; + std::cerr << getFunction()->body << '\n'; + std::cerr << "order:\n"; + for (auto i : order) std::cerr << i << ' '; + std::cerr << '\n'; + std::cerr << "interferences:\n"; + for (Index i = 0; i < numLocals; i++) { + for (Index j = 0; j < i + 1; j++) { + std::cerr << " "; + } + for (Index j = i + 1; j < numLocals; j++) { + std::cerr << int(interferes(i, j)) << ' '; + } + std::cerr << " : $" << i << '\n'; + } + std::cerr << "copies:\n"; + for (Index i = 0; i < numLocals; i++) { + for (Index j = 0; j < i + 1; j++) { + std::cerr << " "; + } + for (Index j = i + 1; j < numLocals; j++) { + std::cerr << int(getCopies(i, j)) << ' '; + } + std::cerr << " : $" << i << '\n'; + } + std::cerr << "total copies:\n"; + for (Index i = 0; i < numLocals; i++) { + std::cerr << " $" << i << ": " << totalCopies[i] << '\n'; + } +#endif // TODO: take into account distribution (99-1 is better than 50-50 with two registers, for gzip) std::vector<WasmType> types; std::vector<bool> newInterferences; // new index * numLocals => list of all interferences of locals merged to it @@ -397,12 +438,13 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector indices.resize(numLocals); types.resize(numLocals); newInterferences.resize(numLocals * numLocals); - newCopies.resize(numLocals * numLocals); std::fill(newInterferences.begin(), newInterferences.end(), 0); + auto numParams = getFunction()->getNumParams(); + newCopies.resize(numParams * numLocals); // start with enough room for the params std::fill(newCopies.begin(), newCopies.end(), 0); Index nextFree = 0; + removedCopies = 0; // we can't reorder parameters, they are fixed in order, and cannot coalesce - auto numParams = getFunction()->getNumParams(); Index i = 0; for (; i < numParams; i++) { assert(order[i] == i); // order must leave the params in place @@ -421,6 +463,7 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector for (Index j = 0; j < nextFree; j++) { if (!newInterferences[j * numLocals + actual] && getFunction()->getLocalType(actual) == types[j]) { // this does not interfere, so it might be what we want. but pick the one eliminating the most copies + // (we could stop looking forward when there are no more items that have copies anyhow, but it doesn't seem to help) auto currCopies = newCopies[j * numLocals + actual]; if (found == Index(-1) || currCopies > foundCopies) { indices[actual] = found = j; @@ -432,7 +475,14 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector indices[actual] = found = nextFree; types[found] = getFunction()->getLocalType(actual); nextFree++; + removedCopies += getCopies(found, actual); + newCopies.resize(nextFree * numLocals); + } else { + removedCopies += foundCopies; } +#if CFG_DEBUG + std::cerr << "set local $" << actual << " to $" << found << '\n'; +#endif // merge new interferences and copies for the new index for (Index k = i + 1; k < numLocals; k++) { auto j = order[k]; // go in the order, we only need to update for those we will see later @@ -442,31 +492,85 @@ void CoalesceLocals::pickIndicesFromOrder(std::vector<Index>& order, std::vector } } +// Utilities for operating on permutation vectors + +static std::vector<Index> makeIdentity(Index num) { + std::vector<Index> ret; + ret.resize(num); + for (Index i = 0; i < num; i++) { + ret[i] = i; + } + return ret; +} + +static void setIdentity(std::vector<Index>& ret) { + auto num = ret.size(); + assert(num > 0); // must already be of the right size + for (Index i = 0; i < num; i++) { + ret[i] = i; + } +} + +static std::vector<Index> makeReversed(std::vector<Index>& original) { + std::vector<Index> ret; + auto num = original.size(); + ret.resize(num); + for (Index i = 0; i < num; i++) { + ret[original[i]] = i; + } + return ret; +} + +// given a baseline order, adjust it based on an important order of priorities (higher values +// are higher priority). The priorities take precedence, unless they are equal and then +// the original order should be kept. +std::vector<Index> adjustOrderByPriorities(std::vector<Index>& baseline, std::vector<Index>& priorities) { + std::vector<Index> ret = baseline; + std::vector<Index> reversed = makeReversed(baseline); + std::sort(ret.begin(), ret.end(), [&priorities, &reversed](Index x, Index y) { + return priorities[x] > priorities[y] || (priorities[x] == priorities[y] && reversed[x] < reversed[y]); + }); + return ret; +}; + void CoalesceLocals::pickIndices(std::vector<Index>& indices) { if (numLocals == 0) return; if (numLocals == 1) { indices.push_back(0); return; } + if (getFunction()->getNumVars() <= 1) { + // nothing to think about here, since we can't reorder params + indices = makeIdentity(numLocals); + return; + } + // take into account total copies. but we must keep params in place, so give them max priority + auto adjustedTotalCopies = totalCopies; + auto numParams = getFunction()->getNumParams(); + for (Index i = 0; i < numParams; i++) { + adjustedTotalCopies[i] = std::numeric_limits<Index>::max(); + } // first try the natural order. this is less arbitrary than it seems, as the program // may have a natural order of locals inherent in it. - std::vector<Index> order; - order.resize(numLocals); - for (Index i = 0; i < numLocals; i++) { - order[i] = i; - } - pickIndicesFromOrder(order, indices); + auto order = makeIdentity(numLocals); + order = adjustOrderByPriorities(order, adjustedTotalCopies); + Index removedCopies; + pickIndicesFromOrder(order, indices, removedCopies); auto maxIndex = *std::max_element(indices.begin(), indices.end()); - // next try the reverse order. this both gives us anothe chance at something good, + // next try the reverse order. this both gives us another chance at something good, // and also the very naturalness of the simple order may be quite suboptimal - auto numParams = getFunction()->getNumParams(); + setIdentity(order); for (Index i = numParams; i < numLocals; i++) { order[i] = numParams + numLocals - 1 - i; } + order = adjustOrderByPriorities(order, adjustedTotalCopies); std::vector<Index> reverseIndices; - pickIndicesFromOrder(order, reverseIndices); + Index reverseRemovedCopies; + pickIndicesFromOrder(order, reverseIndices, reverseRemovedCopies); auto reverseMaxIndex = *std::max_element(reverseIndices.begin(), reverseIndices.end()); - if (reverseMaxIndex < maxIndex) { + // prefer to remove copies foremost, as it matters more for code size (minus gzip), and + // improves throughput. + if (reverseRemovedCopies > removedCopies || (reverseRemovedCopies == removedCopies && reverseMaxIndex < maxIndex)) { indices.swap(reverseIndices); } } @@ -553,7 +657,8 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) { void calculateFitness(Order* order) { // apply the order std::vector<Index> indices; // the phenotype - parent->pickIndicesFromOrder(*order, indices); + Index removedCopies; + parent->pickIndicesFromOrder(*order, indices, removedCopies); auto maxIndex = *std::max_element(indices.begin(), indices.end()); assert(maxIndex <= parent->numLocals); // main part of fitness is the number of locals @@ -563,6 +668,7 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) { for (Index i = 0; i < parent->numLocals; i++) { if ((*order)[i] == i) fitness += fragment; // boost for each that wasn't moved } + fitness = (100 * fitness) + removedCopies; // removing copies is a secondary concern order->setFitness(fitness); } @@ -577,6 +683,8 @@ void CoalesceLocalsWithLearning::pickIndices(std::vector<Index>& indices) { // first, there may be an inherent order in the input (frequent indices are lower, // etc.). second, by ensuring we start with the natural order, we ensure we are at // least as good as the non-learning variant. + // TODO: use ::pickIndices from the parent, so we literally get the simpler approach + // as our first option first = false; } else { // leave params alone, shuffle the rest diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp index 2b8e69b54..cfe2d8565 100644 --- a/src/passes/DuplicateFunctionElimination.cpp +++ b/src/passes/DuplicateFunctionElimination.cpp @@ -102,7 +102,17 @@ struct DuplicateFunctionElimination : public Pass { auto& group = pair.second; if (group.size() == 1) continue; // pick a base for each group, and try to replace everyone else to it. TODO: multiple bases per hash group, for collisions +#if 0 + // for comparison purposes, pick in a deterministic way based on the names + Function* base = nullptr; + for (auto* func : group) { + if (!base || strcmp(func->name.str, base->name.str) < 0) { + base = func; + } + } +#else Function* base = group[0]; +#endif for (auto* func : group) { if (func != base && equal(func, base)) { replacements[func->name] = base->name; diff --git a/src/passes/ExtractFunction.cpp b/src/passes/ExtractFunction.cpp new file mode 100644 index 000000000..b342efc3b --- /dev/null +++ b/src/passes/ExtractFunction.cpp @@ -0,0 +1,46 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Removes code from all functions but one, leaving a valid module +// with (mostly) just the code you want to debug (function-parallel, +// non-lto) passes on. + +#include "wasm.h" +#include "pass.h" + +namespace wasm { + +Name TO_LEAVE("_bytearray_join"); // TODO: commandline param + +struct ExtractFunction : public Pass { + void run(PassRunner* runner, Module* module) override { + for (auto& func : module->functions) { + if (func->name != TO_LEAVE) { + // wipe out the body + func->body = module->allocator.alloc<Unreachable>(); + } + } + } +}; + +// declare pass + +Pass *createExtractFunctionPass() { + return new ExtractFunction(); +} + +} // namespace wasm + diff --git a/src/passes/LowerInt64.cpp b/src/passes/LowerInt64.cpp deleted file mode 100644 index 69f6d5ae9..000000000 --- a/src/passes/LowerInt64.cpp +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright 2015 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -// Lowers 64-bit ints to pairs of 32-bit ints, plus some library routines -// -// This is useful for wasm2asm, as JS has no native 64-bit integer support. -// - -#include <memory> - -#include <wasm.h> -#include <pass.h> - -namespace wasm { - -cashew::IString GET_HIGH("getHigh"); - -struct LowerInt64 : public Pass { - MixedArena* allocator; - std::unique_ptr<NameManager> namer; - - void prepare(PassRunner* runner, Module *module) override { - allocator = runner->allocator; - namer = make_unique<NameManager>(); - namer->run(runner, module); - } - - std::map<Expression*, Expression*> fixes; // fixed nodes, outputs of lowering, mapped to their high bits - std::map<Name, Name> locals; // maps locals which were i64->i32 to their high bits - - void makeGetHigh() { - auto ret = allocator->alloc<CallImport>(); - ret->target = GET_HIGH; - ret->type = i32; - return ret; - } - - void fixCall(CallBase *call) { - auto& operands = call->operands; - for (size_t i = 0; i < operands.size(); i++) { - auto fix = fixes.find(operands[i]); - if (fix != fixes.end()) { - operands.insert(operands.begin() + i + 1, *fix); - } - } - if (curr->type == i64) { - curr->type = i32; - fixes[curr] = makeGetHigh(); // called function will setHigh - } - } - - void visitCall(Call *curr) { - fixCall(curr); - } - void visitCallImport(CallImport *curr) { - fixCall(curr); - } - void visitCallIndirect(CallIndirect *curr) { - fixCall(curr); - } - void visitGetLocal(GetLocal *curr) { - if (curr->type == i64) { - if (locals.count(curr->name) == 0) { - Name highName = namer->getUnique("high"); - locals[curr->name] = highName; - }; - curr->type = i32; - auto high = allocator->alloc<GetLocal>(); - high->name = locals[curr->name]; - high->type = i32; - fixes[curr] = high; - } - } - void visitSetLocal(SetLocal *curr) { - if (curr->type == i64) { - Name highName; - if (locals.count(curr->name) == 0) { - highName = namer->getUnique("high"); - locals[curr->name] = highName; - } else { - highName = locals[curr->name]; - } - curr->type = i32; - auto high = allocator->alloc<GetLocal>(); - high->name = highName; - high->type = i32; - fixes[curr] = high; - // Set the high bits - auto set = allocator.alloc<SetLocal>(); - set->name = highName; - set->value = fixes[curr->value]; - set->type = i32; - assert(set->value); - auto low = allocator->alloc<GetLocal>(); - low->name = curr->name; - low->type = i32; - auto ret = allocator.alloc<Block>(); - ret->list.push_back(curr); - ret->list.push_back(set); - ret->list.push_back(low); // so the block returns the low bits - ret->finalize(); - fixes[ret] = high; - replaceCurrent(ret); - } - } - - // sets an expression to a local, and returns a block - Block* setToLocalForBlock(Expression *value, Name& local, Block *ret = nullptr) { - if (!ret) ret = allocator->alloc<Block>(); - if (value->is<GetLocal>()) { - local = value->name; - } else if (value->is<SetLocal>()) { - local = value->name; - } else { - auto set = allocator.alloc<SetLocal>(); - set->name = local = namer->getUnique("temp"); - set->value = value; - set->type = value->type; - ret->list.push_back(set); - } - ret->finalize(); - return ret; - } - - GetLocal* getLocal(Name name) { - auto ret = allocator->alloc<GetLocal>(); - ret->name = name; - ret->type = i32; - return ret; - } - - void visitLoad(Load *curr) { - if (curr->type == i64) { - Name local; - auto ret = setToLocalForBlock(curr->ptr, local); - curr->ptr = getLocal(local); - curr->type = i32; - curr->bytes = 4; - auto high = allocator->alloc<Load>(); - *high = *curr; - high->ptr = getLocal(local); - high->offset += 4; - ret->list.push_back(curr); - fixes[ret] = high; - replaceCurrent(ret); - } - } - void visitStore(Store *curr) { - if (curr->type == i64) { - Name localPtr, localValue; - auto ret = setToLocalForBlock(curr->ptr, localPtr); - setToLocalForBlock(curr->value, localValue); - curr->ptr = getLocal(localPtr); - curr->value = getLocal(localValue); - curr->type = i32; - curr->bytes = 4; - auto high = allocator->alloc<Load>(); - *high = *curr; - high->ptr = getLocal(localPtr); - high->value = getLocal(localValue); - high->offset += 4; - ret->list.push_back(high); - ret->list.push_back(curr); - fixes[ret] = high; - replaceCurrent(ret); - } - } - void visitFunction(Function *curr) { - // TODO: new params - for (auto localPair : locals) { // TODO: ignore params - curr->locals.emplace_back(localPair.second, i32); - } - fixes.clear(); - locals.clear(); - } -}; - -Pass *createLowerInt64Pass() { - return new LowerInt64(); -} - -} // namespace wasm diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index c5309850e..9181da9bb 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -34,7 +34,7 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor< counts[name]++; } - void finalize(PassRunner *runner, Module *module) override { + void visitModule(Module* module) { ostream &o = cout; o << "Counts" << "\n"; diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 2a5e7b8c4..669a19b89 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -24,6 +24,7 @@ #include <pass.h> #include <wasm-s-parser.h> #include <support/threads.h> +#include <ast_utils.h> namespace wasm { @@ -50,7 +51,6 @@ struct PatternDatabase { std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more PatternDatabase() { - // TODO: do this on first use, with a lock, to avoid startup pause // generate module input = strdup( #include "OptimizeInstructions.wast.processed" @@ -74,14 +74,12 @@ struct PatternDatabase { static PatternDatabase* database = nullptr; -static void ensureDatabase() { - if (!database) { - // we must only ever create one database - static OnlyOnce onlyOnce; - onlyOnce.verify(); +struct DatabaseEnsurer { + DatabaseEnsurer() { + assert(!database); database = new PatternDatabase; } -} +}; // Check for matches and apply them struct Match { @@ -161,13 +159,18 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, Pass* create() override { return new OptimizeInstructions; } - OptimizeInstructions() { - ensureDatabase(); + void prepareToRun(PassRunner* runner, Module* module) override { + static DatabaseEnsurer ensurer; } void visitExpression(Expression* curr) { // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles while (1) { + auto* handOptimized = handOptimize(curr); + if (handOptimized) { + curr = handOptimized; + replaceCurrent(curr); + } auto iter = database->patternMap.find(curr->_id); if (iter == database->patternMap.end()) return; auto& patterns = iter->second; @@ -184,6 +187,65 @@ struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, if (!more) break; } } + + // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe + Expression* handOptimize(Expression* curr) { + if (auto* binary = curr->dynCast<Binary>()) { + // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc. + if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) { + auto shifts = binary->right->cast<Const>()->value.geti32(); + if (shifts == 24 || shifts == 16) { + auto* left = binary->left->dynCast<Binary>(); + if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) { + auto* load = left->left->dynCast<Load>(); + if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) { + load->signed_ = true; + return load; + } + } + } + } + } else if (auto* set = curr->dynCast<SetGlobal>()) { + // optimize out a set of a get + auto* get = set->value->dynCast<GetGlobal>(); + if (get && get->name == set->name) { + ExpressionManipulator::nop(curr); + } + } else if (auto* iff = curr->dynCast<If>()) { + iff->condition = optimizeBoolean(iff->condition); + } else if (auto* select = curr->dynCast<Select>()) { + select->condition = optimizeBoolean(select->condition); + auto* condition = select->condition->dynCast<Unary>(); + if (condition && condition->op == EqZInt32) { + // flip select to remove eqz, if we can reorder + EffectAnalyzer ifTrue(select->ifTrue); + EffectAnalyzer ifFalse(select->ifFalse); + if (!ifTrue.invalidates(ifFalse)) { + select->condition = condition->value; + std::swap(select->ifTrue, select->ifFalse); + } + } + } else if (auto* br = curr->dynCast<Break>()) { + if (br->condition) { + br->condition = optimizeBoolean(br->condition); + } + } + return nullptr; + } + +private: + + Expression* optimizeBoolean(Expression* boolean) { + auto* condition = boolean->dynCast<Unary>(); + if (condition && condition->op == EqZInt32) { + auto* condition2 = condition->value->dynCast<Unary>(); + if (condition2 && condition2->op == EqZInt32) { + // double eqz + return condition2->value; + } + } + return boolean; + } }; Pass *createOptimizeInstructionsPass() { diff --git a/src/passes/OptimizeInstructions.wast b/src/passes/OptimizeInstructions.wast index 7d5c56881..48a4e2c9d 100644 --- a/src/passes/OptimizeInstructions.wast +++ b/src/passes/OptimizeInstructions.wast @@ -19,8 +19,8 @@ ;; main function. each block here is a pattern pair of input => output (func $patterns + ;; flip if-else arms to get rid of an eqz (block - ;; flip if-else arms to get rid of an eqz (if (i32.eqz (call_import $i32.expr (i32.const 0)) @@ -34,6 +34,25 @@ (call_import $any.expr (i32.const 1)) ) ) + ;; equal 0 => eqz + (block + (i32.eq + (call_import $any.expr (i32.const 0)) + (i32.const 0) + ) + (i32.eqz + (call_import $any.expr (i32.const 0)) + ) + ) + (block + (i32.eq + (i32.const 0) + (call_import $any.expr (i32.const 0)) + ) + (i32.eqz + (call_import $any.expr (i32.const 0)) + ) + ) ;; De Morgans Laws (block (i32.eqz (i32.eq (call_import $i32.expr (i32.const 0)) (call_import $i32.expr (i32.const 1)))) diff --git a/src/passes/OptimizeInstructions.wast.processed b/src/passes/OptimizeInstructions.wast.processed index 61fde86c6..13ccc8241 100644 --- a/src/passes/OptimizeInstructions.wast.processed +++ b/src/passes/OptimizeInstructions.wast.processed @@ -19,8 +19,8 @@ "\n" ";; main function. each block here is a pattern pair of input => output\n" "(func $patterns\n" -"(block\n" ";; flip if-else arms to get rid of an eqz\n" +"(block\n" "(if\n" "(i32.eqz\n" "(call_import $i32.expr (i32.const 0))\n" @@ -34,6 +34,25 @@ "(call_import $any.expr (i32.const 1))\n" ")\n" ")\n" +";; equal 0 => eqz\n" +"(block\n" +"(i32.eq\n" +"(call_import $any.expr (i32.const 0))\n" +"(i32.const 0)\n" +")\n" +"(i32.eqz\n" +"(call_import $any.expr (i32.const 0))\n" +")\n" +")\n" +"(block\n" +"(i32.eq\n" +"(i32.const 0)\n" +"(call_import $any.expr (i32.const 0))\n" +")\n" +"(i32.eqz\n" +"(call_import $any.expr (i32.const 0))\n" +")\n" +")\n" ";; De Morgans Laws\n" "(block\n" "(i32.eqz (i32.eq (call_import $i32.expr (i32.const 0)) (call_import $i32.expr (i32.const 1))))\n" diff --git a/src/passes/RelooperJumpThreading.cpp b/src/passes/RelooperJumpThreading.cpp new file mode 100644 index 000000000..7f74220d5 --- /dev/null +++ b/src/passes/RelooperJumpThreading.cpp @@ -0,0 +1,248 @@ +/* + * Copyright 2016 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Optimize relooper-generated label variable usage: add blocks and turn +// a label-set/break/label-check into a break into the new block. +// This assumes the very specific output the fastcomp relooper emits, +// including the name of the 'label' variable. + +#include "wasm.h" +#include "pass.h" +#include "ast_utils.h" + +namespace wasm { + +static Name LABEL("label"); + +// We need to use new label names, which we cannot create in parallel, so pre-create them + +const Index MAX_NAME_INDEX = 1000; + +std::vector<Name>* innerNames = nullptr; +std::vector<Name>* outerNames = nullptr; + +struct NameEnsurer { + NameEnsurer() { + assert(!innerNames); + assert(!outerNames); + innerNames = new std::vector<Name>; + outerNames = new std::vector<Name>; + for (Index i = 0; i < MAX_NAME_INDEX; i++) { + innerNames->push_back(Name(std::string("jumpthreading$inner$") + std::to_string(i))); + outerNames->push_back(Name(std::string("jumpthreading$outer$") + std::to_string(i))); + } + } +}; + +static If* isLabelCheckingIf(Expression* curr, Index labelIndex) { + if (!curr) return nullptr; + auto* iff = curr->dynCast<If>(); + if (!iff) return nullptr; + auto* condition = iff->condition->dynCast<Binary>(); + if (!(condition && condition->op == EqInt32)) return nullptr; + auto* left = condition->left->dynCast<GetLocal>(); + if (!(left && left->index == labelIndex)) return nullptr; + return iff; +} + +static Index getCheckedLabelValue(If* iff) { + return iff->condition->cast<Binary>()->right->cast<Const>()->value.geti32(); +} + +static SetLocal* isLabelSettingSetLocal(Expression* curr, Index labelIndex) { + if (!curr) return nullptr; + auto* set = curr->dynCast<SetLocal>(); + if (!set) return nullptr; + if (set->index != labelIndex) return nullptr; + return set; +} + +static Index getSetLabelValue(SetLocal* set) { + return set->value->cast<Const>()->value.geti32(); +} + +struct LabelUseFinder : public PostWalker<LabelUseFinder, Visitor<LabelUseFinder>> { + Index labelIndex; + std::map<Index, Index>& checks; // label value => number of checks on it + std::map<Index, Index>& sets; // label value => number of sets to it + + LabelUseFinder(Index labelIndex, std::map<Index, Index>& checks, std::map<Index, Index>& sets) : labelIndex(labelIndex), checks(checks), sets(sets) {} + + void visitIf(If* curr) { + if (isLabelCheckingIf(curr, labelIndex)) { + checks[getCheckedLabelValue(curr)]++; + } + } + + void visitSetLocal(SetLocal* curr) { + if (isLabelSettingSetLocal(curr, labelIndex)) { + sets[getSetLabelValue(curr)]++; + } + } +}; + +struct RelooperJumpThreading : public WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new RelooperJumpThreading; } + + void prepareToRun(PassRunner* runner, Module* module) override { + static NameEnsurer ensurer; + } + + std::map<Index, Index> labelChecks; + std::map<Index, Index> labelSets; + + Index labelIndex; + Index newNameCounter = 0; + + void visitBlock(Block* curr) { + // look for the if label == X pattern + auto& list = curr->list; + if (list.size() == 0) return; + for (Index i = 0; i < list.size() - 1; i++) { + // once we see something that might be irreducible, we must skip that if and the rest of the dependents + bool irreducible = false; + Index origin = i; + for (Index j = i + 1; j < list.size(); j++) { + if (auto* iff = isLabelCheckingIf(list[j], labelIndex)) { + irreducible |= hasIrreducibleControlFlow(iff, list[origin]); + if (!irreducible) { + optimizeJumpsToLabelCheck(list[origin], iff); + ExpressionManipulator::nop(iff); + } + i++; + continue; + } + // if the next element is a block, it may be the holding block of label-checking ifs + if (auto* holder = list[j]->dynCast<Block>()) { + if (holder->list.size() > 0) { + if (If* iff = isLabelCheckingIf(holder->list[0], labelIndex)) { + irreducible |= hasIrreducibleControlFlow(iff, list[origin]); + if (!irreducible) { + // this is indeed a holder. we can process the ifs, and must also move + // the block to enclose the origin, so it is properly reachable + assert(holder->list.size() == 1); // must be size 1, a relooper multiple will have its own label, and is an if-else sequence and nothing more + optimizeJumpsToLabelCheck(list[origin], iff); + holder->list[0] = list[origin]; + list[origin] = holder; + // reuse the if as a nop + list[j] = iff; + ExpressionManipulator::nop(iff); + } + i++; + continue; + } + } + } + break; // we didn't see something we like, so stop here + } + } + } + + void doWalkFunction(Function* func) { + // if there isn't a label variable, nothing for us to do + if (func->localIndices.count(LABEL)) { + labelIndex = func->getLocalIndex(LABEL); + LabelUseFinder finder(labelIndex, labelChecks, labelSets); + finder.walk(func->body); + WalkerPass<ExpressionStackWalker<RelooperJumpThreading, Visitor<RelooperJumpThreading>>>::doWalkFunction(func); + } + } + +private: + + bool hasIrreducibleControlFlow(If* iff, Expression* origin) { + // Gather the checks in this if chain. If all the label values checked are only set in origin, + // then since origin is right before us, this is not irreducible - we can replace all sets + // in origin with jumps forward to us, and since there is nothing else, this is safe and complete. + // We must also have the property that there is just one check for the label value, as otherwise + // node splitting has complicated things. + std::map<Index, Index> labelChecksInOrigin; + std::map<Index, Index> labelSetsInOrigin; + LabelUseFinder finder(labelIndex, labelChecksInOrigin, labelSetsInOrigin); + finder.walk(origin); + while (iff) { + auto num = getCheckedLabelValue(iff); + assert(labelChecks[num] > 0); + if (labelChecks[num] > 1) return true; // checked more than once, somewhere in function + assert(labelChecksInOrigin[num] == 0); + if (labelSetsInOrigin[num] != labelSets[num]) { + assert(labelSetsInOrigin[num] < labelSets[num]); + return true; // label set somewhere outside of origin TODO: if set in the if body here, it might be safe in some cases + } + iff = isLabelCheckingIf(iff->ifFalse, labelIndex); + } + return false; + } + + // optimizes jumps to a label check + // * origin is where the jumps originate, and also where we should write our output + // * iff is the if + void optimizeJumpsToLabelCheck(Expression*& origin, If* iff) { + Index nameCounter = newNameCounter++; + if (nameCounter >= MAX_NAME_INDEX) { + std::cerr << "too many names in RelooperJumpThreading :(\n"; + return; + } + Index num = getCheckedLabelValue(iff); + // create a new block for this jump target + Builder builder(*getModule()); + // origin is where all jumps to this target must come from - the element right before this if + // we break out of inner to reach the target. instead of flowing out of normally, we break out of the outer, so we skip the target. + auto innerName = innerNames->at(nameCounter); + auto outerName = outerNames->at(nameCounter); + auto* ifFalse = iff->ifFalse; + // all assignments of label to the target can be replaced with breaks to the target, via innerName + struct JumpUpdater : public PostWalker<JumpUpdater, Visitor<JumpUpdater>> { + Index labelIndex; + Index targetNum; + Name targetName; + + void visitSetLocal(SetLocal* curr) { + if (curr->index == labelIndex) { + if (Index(curr->value->cast<Const>()->value.geti32()) == targetNum) { + replaceCurrent(Builder(*getModule()).makeBreak(targetName)); + } + } + } + }; + JumpUpdater updater; + updater.labelIndex = labelIndex; + updater.targetNum = num; + updater.targetName = innerName; + updater.setModule(getModule()); + updater.walk(origin); + // restructure code + auto* inner = builder.blockifyWithName(origin, innerName, builder.makeBreak(outerName)); + auto* outer = builder.makeSequence(inner, iff->ifTrue); + outer->name = outerName; + origin = outer; + // if another label value is checked here, handle that too + if (ifFalse) { + optimizeJumpsToLabelCheck(origin, ifFalse->cast<If>()); + } + } +}; + +// declare pass + +Pass *createRelooperJumpThreadingPass() { + return new RelooperJumpThreading(); +} + +} // namespace wasm + diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 19d6c3eb1..429203a2e 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -28,22 +28,14 @@ namespace wasm { struct RemoveImports : public WalkerPass<PostWalker<RemoveImports, Visitor<RemoveImports>>> { - MixedArena* allocator; - Module* module; - - void prepare(PassRunner* runner, Module *module_) override { - allocator = runner->allocator; - module = module_; - } - void visitCallImport(CallImport *curr) { - WasmType type = module->getImport(curr->target)->functionType->result; + WasmType type = getModule()->getImport(curr->target)->functionType->result; if (type == none) { - replaceCurrent(allocator->alloc<Nop>()); + replaceCurrent(getModule()->allocator.alloc<Nop>()); } else { Literal nopLiteral; nopLiteral.type = type; - replaceCurrent(allocator->alloc<Const>()->set(nopLiteral)); + replaceCurrent(getModule()->allocator.alloc<Const>()->set(nopLiteral)); } } diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 59d3af6fc..86a46374f 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -21,9 +21,20 @@ #include <wasm.h> #include <pass.h> #include <ast_utils.h> +#include <wasm-builder.h> namespace wasm { +// to turn an if into a br-if, we must be able to reorder the +// condition and possible value, and the possible value must +// not have side effects (as they would run unconditionally) +static bool canTurnIfIntoBrIf(Expression* ifCondition, Expression* brValue) { + if (!brValue) return true; + EffectAnalyzer value(brValue); + if (value.hasSideEffects()) return false; + return !EffectAnalyzer(ifCondition).invalidates(value); +} + struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<RemoveUnusedBrs>>> { bool isFunctionParallel() override { return true; } @@ -44,6 +55,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R // a stack for if-else contents, we merge their outputs std::vector<Flows> ifStack; + // list of all loops, so we can optimize them + std::vector<Loop*> loops; + static void visitAny(RemoveUnusedBrs* self, Expression** currp) { auto* curr = *currp; auto& flows = self->flows; @@ -109,7 +123,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R // ignore (could be result of a previous cycle) self->valueCanFlow = false; } else { - // anything else stops the flow TODO: optimize loops? + // anything else stops the flow flows.clear(); self->valueCanFlow = false; } @@ -123,23 +137,25 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R self->ifStack.push_back(std::move(self->flows)); } + void visitLoop(Loop* curr) { + loops.push_back(curr); + } + void visitIf(If* curr) { if (!curr->ifFalse) { // if without an else. try to reduce if (condition) br => br_if (condition) Break* br = curr->ifTrue->dynCast<Break>(); if (br && !br->condition) { // TODO: if there is a condition, join them // if the br has a value, then if => br_if means we always execute the value, and also the order is value,condition vs condition,value - if (br->value) { - EffectAnalyzer value(br->value); - if (value.hasSideEffects()) return; - EffectAnalyzer condition(curr->condition); - if (condition.invalidates(value)) return; + if (canTurnIfIntoBrIf(curr->condition, br->value)) { + br->condition = curr->condition; + br->finalize(); + replaceCurrent(br); + anotherCycle = true; } - br->condition = curr->condition; - replaceCurrent(br); - anotherCycle = true; } } + // TODO: if-else can be turned into a br_if as well, if one of the sides is a dead end } // override scan to add a pre and a post check task to all nodes @@ -163,6 +179,99 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R } } + // optimizes a loop. returns true if we made changes + bool optimizeLoop(Loop* loop) { + // if a loop ends in + // (loop $in + // (block $out + // if (..) br $in; else br $out; + // ) + // ) + // then our normal opts can remove the break out since it flows directly out + // (and later passes make the if one-armed). however, the simple analysis + // fails on patterns like + // if (..) br $out; + // br $in; + // which is a common way to do a while (1) loop (end it with a jump to the + // top), so we handle that here. Specifically we want to conditionalize + // breaks to the loop top, i.e., put them behind a condition, so that other + // code can flow directly out and thus brs out can be removed. (even if + // the change is to let a break somewhere else flow out, that can still be + // helpful, as it shortens the logical loop. it is also good to generate + // an if-else instead of an if, as it might allow an eqz to be removed + // by flipping arms) + if (!loop->name.is()) return false; + auto* block = loop->body->dynCast<Block>(); + if (!block) return false; + // does the last element break to the top of the loop? + auto& list = block->list; + if (list.size() <= 1) return false; + auto* last = list.back()->dynCast<Break>(); + if (!last || !ExpressionAnalyzer::isSimple(last) || last->name != loop->name) return false; + // last is a simple break to the top of the loop. if we can conditionalize it, + // it won't block things from flowing out and not needing breaks to do so. + Index i = list.size() - 2; + Builder builder(*getModule()); + while (1) { + auto* curr = list[i]; + if (auto* iff = curr->dynCast<If>()) { + // let's try to move the code going to the top of the loop into the if-else + if (!iff->ifFalse) { + // we need the ifTrue to break, so it cannot reach the code we want to move + if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifTrue)) { + iff->ifFalse = builder.stealSlice(block, i + 1, list.size()); + return true; + } + } else { + // this is already an if-else. if one side is a dead end, we can append to the other, if + // there is no returned value to concern us + assert(!isConcreteWasmType(iff->type)); // can't be, since in the middle of a block + if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifTrue)) { + iff->ifFalse = builder.blockifyMerge(iff->ifFalse, builder.stealSlice(block, i + 1, list.size())); + return true; + } else if (ExpressionAnalyzer::getEndingSimpleBreak(iff->ifFalse)) { + iff->ifTrue = builder.blockifyMerge(iff->ifTrue, builder.stealSlice(block, i + 1, list.size())); + return true; + } + } + return false; + } else if (auto* brIf = curr->dynCast<Break>()) { + // br_if is similar to if. + if (brIf->condition && !brIf->value && brIf->name != loop->name) { + if (i == list.size() - 2) { + // there is the br_if, and then the br to the top, so just flip them and the condition + brIf->condition = builder.makeUnary(EqZInt32, brIf->condition); + last->name = brIf->name; + brIf->name = loop->name; + return true; + } else { + // there are elements in the middle, + // br_if $somewhere (condition) + // (..more..) + // br $in + // we can convert the br_if to an if. this has a cost, though, + // so only do it if it looks useful, which it definitely is if + // (a) $somewhere is straight out (so the br out vanishes), and + // (b) this br_if is the only branch to that block (so the block will vanish) + if (brIf->name == block->name && BreakSeeker::count(block, block->name) == 1) { + // note that we could drop the last element here, it is a br we know for sure is removable, + // but telling stealSlice to steal all to the end is more efficient, it can just truncate. + list[i] = builder.makeIf(brIf->condition, builder.makeBreak(brIf->name), builder.stealSlice(block, i + 1, list.size())); + return true; + } + } + } + return false; + } + // if there is control flow, we must stop looking + if (EffectAnalyzer(curr).branches) { + return false; + } + if (i == 0) return false; + i--; + } + } + void doWalkFunction(Function* func) { // multiple cycles may be needed bool worked = false; @@ -172,7 +281,8 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R assert(ifStack.empty()); // flows may contain returns, which are flowing out and so can be optimized for (size_t i = 0; i < flows.size(); i++) { - auto* flow = (*flows[i])->cast<Return>(); // cannot be a break + auto* flow = (*flows[i])->dynCast<Return>(); + if (!flow) continue; if (!flow->value) { // return => nop ExpressionManipulator::nop(flow); @@ -184,11 +294,138 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R } } flows.clear(); + // optimize loops (we don't do it while tracking flows, as they can interfere) + for (auto* loop : loops) { + anotherCycle |= optimizeLoop(loop); + } + loops.clear(); if (anotherCycle) worked = true; } while (anotherCycle); - // finally, we may have simplified ifs enough to turn them into selects - struct Selectifier : public WalkerPass<PostWalker<Selectifier, Visitor<Selectifier>>> { + + if (worked) { + // Our work may alter block and if types, they may now return + struct TypeUpdater : public WalkerPass<PostWalker<TypeUpdater, Visitor<TypeUpdater>>> { + void visitBlock(Block* curr) { + curr->finalize(); + } + void visitLoop(Loop* curr) { + curr->finalize(); + } + void visitIf(If* curr) { + curr->finalize(); + } + }; + TypeUpdater typeUpdater; + typeUpdater.walkFunction(func); + } + + // thread trivial jumps + struct JumpThreader : public ControlFlowWalker<JumpThreader, Visitor<JumpThreader>> { + // map of all value-less breaks going to a block (and not a loop) + std::map<Block*, std::vector<Break*>> breaksToBlock; + + // number of definitions of each name - when a name is defined more than once, it is not trivially safe to do this + std::map<Name, Index> numDefs; + + // the names to update, when we can (just one def) + std::map<Break*, Name> newNames; + + void visitBreak(Break* curr) { + if (!curr->value) { + if (auto* target = findBreakTarget(curr->name)->dynCast<Block>()) { + breaksToBlock[target].push_back(curr); + } + } + } + // TODO: Switch? + void visitBlock(Block* curr) { + if (curr->name.is()) numDefs[curr->name]++; + + auto& list = curr->list; + if (list.size() == 1 && curr->name.is()) { + // if this block has just one child, a sub-block, then jumps to the former are jumps to us, really + if (auto* child = list[0]->dynCast<Block>()) { + if (child->name.is() && child->name != curr->name) { + auto& breaks = breaksToBlock[child]; + for (auto* br : breaks) { + newNames[br] = curr->name; + breaksToBlock[curr].push_back(br); // update the list - we may push it even more later + } + breaksToBlock.erase(child); + } + } + } else if (list.size() == 2) { + // if this block has two children, a child-block and a simple jump, then jumps to child-block can be replaced with jumps to the new target + auto* child = list[0]->dynCast<Block>(); + auto* jump = list[1]->dynCast<Break>(); + if (child && child->name.is() && jump && ExpressionAnalyzer::isSimple(jump)) { + auto& breaks = breaksToBlock[child]; + for (auto* br : breaks) { + newNames[br] = jump->name; + } + // if the jump is to another block then we can update the list, and maybe push it even more later + if (auto* newTarget = findBreakTarget(jump->name)->dynCast<Block>()) { + for (auto* br : breaks) { + breaksToBlock[newTarget].push_back(br); + } + } + breaksToBlock.erase(child); + } + } + } + void visitLoop(Loop* curr) { + if (curr->name.is()) numDefs[curr->name]++; + } + + void finish() { + for (auto& iter : newNames) { + auto* br = iter.first; + auto name = iter.second; + if (numDefs[name] == 1) { + br->name = name; + } + } + } + }; + JumpThreader jumpThreader; + jumpThreader.setModule(getModule()); + jumpThreader.walkFunction(func); + jumpThreader.finish(); + + // perform some final optimizations + struct FinalOptimizer : public PostWalker<FinalOptimizer, Visitor<FinalOptimizer>> { + void visitBlock(Block* curr) { + // if a block has an if br else br, we can un-conditionalize the latter, allowing + // the if to become a br_if. + // * note that if not in a block already, then we need to create a block for this, so not useful otherwise + // * note that this only happens at the end of a block, as code after the if is dead + // * note that we do this at the end, because un-conditionalizing can interfere with optimizeLoop()ing. + auto& list = curr->list; + for (Index i = 0; i < list.size(); i++) { + auto* iff = list[i]->dynCast<If>(); + if (!iff || !iff->ifFalse || isConcreteWasmType(iff->type)) continue; // if it lacked an if-false, it would already be a br_if, as that's the easy case + auto* ifTrueBreak = iff->ifTrue->dynCast<Break>(); + if (ifTrueBreak && !ifTrueBreak->condition && canTurnIfIntoBrIf(iff->condition, ifTrueBreak->value)) { + // we are an if-else where the ifTrue is a break without a condition, so we can do this + list[i] = ifTrueBreak; + ifTrueBreak->condition = iff->condition; + ifTrueBreak->finalize(); + ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifFalse); + continue; + } + // otherwise, perhaps we can flip the if + auto* ifFalseBreak = iff->ifFalse->dynCast<Break>(); + if (ifFalseBreak && !ifFalseBreak->condition && canTurnIfIntoBrIf(iff->condition, ifFalseBreak->value)) { + list[i] = ifFalseBreak; + ifFalseBreak->condition = Builder(*getModule()).makeUnary(EqZInt32, iff->condition); + ifFalseBreak->finalize(); + ExpressionManipulator::spliceIntoBlock(curr, i + 1, iff->ifTrue); + continue; + } + } + } void visitIf(If* curr) { + // we may have simplified ifs enough to turn them into selects if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) { // if with else, consider turning it into a select if there is no control flow // TODO: estimate cost @@ -210,25 +447,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs, Visitor<R } } }; - Selectifier selectifier; - selectifier.setModule(getModule()); - selectifier.walkFunction(func); - if (worked) { - // Our work may alter block and if types, they may now return - struct TypeUpdater : public WalkerPass<PostWalker<TypeUpdater, Visitor<TypeUpdater>>> { - void visitBlock(Block* curr) { - curr->finalize(); - } - void visitLoop(Loop* curr) { - curr->finalize(); - } - void visitIf(If* curr) { - curr->finalize(); - } - }; - TypeUpdater typeUpdater; - typeUpdater.walkFunction(func); - } + FinalOptimizer finalOptimizer; + finalOptimizer.setModule(getModule()); + finalOptimizer.walkFunction(func); } }; diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 5315edac4..e032acc77 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -43,7 +43,7 @@ namespace wasm { // Helper classes struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCounter>> { - std::vector<int>* numGetLocals; + std::vector<Index>* numGetLocals; void visitGetLocal(GetLocal *curr) { (*numGetLocals)[curr->index]++; @@ -51,7 +51,7 @@ struct GetLocalCounter : public PostWalker<GetLocalCounter, Visitor<GetLocalCoun }; struct SetLocalRemover : public PostWalker<SetLocalRemover, Visitor<SetLocalRemover>> { - std::vector<int>* numGetLocals; + std::vector<Index>* numGetLocals; void visitSetLocal(SetLocal *curr) { if ((*numGetLocals)[curr->index] == 0) { @@ -102,8 +102,8 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // block returns std::map<Name, std::vector<BlockBreak>> blockBreaks; - // blocks that are the targets of a switch; we need to know this - // since we can't produce a block return value for them. + // blocks that we can't produce a block return value for them. + // (switch target, or some other reason) std::set<Name> unoptimizableBlocks; // A stack of sinkables from the current traversal state. When @@ -114,6 +114,12 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, // whether we need to run an additional cycle bool anotherCycle; + // whether this is the first cycle + bool firstCycle; + + // local => # of get_locals for it + std::vector<Index> numGetLocals; + static void doNoteNonLinear(SimplifyLocals* self, Expression** currp) { auto* curr = *currp; if (curr->is<Break>()) { @@ -187,9 +193,15 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, if (found != sinkables.end()) { // sink it, and nop the origin auto* set = (*found->second.item)->cast<SetLocal>(); - replaceCurrent(set); - assert(!set->isTee()); - set->setTee(true); + if (firstCycle) { + // just one get_local of this, so just sink the value + assert(numGetLocals[curr->index] == 1); + replaceCurrent(set->value); + } else { + replaceCurrent(set); + assert(!set->isTee()); + set->setTee(true); + } // reuse the getlocal that is dying *found->second.item = curr; ExpressionManipulator::nop(curr); @@ -259,7 +271,7 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, self->checkInvalidations(effects); } - if (set && !set->isTee()) { + if (set && !set->isTee() && (!self->firstCycle || self->numGetLocals[set->index] == 1)) { Index index = set->index; assert(self->sinkables.count(index) == 0); self->sinkables.emplace(std::make_pair(index, SinkableInfo(currp))); @@ -316,9 +328,18 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, for (size_t j = 0; j < breaks.size(); j++) { // move break set_local's value to the break auto* breakSetLocalPointer = breaks[j].sinkables.at(sharedIndex).item; - assert(!breaks[j].br->value); - breaks[j].br->value = (*breakSetLocalPointer)->cast<SetLocal>()->value; - ExpressionManipulator::nop(*breakSetLocalPointer); + auto* br = breaks[j].br; + assert(!br->value); + // if the break is conditional, then we must set the value here - if the break is not taken, we must still have the new value in the local + auto* set = (*breakSetLocalPointer)->cast<SetLocal>(); + if (br->condition) { + br->value = set; + set->setTee(true); + *breakSetLocalPointer = getModule()->allocator.alloc<Nop>(); + } else { + br->value = set->value; + ExpressionManipulator::nop(set); + } } // finally, create a set_local on the block itself auto* newSetLocal = Builder(*getModule()).makeSetLocal(sharedIndex, block); @@ -397,11 +418,22 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, } void doWalkFunction(Function* func) { + // scan get_locals + numGetLocals.resize(func->getNumLocals()); + std::fill(numGetLocals.begin(), numGetLocals.end(), 0); + GetLocalCounter counter; + counter.numGetLocals = &numGetLocals; + counter.walkFunction(func); // multiple passes may be required per function, consider this: // x = load // y = store // c(x, y) - // the load cannot cross the store, but y can be sunk, after which so can x + // the load cannot cross the store, but y can be sunk, after which so can x. + // + // we start with a cycle focusing on single-use locals, which are easy to + // sink (we don't need to put a set), and a good match for common compiler + // output patterns. further cycles do fully general sinking. + firstCycle = true; do { anotherCycle = false; // main operation @@ -435,15 +467,16 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals, sinkables.clear(); blockBreaks.clear(); unoptimizableBlocks.clear(); + if (firstCycle) { + firstCycle = false; + anotherCycle = true; + } } while (anotherCycle); // Finally, after optimizing a function, we can see if we have set_locals // for a local with no remaining gets, in which case, we can // remove the set. - // First, count get_locals - std::vector<int> numGetLocals; // local => # of get_locals for it - numGetLocals.resize(func->getNumLocals()); - GetLocalCounter counter; - counter.numGetLocals = &numGetLocals; + // First, recount get_locals + std::fill(numGetLocals.begin(), numGetLocals.end(), 0); counter.walkFunction(func); // Second, remove unneeded sets SetLocalRemover remover; diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index db42a994a..0fdd12dec 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -47,26 +47,31 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>> case Expression::Id::CallImportId: case Expression::Id::CallIndirectId: case Expression::Id::SetLocalId: - case Expression::Id::LoadId: case Expression::Id::StoreId: case Expression::Id::ReturnId: - case Expression::Id::GetGlobalId: case Expression::Id::SetGlobalId: case Expression::Id::HostId: case Expression::Id::UnreachableId: return curr; // always needed + case Expression::Id::LoadId: { + if (!resultUsed) { + return curr->cast<Load>()->ptr; + } + return curr; + } case Expression::Id::ConstId: case Expression::Id::GetLocalId: + case Expression::Id::GetGlobalId: { + if (!resultUsed) return nullptr; + return curr; + } + case Expression::Id::UnaryId: case Expression::Id::BinaryId: case Expression::Id::SelectId: { if (resultUsed) { return curr; // used, keep it } - // result is not used, perhaps it is dead - if (curr->is<Const>() || curr->is<GetLocal>()) { - return nullptr; - } // for unary, binary, and select, we need to check their arguments for side effects if (auto* unary = curr->dynCast<Unary>()) { if (EffectAnalyzer(unary->value).hasSideEffects()) { @@ -132,13 +137,12 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>> void visitBlock(Block *curr) { // compress out nops and other dead code - bool resultUsed = ExpressionAnalyzer::isResultUsed(expressionStack, getFunction()); int skip = 0; auto& list = curr->list; size_t size = list.size(); bool needResize = false; for (size_t z = 0; z < size; z++) { - auto* optimized = optimize(list[z], z == size - 1 && resultUsed); + auto* optimized = optimize(list[z], z == size - 1 && isConcreteWasmType(curr->type)); if (!optimized) { skip++; needResize = true; @@ -153,7 +157,12 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>> Break* br = list[z - skip]->dynCast<Break>(); Switch* sw = list[z - skip]->dynCast<Switch>(); if ((br && !br->condition) || sw) { + auto* last = list.back(); list.resize(z - skip + 1); + // if we removed the last one, and it was a return value, it must be returned + if (list.back() != last && isConcreteWasmType(last->type)) { + list.push_back(last); + } needResize = false; break; } @@ -165,7 +174,7 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>> if (!curr->name.is()) { if (list.size() == 1) { // just one element. replace the block, either with it or with a nop if it's not needed - if (resultUsed || EffectAnalyzer(list[0]).hasSideEffects()) { + if (isConcreteWasmType(curr->type) || EffectAnalyzer(list[0]).hasSideEffects()) { replaceCurrent(list[0]); } else { ExpressionManipulator::nop(curr); @@ -200,11 +209,53 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum, Visitor<Vacuum>> } void visitDrop(Drop* curr) { - // if the drop input has no side effects, it can be wiped out - if (!EffectAnalyzer(curr->value).hasSideEffects()) { + // optimize the dropped value, maybe leaving nothing + curr->value = optimize(curr->value, false); + if (curr->value == nullptr) { ExpressionManipulator::nop(curr); return; } + // a drop of a tee is a set + if (auto* set = curr->value->dynCast<SetLocal>()) { + assert(set->isTee()); + set->setTee(false); + replaceCurrent(set); + return; + } + // if we are dropping a block's return value, we might be able to remove it entirely + if (auto* block = curr->value->dynCast<Block>()) { + auto* last = block->list.back(); + if (isConcreteWasmType(last->type)) { + assert(block->type == last->type); + last = optimize(last, false); + if (!last) { + // we may be able to remove this, if there are no brs + bool canPop = true; + if (block->name.is()) { + BreakSeeker breakSeeker(block->name); + Expression* temp = block; + breakSeeker.walk(temp); + if (breakSeeker.found && breakSeeker.valueType != none) { + canPop = false; + } + } + if (canPop) { + block->list.back() = last; + block->list.pop_back(); + block->type = none; + // we don't need the drop anymore, let's see what we have left in the block + if (block->list.size() > 1) { + replaceCurrent(block); + } else if (block->list.size() == 1) { + replaceCurrent(block->list[0]); + } else { + ExpressionManipulator::nop(curr); + } + return; + } + } + } + } // sink a drop into an arm of an if-else if the other arm ends in an unreachable, as it if is a branch, this can make that branch optimizable and more vaccuming possible auto* iff = curr->value->dynCast<If>(); if (iff && iff->ifFalse && isConcreteWasmType(iff->type)) { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 437b023bc..f3bc6d4c7 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -18,6 +18,7 @@ #include <passes/passes.h> #include <pass.h> +#include <wasm-validator.h> namespace wasm { @@ -65,6 +66,7 @@ void PassRegistry::registerPasses() { registerPass("coalesce-locals-learning", "reduce # of locals by coalescing and learning", createCoalesceLocalsWithLearningPass); registerPass("dce", "removes unreachable code", createDeadCodeEliminationPass); registerPass("duplicate-function-elimination", "removes duplicate functions", createDuplicateFunctionEliminationPass); + registerPass("extract-function", "leaves just one function (useful for debugging)", createExtractFunctionPass); registerPass("merge-blocks", "merges blocks to their parents", createMergeBlocksPass); registerPass("metrics", "reports metrics", createMetricsPass); registerPass("nm", "name list", createNameListPass); @@ -74,6 +76,7 @@ void PassRegistry::registerPasses() { registerPass("print", "print in s-expression format", createPrinterPass); registerPass("print-minified", "print in minified s-expression format", createMinifiedPrinterPass); registerPass("print-full", "print in full s-expression format", createFullPrinterPass); + registerPass("relooper-jump-threading", "thread relooper jumps (fastcomp output only)", createRelooperJumpThreadingPass); registerPass("remove-imports", "removes imports and replaces them with nops", createRemoveImportsPass); registerPass("remove-memory", "removes memory segments", createRemoveMemoryPass); registerPass("remove-unused-brs", "removes breaks from locations that are not needed", createRemoveUnusedBrsPass); @@ -132,10 +135,9 @@ void PassRunner::addDefaultGlobalOptimizationPasses() { void PassRunner::run() { if (debug) { // for debug logging purposes, run each pass in full before running the other - std::chrono::high_resolution_clock::time_point beforeEverything; + auto totalTime = std::chrono::duration<double>(0); size_t padding = 0; std::cerr << "[PassRunner] running passes..." << std::endl; - beforeEverything = std::chrono::high_resolution_clock::now(); for (auto pass : passes) { padding = std::max(padding, pass->name.size()); } @@ -158,10 +160,19 @@ void PassRunner::run() { auto after = std::chrono::high_resolution_clock::now(); std::chrono::duration<double> diff = after - before; std::cerr << diff.count() << " seconds." << std::endl; + totalTime += diff; +#if 0 + // validate, ignoring the time + std::cerr << "[PassRunner] (validating)\n"; + if (!WasmValidator().validate(*wasm)) { + std::cerr << "last pass (" << pass->name << ") broke validation\n"; + abort(); + } +#endif } - auto after = std::chrono::high_resolution_clock::now(); - std::chrono::duration<double> diff = after - beforeEverything; - std::cerr << "[PassRunner] passes took " << diff.count() << " seconds." << std::endl; + std::cerr << "[PassRunner] passes took " << totalTime.count() << " seconds." << std::endl; + // validate + assert(WasmValidator().validate(*wasm)); } else { // non-debug normal mode, run them in an optimal manner - for locality it is better // to run as many passes as possible on a single function before moving to the next @@ -223,6 +234,11 @@ PassRunner::~PassRunner() { } } +void PassRunner::doAdd(Pass* pass) { + passes.push_back(pass); + pass->prepareToRun(this, wasm); +} + void PassRunner::runPassOnFunction(Pass* pass, Function* func) { // function-parallel passes get a new instance per function if (pass->isFunctionParallel()) { diff --git a/src/passes/passes.h b/src/passes/passes.h index 4bb76edad..80fa394e1 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -26,6 +26,7 @@ Pass *createCoalesceLocalsPass(); Pass *createCoalesceLocalsWithLearningPass(); Pass *createDeadCodeEliminationPass(); Pass *createDuplicateFunctionEliminationPass(); +Pass *createExtractFunctionPass(); Pass *createLowerIfElsePass(); Pass *createMergeBlocksPass(); Pass *createMetricsPass(); @@ -36,6 +37,7 @@ Pass *createPostEmscriptenPass(); Pass *createPrinterPass(); Pass *createMinifiedPrinterPass(); Pass *createFullPrinterPass(); +Pass *createRelooperJumpThreadingPass(); Pass *createRemoveImportsPass(); Pass *createRemoveMemoryPass(); Pass *createRemoveUnusedBrsPass(); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 2841585e6..f9b2e73f6 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -283,6 +283,28 @@ public: return block; } + // ensures the first node is a block, if it isn't already, and merges in the second, + // either as a single element or, if a block, by appending to the first block + Block* blockifyMerge(Expression* any, Expression* append) { + Block* block = nullptr; + if (any) block = any->dynCast<Block>(); + if (!block) { + block = makeBlock(any); + } else { + assert(!isConcreteWasmType(block->type)); + } + auto* other = append->dynCast<Block>(); + if (!other) { + block->list.push_back(append); + } else { + for (auto* item : other->list) { + block->list.push_back(item); + } + } + block->finalize(); // TODO: move out of if + return block; + } + // a helper for the common pattern of a sequence of two expressions. Similar to // blockify, but does *not* reuse a block if the first is one. Block* makeSequence(Expression* left, Expression* right) { @@ -291,6 +313,32 @@ public: block->finalize(); return block; } + + // Grab a slice out of a block, replacing it with nops, and returning + // either another block with the contents (if more than 1) or a single expression + Expression* stealSlice(Block* input, Index from, Index to) { + Expression* ret; + if (to == from + 1) { + // just one + ret = input->list[from]; + } else { + auto* block = allocator.alloc<Block>(); + for (Index i = from; i < to; i++) { + block->list.push_back(input->list[i]); + } + block->finalize(); + ret = block; + } + if (to == input->list.size()) { + input->list.resize(from); + } else { + for (Index i = from; i < to; i++) { + input->list[i] = allocator.alloc<Nop>(); + } + } + input->finalize(); + return ret; + } }; } // namespace wasm diff --git a/src/wasm-module-building.h b/src/wasm-module-building.h index 52a4e7536..92b96d98d 100644 --- a/src/wasm-module-building.h +++ b/src/wasm-module-building.h @@ -97,7 +97,7 @@ public: } // Before parallelism, create all passes on the main thread here, to ensure - // constructors run at least once on the main thread, for one-time init things. + // prepareToRun() is called for each pass before we start to optimize functions. { PassRunner passRunner(wasm); addPrePasses(passRunner); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 4725237dd..47b9d26e8 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -335,8 +335,8 @@ struct PostWalker : public Walker<SubType, VisitorType> { } case Expression::Id::SwitchId: { self->pushTask(SubType::doVisitSwitch, currp); - self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value); self->pushTask(SubType::scan, &curr->cast<Switch>()->condition); + self->maybePushTask(SubType::scan, &curr->cast<Switch>()->value); break; } case Expression::Id::CallId: { diff --git a/src/wasm-validator.h b/src/wasm-validator.h index 58a30f9a3..e23221337 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -104,6 +104,9 @@ public: void visitIf(If *curr) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32 || curr->condition->type == i64, curr, "if condition must be valid"); + if (!curr->ifFalse) { + shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body"); + } } // override scan to add a pre and a post check task to all nodes diff --git a/src/wasm.h b/src/wasm.h index d6bdfe91f..4d5cf4d70 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -981,7 +981,6 @@ public: Expression *condition, *ifTrue, *ifFalse; void finalize() { - assert(ifTrue); if (ifFalse) { if (ifTrue->type == ifFalse->type) { type = ifTrue->type; @@ -992,6 +991,8 @@ public: } else { type = none; } + } else { + type = none; // if without else } } }; @@ -1027,6 +1028,8 @@ public: void finalize() { if (condition) { type = none; + } else { + type = unreachable; } } }; |