/* * Copyright 2023 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Optimize away trivial tuples. When values are bundled together in a tuple, we // are limited in how we can optimize then in the various local-related passes, // like this: // // (local.set $tuple // (tuple.make 3 (A) (B) (C))) // (use // (tuple.extract 3 0 // (local.get $tuple))) // // If there are no other uses, then we just need one of the three elements. By // lowing them to three separate locals, other passes can remove the other two. // // Specifically, this pass seeks out tuple locals that have these properties: // // * They are always written either a tuple.make or another tuple local with // these properties. // * They are always used either in tuple.extract or they are copied to another // tuple local with these properties. // // The set of those tuple locals can be easily optimized into individual locals, // as the tuple does not "escape" into, say, a return value. // // TODO: Blocks etc. might be handled here, but it's not clear if we want to: // there are situations where multivalue leads to smaller code using // those constructs. Atm this pass should only remove things that are // definitely worth lowering. // #include #include #include #include namespace wasm { struct TupleOptimization : public WalkerPass> { bool isFunctionParallel() override { return true; } std::unique_ptr create() override { return std::make_unique(); } // Track the number of uses for each tuple local. We consider a use as a // local.get, a set, or a tee. A tee counts as two uses (since it both sets // and gets, and so we must see that it is both used and uses properly). std::vector uses; // Tracks which tuple local uses are valid, that is, follow the properties // above. If we have more uses than valid uses then we must have an invalid // one, and the local cannot be optimized. std::vector validUses; // When one tuple local copies the value of another, we need to track the // index that was copied, as if the source ends up bad then the target is bad // as well. // // This is a symmetrical map, that is, we consider copies to work both ways: // // x \in copiedIndexed[y] <==> y \in copiedIndexed[x] // std::vector> copiedIndexes; void doWalkFunction(Function* func) { // If tuples are not enabled, or there are no tuple locals, then there is no // work to do. if (!getModule()->features.hasMultivalue()) { return; } bool hasTuple = false; for (auto var : func->vars) { if (var.isTuple()) { hasTuple = true; break; } } if (!hasTuple) { return; } // Prepare global data structures before we collect info. auto numLocals = func->getNumLocals(); uses.resize(numLocals); validUses.resize(numLocals); copiedIndexes.resize(numLocals); // Walk the code to collect info. Super::doWalkFunction(func); // Analyze and optimize. optimize(func); } void visitLocalGet(LocalGet* curr) { if (curr->type.isTuple()) { uses[curr->index]++; } } void visitLocalSet(LocalSet* curr) { if (getFunction()->getLocalType(curr->index).isTuple()) { // See comment above about tees (we consider their set and get each a // separate use). uses[curr->index] += curr->isTee() ? 2 : 1; auto* value = curr->value; // We need the input to the local to be another such local (from a tee, or // a get), or a tuple.make. if (auto* tee = value->dynCast()) { assert(tee->isTee()); // We don't want to count anything as valid if the inner tee is // unreachable. In that case the outer tee is also unreachable, of // course, and in fact they might not even have the same tuple type or // the inner one might not even be a tuple (since we are in unreachable // code, that is possible). We could try to optimize unreachable tees in // some cases, but it's simpler to let DCE simplify the code first. if (tee->type != Type::unreachable) { validUses[tee->index]++; validUses[curr->index]++; copiedIndexes[tee->index].insert(curr->index); copiedIndexes[curr->index].insert(tee->index); } } else if (auto* get = value->dynCast()) { validUses[get->index]++; validUses[curr->index]++; copiedIndexes[get->index].insert(curr->index); copiedIndexes[curr->index].insert(get->index); } else if (value->is()) { validUses[curr->index]++; } } } void visitTupleExtract(TupleExtract* curr) { // We need the input to be a local, either from a tee or a get. if (auto* set = curr->tuple->dynCast()) { validUses[set->index]++; } else if (auto* get = curr->tuple->dynCast()) { validUses[get->index]++; } } void optimize(Function* func) { auto numLocals = func->getNumLocals(); // Find the set of bad indexes. We add each such candidate to a worklist // that we will then flow to find all those corrupted. std::vector bad(numLocals); UniqueDeferredQueue work; for (Index i = 0; i < uses.size(); i++) { assert(validUses[i] <= uses[i]); if (uses[i] > 0 && validUses[i] < uses[i]) { // This is a bad tuple. work.push(i); } } // Flow badness forward. while (!work.empty()) { auto i = work.pop(); if (bad[i]) { continue; } bad[i] = true; for (auto target : copiedIndexes[i]) { work.push(target); } } // Good indexes we can optimize are tuple locals with uses that are not bad. std::vector good(numLocals); bool hasGood = false; for (Index i = 0; i < uses.size(); i++) { if (uses[i] > 0 && !bad[i]) { good[i] = true; hasGood = true; } } if (!hasGood) { return; } // We found things to optimize! Create new non-tuple locals for their // contents, and then rewrite the code to use those according to the // mapping from tuple locals to normal ones. The mapping maps a tuple local // to the base index used for its contents: an index and several others // right after it, depending on the tuple size. std::unordered_map tupleToNewBaseMap; for (Index i = 0; i < good.size(); i++) { if (!good[i]) { continue; } auto newBase = func->getNumLocals(); tupleToNewBaseMap[i] = newBase; Index lastNewIndex = 0; for (auto t : func->getLocalType(i)) { Index newIndex = Builder::addVar(func, t); if (lastNewIndex == 0) { // This is the first new local we added (0 is an impossible value, // since tuple locals exist, hence index 0 was already taken), so it // must be equal to the base. assert(newIndex == newBase); } else { // This must be right after the former. assert(newIndex == lastNewIndex + 1); } lastNewIndex = newIndex; } } MapApplier mapApplier(tupleToNewBaseMap); mapApplier.walkFunctionInModule(func, getModule()); } struct MapApplier : public PostWalker { std::unordered_map& tupleToNewBaseMap; MapApplier(std::unordered_map& tupleToNewBaseMap) : tupleToNewBaseMap(tupleToNewBaseMap) {} // Gets the new base index if there is one, or 0 if not (0 is an impossible // value for a new index, as local index 0 was taken before, as tuple // locals existed). Index getNewBaseIndex(Index i) { auto iter = tupleToNewBaseMap.find(i); if (iter == tupleToNewBaseMap.end()) { return 0; } return iter->second; } // Given a local.get or local.set, return the new base index for the local // index used there. Returns 0 (an impossible value, see above) otherwise. Index getSetOrGetBaseIndex(Expression* setOrGet) { Index index; if (auto* set = setOrGet->dynCast()) { index = set->index; } else if (auto* get = setOrGet->dynCast()) { index = get->index; } else { return 0; } return getNewBaseIndex(index); } // Replacing a local.tee requires some care, since we might have // // (local.set // (local.tee // .. // // We replace the local.tee with a block of sets of the new non-tuple // locals, and the outer set must then (1) keep those around and also (2) // identify the local that was tee'd, so we know what to get (which has been // replaced by the block). To make that simple keep a map of the things that // replaced tees. std::unordered_map replacedTees; void visitLocalSet(LocalSet* curr) { auto replace = [&](Expression* replacement) { if (curr->isTee()) { replacedTees[replacement] = curr; } replaceCurrent(replacement); }; if (auto targetBase = getNewBaseIndex(curr->index)) { Builder builder(*getModule()); auto type = getFunction()->getLocalType(curr->index); auto* value = curr->value; if (auto* make = value->dynCast()) { // Write each of the tuple.make fields into the proper local. std::vector sets; for (Index i = 0; i < type.size(); i++) { auto* value = make->operands[i]; sets.push_back(builder.makeLocalSet(targetBase + i, value)); } replace(builder.makeBlock(sets)); return; } std::vector contents; auto iter = replacedTees.find(value); if (iter != replacedTees.end()) { // The input to us was a tee that has been replaced. The actual value // we read from (the tee) can be found in replacedTees. Also, we // need to keep around the replacement of the tee. contents.push_back(value); value = iter->second; } // This is a copy of a tuple local into another. Copy all the fields // between them. Index sourceBase = getSetOrGetBaseIndex(value); // The target is being optimized, so the source must be as well, or else // we were confused earlier and the target should not be. assert(sourceBase); // The source and target may have different element types due to // subtyping (but their sizes must be equal). auto sourceType = value->type; assert(sourceType.size() == type.size()); for (Index i = 0; i < type.size(); i++) { auto* get = builder.makeLocalGet(sourceBase + i, sourceType[i]); contents.push_back(builder.makeLocalSet(targetBase + i, get)); } replace(builder.makeBlock(contents)); } } void visitTupleExtract(TupleExtract* curr) { auto* value = curr->tuple; Expression* extraContents = nullptr; auto iter = replacedTees.find(value); if (iter != replacedTees.end()) { // The input to us was a tee that has been replaced. Handle it as in // visitLocalSet. extraContents = value; value = iter->second; } auto type = value->type; if (type == Type::unreachable) { return; } Index sourceBase = getSetOrGetBaseIndex(value); if (!sourceBase) { return; } Builder builder(*getModule()); auto i = curr->index; auto* get = builder.makeLocalGet(sourceBase + i, type[i]); if (extraContents) { replaceCurrent(builder.makeSequence(extraContents, get)); } else { replaceCurrent(get); } } }; }; Pass* createTupleOptimizationPass() { return new TupleOptimization(); } } // namespace wasm