diff options
author | Alon Zakai <azakai@google.com> | 2019-02-26 17:47:28 -0800 |
---|---|---|
committer | Alon Zakai <alonzakai@gmail.com> | 2019-03-06 16:34:34 -0800 |
commit | 56fc114121716c672d4a16f92e8323eada557177 (patch) | |
tree | b895d369688c3357a892bd93bfecac78d9b25db8 /src/passes/OptimizeAddedConstants.cpp | |
parent | 8b5b85463cd5f5fcfec4d0bc6acb52f2acb30d79 (diff) | |
download | binaryen-56fc114121716c672d4a16f92e8323eada557177.tar.gz binaryen-56fc114121716c672d4a16f92e8323eada557177.tar.bz2 binaryen-56fc114121716c672d4a16f92e8323eada557177.zip |
Propagate a load/store offset even if locals are not in ssa form
The initial OptimizeAddedConstants pass did not try to handle the case of non-ssa locals. However, that can happen, and optimizing those cases too improves us by almost 1% of code size on some large benchmarks like bullet.
How this works is that if we see
b = a + 10
a = c
load(b)
then we copy the base value at the add,
a' = a
b = a' + 10
a = c
load(a', offset=10)
This no longer has a guarantee of improving code size, since in theory both b and a may have other uses. However, in practice it's very common for b to be optimized out later.
Diffstat (limited to 'src/passes/OptimizeAddedConstants.cpp')
-rw-r--r-- | src/passes/OptimizeAddedConstants.cpp | 100 |
1 files changed, 89 insertions, 11 deletions
diff --git a/src/passes/OptimizeAddedConstants.cpp b/src/passes/OptimizeAddedConstants.cpp index 119d61637..9e339ed38 100644 --- a/src/passes/OptimizeAddedConstants.cpp +++ b/src/passes/OptimizeAddedConstants.cpp @@ -37,10 +37,11 @@ namespace wasm { -template<typename T> +template<typename P, typename T> class MemoryAccessOptimizer { public: - MemoryAccessOptimizer(T* curr, Module* module, LocalGraph* localGraph) : curr(curr), module(module), localGraph(localGraph) { + MemoryAccessOptimizer(P* parent, T* curr, Module* module, LocalGraph* localGraph) : + parent(parent), curr(curr), module(module), localGraph(localGraph) { // The pointer itself may be a constant, if e.g. it was precomputed or // a get that we propagated. if (curr->ptr->template is<Const>()) { @@ -81,8 +82,8 @@ public: // a constant *and* the other side cannot change in the middle. // TODO If it could change, we may add a new local to capture the // old value. - if (tryToOptimizePropagatedAdd(add->right, add->left, get) || - tryToOptimizePropagatedAdd(add->left, add->right, get)) { + if (tryToOptimizePropagatedAdd(add->right, add->left, get, set) || + tryToOptimizePropagatedAdd(add->left, add->right, get, set)) { return; } } @@ -94,6 +95,7 @@ public: } private: + P* parent; T* curr; Module* module; LocalGraph* localGraph; @@ -144,8 +146,12 @@ private: return false; } - bool tryToOptimizePropagatedAdd(Expression* oneSide, Expression* otherSide, GetLocal* ptr) { + bool tryToOptimizePropagatedAdd(Expression* oneSide, Expression* otherSide, GetLocal* ptr, SetLocal* set) { if (auto* c = oneSide->template dynCast<Const>()) { + if (otherSide->template is<Const>()) { + // Both sides are constant - this is not optimized code, ignore. + return false; + } auto result = canOptimizeConstant(c->value); if (result.succeeded) { // Looks good, but we need to make sure the other side cannot change: @@ -168,14 +174,32 @@ private: // // This is valid since dominance is transitive, so y's definition dominates the load, // and it is ok to replace x with y + 10 there. - // TODO otherwise, create a new local + Index index = -1; + bool canReuseIndex = false; if (auto* get = otherSide->template dynCast<GetLocal>()) { if (localGraph->isSSA(get->index) && localGraph->isSSA(ptr->index)) { - curr->offset = result.total; - curr->ptr = Builder(*module).makeGetLocal(get->index, get->type); - return true; + index = get->index; + canReuseIndex = true; } } + // If we can't reuse the index, then create a new one, + // + // x = y + 10 + // y = y + 1 + // load(x) + // => + // y' = y + // x = y' + 10 + // y = y + 1 + // load(y', offset=10) + // + // Often x has no other uses and later passes can remove it. + if (!canReuseIndex) { + index = parent->getHelperIndex(set); + } + curr->offset = result.total; + curr->ptr = Builder(*module).makeGetLocal(index, i32); + return true; } } return false; @@ -209,11 +233,11 @@ struct OptimizeAddedConstants : public WalkerPass<PostWalker<OptimizeAddedConsta std::unique_ptr<LocalGraph> localGraph; void visitLoad(Load* curr) { - MemoryAccessOptimizer<Load>(curr, getModule(), localGraph.get()); + MemoryAccessOptimizer<OptimizeAddedConstants, Load>(this, curr, getModule(), localGraph.get()); } void visitStore(Store* curr) { - MemoryAccessOptimizer<Store>(curr, getModule(), localGraph.get()); + MemoryAccessOptimizer<OptimizeAddedConstants, Store>(this, curr, getModule(), localGraph.get()); } void doWalkFunction(Function* func) { @@ -224,6 +248,60 @@ struct OptimizeAddedConstants : public WalkerPass<PostWalker<OptimizeAddedConsta localGraph->computeSSAIndexes(); } super::doWalkFunction(func); + if (!helperIndexes.empty()) { + createHelperIndexes(); + } + } + + // For a given expression, store it to a local and return us the local index we can use, + // in order to get that value someplace else. We are provided not the expression, + // but the set in which it is in, as the arm of an add that is the set's value (the other + // arm is a constant, and we are not a constant). + // We cache these, that is, use a single one for all requests. + Index getHelperIndex(SetLocal* set) { + auto iter = helperIndexes.find(set); + if (iter != helperIndexes.end()) { + return iter->second; + } + return helperIndexes[set] = Builder(*getModule()).addVar(getFunction(), i32); + } + +private: + std::map<SetLocal*, Index> helperIndexes; + + void createHelperIndexes() { + struct Creator : public PostWalker<Creator> { + std::map<SetLocal*, Index>& helperIndexes; + Module* module; + + Creator(std::map<SetLocal*, Index>& helperIndexes) : helperIndexes(helperIndexes) {} + + void visitSetLocal(SetLocal* curr) { + auto iter = helperIndexes.find(curr); + if (iter != helperIndexes.end()) { + auto index = iter->second; + auto* binary = curr->value->cast<Binary>(); + Expression** target; + if (binary->left->is<Const>()) { + target = &binary->right; + } else { + assert(binary->right->is<Const>()); + target = &binary->left; + } + auto* value = *target; + Builder builder(*module); + *target = builder.makeGetLocal(index, i32); + replaceCurrent( + builder.makeSequence( + builder.makeSetLocal(index, value), + curr + ) + ); + } + } + } creator(helperIndexes); + creator.module = getModule(); + creator.walk(getFunction()->body); } }; |