diff options
author | Alon Zakai <azakai@google.com> | 2023-09-14 14:21:25 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-14 14:21:25 -0700 |
commit | 3e8a9dacf6c65c29054094ce9f1d34ae8480df65 (patch) | |
tree | 928b04d774c2540a6a4c170d6013d5b60663c447 /src/passes/TupleOptimization.cpp | |
parent | f774effa54c6a40448487033a28a47caa3394f61 (diff) | |
download | binaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.tar.gz binaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.tar.bz2 binaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.zip |
Add a simple tuple optimization pass (#5937)
In some cases tuples are obviously not needed, such as when they are only used
in local operations and make/extract. Such tuples are not used as return values or
in control flow structures, so we might as well lower them to individual locals per
lane, which other passes can optimize a lot better.
I believe LLVM does the same with its own tuples: it lowers them as much as
possible, leaving only necessary ones.
Fixes #5923
Diffstat (limited to 'src/passes/TupleOptimization.cpp')
-rw-r--r-- | src/passes/TupleOptimization.cpp | 359 |
1 files changed, 359 insertions, 0 deletions
diff --git a/src/passes/TupleOptimization.cpp b/src/passes/TupleOptimization.cpp new file mode 100644 index 000000000..f4fd0910b --- /dev/null +++ b/src/passes/TupleOptimization.cpp @@ -0,0 +1,359 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Optimize away trivial tuples. When values are bundled together in a tuple, we +// are limited in how we can optimize then in the various local-related passes, +// like this: +// +// (local.set $tuple +// (tuple.make (A) (B) (C))) +// (use +// (tuple.extract 0 +// (local.get $tuple))) +// +// If there are no other uses, then we just need one of the three lanes. By +// lowing them to three separate locals, other passes can remove the other two. +// +// Specifically, this pass seeks out tuple locals that have these properties: +// +// * They are always written either a tuple.make or another tuple local with +// these properties. +// * They are always used either in tuple.extract or they are copied to another +// tuple local with these properties. +// +// The set of those tuple locals can be easily optimized into individual locals, +// as the tuple does not "escape" into, say, a return value. +// +// TODO: Blocks etc. might be handled here, but it's not clear if we want to: +// there are situations where multivalue leads to smaller code using +// those constructs. Atm this pass should only remove things that are +// definitely worth lowering. +// + +#include <pass.h> +#include <support/unique_deferring_queue.h> +#include <wasm-builder.h> +#include <wasm.h> + +namespace wasm { + +struct TupleOptimization : public WalkerPass<PostWalker<TupleOptimization>> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<TupleOptimization>(); + } + + // Track the number of uses for each tuple local. We consider a use as a + // local.get, a set, or a tee. A tee counts as two uses (since it both sets + // and gets, and so we must see that it is both used and uses properly). + std::vector<Index> uses; + + // Tracks which tuple local uses are valid, that is, follow the properties + // above. If we have more uses than valid uses then we must have an invalid + // one, and the local cannot be optimized. + std::vector<Index> validUses; + + // When one tuple local copies the value of another, we need to track the + // index that was copied, as if the source ends up bad then the target is bad + // as well. + // + // This is a symmetrical map, that is, we consider copies to work both ways: + // + // x \in copiedIndexed[y] <==> y \in copiedIndexed[x] + // + std::vector<std::unordered_set<Index>> copiedIndexes; + + void doWalkFunction(Function* func) { + // If tuples are not enabled, or there are no tuple locals, then there is no + // work to do. + if (!getModule()->features.hasMultivalue()) { + return; + } + bool hasTuple = false; + for (auto var : func->vars) { + if (var.isTuple()) { + hasTuple = true; + break; + } + } + if (!hasTuple) { + return; + } + + // Prepare global data structures before we collect info. + auto numLocals = func->getNumLocals(); + uses.resize(numLocals); + validUses.resize(numLocals); + copiedIndexes.resize(numLocals); + + // Walk the code to collect info. + super::doWalkFunction(func); + + // Analyze and optimize. + optimize(func); + } + + void visitLocalGet(LocalGet* curr) { + if (curr->type.isTuple()) { + uses[curr->index]++; + } + } + + void visitLocalSet(LocalSet* curr) { + if (getFunction()->getLocalType(curr->index).isTuple()) { + // See comment above about tees (we consider their set and get each a + // separate use). + uses[curr->index] += curr->isTee() ? 2 : 1; + auto* value = curr->value; + + // We need the input to the local to be another such local (from a tee, or + // a get), or a tuple.make. + if (auto* tee = value->dynCast<LocalSet>()) { + assert(tee->isTee()); + validUses[tee->index]++; + validUses[curr->index]++; + copiedIndexes[tee->index].insert(curr->index); + copiedIndexes[curr->index].insert(tee->index); + } else if (auto* get = value->dynCast<LocalGet>()) { + validUses[get->index]++; + validUses[curr->index]++; + copiedIndexes[get->index].insert(curr->index); + copiedIndexes[curr->index].insert(get->index); + } else if (value->is<TupleMake>()) { + validUses[curr->index]++; + } + } + } + + void visitTupleExtract(TupleExtract* curr) { + // We need the input to be a local, either from a tee or a get. + if (auto* set = curr->tuple->dynCast<LocalSet>()) { + validUses[set->index]++; + } else if (auto* get = curr->tuple->dynCast<LocalGet>()) { + validUses[get->index]++; + } + } + + void optimize(Function* func) { + auto numLocals = func->getNumLocals(); + + // Find the set of bad indexes. We add each such candidate to a worklist + // that we will then flow to find all those corrupted. + std::vector<bool> bad(numLocals); + UniqueDeferredQueue<Index> work; + + for (Index i = 0; i < uses.size(); i++) { + assert(validUses[i] <= uses[i]); + if (uses[i] > 0 && validUses[i] < uses[i]) { + // This is a bad tuple. + work.push(i); + } + } + + // Flow badness forward. + while (!work.empty()) { + auto i = work.pop(); + if (bad[i]) { + continue; + } + bad[i] = true; + for (auto target : copiedIndexes[i]) { + work.push(target); + } + } + + // Good indexes we can optimize are tuple locals with uses that are not bad. + std::vector<bool> good(numLocals); + bool hasGood = false; + for (Index i = 0; i < uses.size(); i++) { + if (uses[i] > 0 && !bad[i]) { + good[i] = true; + hasGood = true; + } + } + + if (!hasGood) { + return; + } + + // We found things to optimize! Create new non-tuple locals for their + // contents, and then rewrite the code to use those according to the + // mapping from tuple locals to normal ones. The mapping maps a tuple local + // to the base index used for its contents: an index and several others + // right after it, depending on the tuple size. + std::unordered_map<Index, Index> tupleToNewBaseMap; + for (Index i = 0; i < good.size(); i++) { + if (!good[i]) { + continue; + } + + auto newBase = func->getNumLocals(); + tupleToNewBaseMap[i] = newBase; + Index lastNewIndex = 0; + for (auto t : func->getLocalType(i)) { + Index newIndex = Builder::addVar(func, t); + if (lastNewIndex == 0) { + // This is the first new local we added (0 is an impossible value, + // since tuple locals exist, hence index 0 was already taken), so it + // must be equal to the base. + assert(newIndex == newBase); + } else { + // This must be right after the former. + assert(newIndex == lastNewIndex + 1); + } + lastNewIndex = newIndex; + } + } + + MapApplier mapApplier(tupleToNewBaseMap); + mapApplier.walkFunctionInModule(func, getModule()); + } + + struct MapApplier : public PostWalker<MapApplier> { + std::unordered_map<Index, Index>& tupleToNewBaseMap; + + MapApplier(std::unordered_map<Index, Index>& tupleToNewBaseMap) + : tupleToNewBaseMap(tupleToNewBaseMap) {} + + // Gets the new base index if there is one, or 0 if not (0 is an impossible + // value for a new index, as local index 0 was taken before, as tuple + // locals existed). + Index getNewBaseIndex(Index i) { + auto iter = tupleToNewBaseMap.find(i); + if (iter == tupleToNewBaseMap.end()) { + return 0; + } + return iter->second; + } + + // Given a local.get or local.set, return the new base index for the local + // index used there. Returns 0 (an impossible value, see above) otherwise. + Index getSetOrGetBaseIndex(Expression* setOrGet) { + Index index; + if (auto* set = setOrGet->dynCast<LocalSet>()) { + index = set->index; + } else if (auto* get = setOrGet->dynCast<LocalGet>()) { + index = get->index; + } else { + return 0; + } + + return getNewBaseIndex(index); + } + + // Replacing a local.tee requires some care, since we might have + // + // (local.set + // (local.tee + // .. + // + // We replace the local.tee with a block of sets of the new non-tuple + // locals, and the outer set must then (1) keep those around and also (2) + // identify the local that was tee'd, so we know what to get (which has been + // replaced by the block). To make that simple keep a map of the things that + // replaced tees. + std::unordered_map<Expression*, LocalSet*> replacedTees; + + void visitLocalSet(LocalSet* curr) { + auto replace = [&](Expression* replacement) { + if (curr->isTee()) { + replacedTees[replacement] = curr; + } + replaceCurrent(replacement); + }; + + if (auto targetBase = getNewBaseIndex(curr->index)) { + Builder builder(*getModule()); + auto type = getFunction()->getLocalType(curr->index); + + auto* value = curr->value; + if (auto* make = value->dynCast<TupleMake>()) { + // Write each of the tuple.make fields into the proper local. + std::vector<Expression*> sets; + for (Index i = 0; i < type.size(); i++) { + auto* value = make->operands[i]; + sets.push_back(builder.makeLocalSet(targetBase + i, value)); + } + replace(builder.makeBlock(sets)); + return; + } + + std::vector<Expression*> contents; + + auto iter = replacedTees.find(value); + if (iter != replacedTees.end()) { + // The input to us was a tee that has been replaced. The actual value + // we read from (the tee) can be found in replacedTees. Also, we + // need to keep around the replacement of the tee. + contents.push_back(value); + value = iter->second; + } + + // This is a copy of a tuple local into another. Copy all the fields + // between them. + Index sourceBase = getSetOrGetBaseIndex(value); + + // The target is being optimized, so the source must be as well, or else + // we were confused earlier and the target should not be. + assert(sourceBase); + + for (Index i = 0; i < type.size(); i++) { + auto* get = builder.makeLocalGet(sourceBase + i, type[i]); + contents.push_back(builder.makeLocalSet(targetBase + i, get)); + } + replace(builder.makeBlock(contents)); + } + } + + void visitTupleExtract(TupleExtract* curr) { + auto* value = curr->tuple; + Expression* extraContents = nullptr; + + auto iter = replacedTees.find(value); + if (iter != replacedTees.end()) { + // The input to us was a tee that has been replaced. Handle it as in + // visitLocalSet. + extraContents = value; + value = iter->second; + } + + auto type = value->type; + if (type == Type::unreachable) { + return; + } + + Index sourceBase = getSetOrGetBaseIndex(value); + if (!sourceBase) { + return; + } + + Builder builder(*getModule()); + auto i = curr->index; + auto* get = builder.makeLocalGet(sourceBase + i, type[i]); + if (extraContents) { + replaceCurrent(builder.makeSequence(extraContents, get)); + } else { + replaceCurrent(get); + } + } + }; +}; + +Pass* createTupleOptimizationPass() { return new TupleOptimization(); } + +} // namespace wasm |