summaryrefslogtreecommitdiff
path: root/src/passes/TupleOptimization.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2023-09-14 14:21:25 -0700
committerGitHub <noreply@github.com>2023-09-14 14:21:25 -0700
commit3e8a9dacf6c65c29054094ce9f1d34ae8480df65 (patch)
tree928b04d774c2540a6a4c170d6013d5b60663c447 /src/passes/TupleOptimization.cpp
parentf774effa54c6a40448487033a28a47caa3394f61 (diff)
downloadbinaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.tar.gz
binaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.tar.bz2
binaryen-3e8a9dacf6c65c29054094ce9f1d34ae8480df65.zip
Add a simple tuple optimization pass (#5937)
In some cases tuples are obviously not needed, such as when they are only used in local operations and make/extract. Such tuples are not used as return values or in control flow structures, so we might as well lower them to individual locals per lane, which other passes can optimize a lot better. I believe LLVM does the same with its own tuples: it lowers them as much as possible, leaving only necessary ones. Fixes #5923
Diffstat (limited to 'src/passes/TupleOptimization.cpp')
-rw-r--r--src/passes/TupleOptimization.cpp359
1 files changed, 359 insertions, 0 deletions
diff --git a/src/passes/TupleOptimization.cpp b/src/passes/TupleOptimization.cpp
new file mode 100644
index 000000000..f4fd0910b
--- /dev/null
+++ b/src/passes/TupleOptimization.cpp
@@ -0,0 +1,359 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Optimize away trivial tuples. When values are bundled together in a tuple, we
+// are limited in how we can optimize then in the various local-related passes,
+// like this:
+//
+// (local.set $tuple
+// (tuple.make (A) (B) (C)))
+// (use
+// (tuple.extract 0
+// (local.get $tuple)))
+//
+// If there are no other uses, then we just need one of the three lanes. By
+// lowing them to three separate locals, other passes can remove the other two.
+//
+// Specifically, this pass seeks out tuple locals that have these properties:
+//
+// * They are always written either a tuple.make or another tuple local with
+// these properties.
+// * They are always used either in tuple.extract or they are copied to another
+// tuple local with these properties.
+//
+// The set of those tuple locals can be easily optimized into individual locals,
+// as the tuple does not "escape" into, say, a return value.
+//
+// TODO: Blocks etc. might be handled here, but it's not clear if we want to:
+// there are situations where multivalue leads to smaller code using
+// those constructs. Atm this pass should only remove things that are
+// definitely worth lowering.
+//
+
+#include <pass.h>
+#include <support/unique_deferring_queue.h>
+#include <wasm-builder.h>
+#include <wasm.h>
+
+namespace wasm {
+
+struct TupleOptimization : public WalkerPass<PostWalker<TupleOptimization>> {
+ bool isFunctionParallel() override { return true; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<TupleOptimization>();
+ }
+
+ // Track the number of uses for each tuple local. We consider a use as a
+ // local.get, a set, or a tee. A tee counts as two uses (since it both sets
+ // and gets, and so we must see that it is both used and uses properly).
+ std::vector<Index> uses;
+
+ // Tracks which tuple local uses are valid, that is, follow the properties
+ // above. If we have more uses than valid uses then we must have an invalid
+ // one, and the local cannot be optimized.
+ std::vector<Index> validUses;
+
+ // When one tuple local copies the value of another, we need to track the
+ // index that was copied, as if the source ends up bad then the target is bad
+ // as well.
+ //
+ // This is a symmetrical map, that is, we consider copies to work both ways:
+ //
+ // x \in copiedIndexed[y] <==> y \in copiedIndexed[x]
+ //
+ std::vector<std::unordered_set<Index>> copiedIndexes;
+
+ void doWalkFunction(Function* func) {
+ // If tuples are not enabled, or there are no tuple locals, then there is no
+ // work to do.
+ if (!getModule()->features.hasMultivalue()) {
+ return;
+ }
+ bool hasTuple = false;
+ for (auto var : func->vars) {
+ if (var.isTuple()) {
+ hasTuple = true;
+ break;
+ }
+ }
+ if (!hasTuple) {
+ return;
+ }
+
+ // Prepare global data structures before we collect info.
+ auto numLocals = func->getNumLocals();
+ uses.resize(numLocals);
+ validUses.resize(numLocals);
+ copiedIndexes.resize(numLocals);
+
+ // Walk the code to collect info.
+ super::doWalkFunction(func);
+
+ // Analyze and optimize.
+ optimize(func);
+ }
+
+ void visitLocalGet(LocalGet* curr) {
+ if (curr->type.isTuple()) {
+ uses[curr->index]++;
+ }
+ }
+
+ void visitLocalSet(LocalSet* curr) {
+ if (getFunction()->getLocalType(curr->index).isTuple()) {
+ // See comment above about tees (we consider their set and get each a
+ // separate use).
+ uses[curr->index] += curr->isTee() ? 2 : 1;
+ auto* value = curr->value;
+
+ // We need the input to the local to be another such local (from a tee, or
+ // a get), or a tuple.make.
+ if (auto* tee = value->dynCast<LocalSet>()) {
+ assert(tee->isTee());
+ validUses[tee->index]++;
+ validUses[curr->index]++;
+ copiedIndexes[tee->index].insert(curr->index);
+ copiedIndexes[curr->index].insert(tee->index);
+ } else if (auto* get = value->dynCast<LocalGet>()) {
+ validUses[get->index]++;
+ validUses[curr->index]++;
+ copiedIndexes[get->index].insert(curr->index);
+ copiedIndexes[curr->index].insert(get->index);
+ } else if (value->is<TupleMake>()) {
+ validUses[curr->index]++;
+ }
+ }
+ }
+
+ void visitTupleExtract(TupleExtract* curr) {
+ // We need the input to be a local, either from a tee or a get.
+ if (auto* set = curr->tuple->dynCast<LocalSet>()) {
+ validUses[set->index]++;
+ } else if (auto* get = curr->tuple->dynCast<LocalGet>()) {
+ validUses[get->index]++;
+ }
+ }
+
+ void optimize(Function* func) {
+ auto numLocals = func->getNumLocals();
+
+ // Find the set of bad indexes. We add each such candidate to a worklist
+ // that we will then flow to find all those corrupted.
+ std::vector<bool> bad(numLocals);
+ UniqueDeferredQueue<Index> work;
+
+ for (Index i = 0; i < uses.size(); i++) {
+ assert(validUses[i] <= uses[i]);
+ if (uses[i] > 0 && validUses[i] < uses[i]) {
+ // This is a bad tuple.
+ work.push(i);
+ }
+ }
+
+ // Flow badness forward.
+ while (!work.empty()) {
+ auto i = work.pop();
+ if (bad[i]) {
+ continue;
+ }
+ bad[i] = true;
+ for (auto target : copiedIndexes[i]) {
+ work.push(target);
+ }
+ }
+
+ // Good indexes we can optimize are tuple locals with uses that are not bad.
+ std::vector<bool> good(numLocals);
+ bool hasGood = false;
+ for (Index i = 0; i < uses.size(); i++) {
+ if (uses[i] > 0 && !bad[i]) {
+ good[i] = true;
+ hasGood = true;
+ }
+ }
+
+ if (!hasGood) {
+ return;
+ }
+
+ // We found things to optimize! Create new non-tuple locals for their
+ // contents, and then rewrite the code to use those according to the
+ // mapping from tuple locals to normal ones. The mapping maps a tuple local
+ // to the base index used for its contents: an index and several others
+ // right after it, depending on the tuple size.
+ std::unordered_map<Index, Index> tupleToNewBaseMap;
+ for (Index i = 0; i < good.size(); i++) {
+ if (!good[i]) {
+ continue;
+ }
+
+ auto newBase = func->getNumLocals();
+ tupleToNewBaseMap[i] = newBase;
+ Index lastNewIndex = 0;
+ for (auto t : func->getLocalType(i)) {
+ Index newIndex = Builder::addVar(func, t);
+ if (lastNewIndex == 0) {
+ // This is the first new local we added (0 is an impossible value,
+ // since tuple locals exist, hence index 0 was already taken), so it
+ // must be equal to the base.
+ assert(newIndex == newBase);
+ } else {
+ // This must be right after the former.
+ assert(newIndex == lastNewIndex + 1);
+ }
+ lastNewIndex = newIndex;
+ }
+ }
+
+ MapApplier mapApplier(tupleToNewBaseMap);
+ mapApplier.walkFunctionInModule(func, getModule());
+ }
+
+ struct MapApplier : public PostWalker<MapApplier> {
+ std::unordered_map<Index, Index>& tupleToNewBaseMap;
+
+ MapApplier(std::unordered_map<Index, Index>& tupleToNewBaseMap)
+ : tupleToNewBaseMap(tupleToNewBaseMap) {}
+
+ // Gets the new base index if there is one, or 0 if not (0 is an impossible
+ // value for a new index, as local index 0 was taken before, as tuple
+ // locals existed).
+ Index getNewBaseIndex(Index i) {
+ auto iter = tupleToNewBaseMap.find(i);
+ if (iter == tupleToNewBaseMap.end()) {
+ return 0;
+ }
+ return iter->second;
+ }
+
+ // Given a local.get or local.set, return the new base index for the local
+ // index used there. Returns 0 (an impossible value, see above) otherwise.
+ Index getSetOrGetBaseIndex(Expression* setOrGet) {
+ Index index;
+ if (auto* set = setOrGet->dynCast<LocalSet>()) {
+ index = set->index;
+ } else if (auto* get = setOrGet->dynCast<LocalGet>()) {
+ index = get->index;
+ } else {
+ return 0;
+ }
+
+ return getNewBaseIndex(index);
+ }
+
+ // Replacing a local.tee requires some care, since we might have
+ //
+ // (local.set
+ // (local.tee
+ // ..
+ //
+ // We replace the local.tee with a block of sets of the new non-tuple
+ // locals, and the outer set must then (1) keep those around and also (2)
+ // identify the local that was tee'd, so we know what to get (which has been
+ // replaced by the block). To make that simple keep a map of the things that
+ // replaced tees.
+ std::unordered_map<Expression*, LocalSet*> replacedTees;
+
+ void visitLocalSet(LocalSet* curr) {
+ auto replace = [&](Expression* replacement) {
+ if (curr->isTee()) {
+ replacedTees[replacement] = curr;
+ }
+ replaceCurrent(replacement);
+ };
+
+ if (auto targetBase = getNewBaseIndex(curr->index)) {
+ Builder builder(*getModule());
+ auto type = getFunction()->getLocalType(curr->index);
+
+ auto* value = curr->value;
+ if (auto* make = value->dynCast<TupleMake>()) {
+ // Write each of the tuple.make fields into the proper local.
+ std::vector<Expression*> sets;
+ for (Index i = 0; i < type.size(); i++) {
+ auto* value = make->operands[i];
+ sets.push_back(builder.makeLocalSet(targetBase + i, value));
+ }
+ replace(builder.makeBlock(sets));
+ return;
+ }
+
+ std::vector<Expression*> contents;
+
+ auto iter = replacedTees.find(value);
+ if (iter != replacedTees.end()) {
+ // The input to us was a tee that has been replaced. The actual value
+ // we read from (the tee) can be found in replacedTees. Also, we
+ // need to keep around the replacement of the tee.
+ contents.push_back(value);
+ value = iter->second;
+ }
+
+ // This is a copy of a tuple local into another. Copy all the fields
+ // between them.
+ Index sourceBase = getSetOrGetBaseIndex(value);
+
+ // The target is being optimized, so the source must be as well, or else
+ // we were confused earlier and the target should not be.
+ assert(sourceBase);
+
+ for (Index i = 0; i < type.size(); i++) {
+ auto* get = builder.makeLocalGet(sourceBase + i, type[i]);
+ contents.push_back(builder.makeLocalSet(targetBase + i, get));
+ }
+ replace(builder.makeBlock(contents));
+ }
+ }
+
+ void visitTupleExtract(TupleExtract* curr) {
+ auto* value = curr->tuple;
+ Expression* extraContents = nullptr;
+
+ auto iter = replacedTees.find(value);
+ if (iter != replacedTees.end()) {
+ // The input to us was a tee that has been replaced. Handle it as in
+ // visitLocalSet.
+ extraContents = value;
+ value = iter->second;
+ }
+
+ auto type = value->type;
+ if (type == Type::unreachable) {
+ return;
+ }
+
+ Index sourceBase = getSetOrGetBaseIndex(value);
+ if (!sourceBase) {
+ return;
+ }
+
+ Builder builder(*getModule());
+ auto i = curr->index;
+ auto* get = builder.makeLocalGet(sourceBase + i, type[i]);
+ if (extraContents) {
+ replaceCurrent(builder.makeSequence(extraContents, get));
+ } else {
+ replaceCurrent(get);
+ }
+ }
+ };
+};
+
+Pass* createTupleOptimizationPass() { return new TupleOptimization(); }
+
+} // namespace wasm