/* * Copyright 2022 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Refine uses of locals where possible. For example, consider this: // // (some.operation // (ref.cast .. (local.get $ref)) // (local.get $ref) // ) // // The second use might as well use the refined/cast value as well: // // (some.operation // (local.tee $temp // (ref.cast .. (local.get $ref)) // ) // (local.get $temp) // ) // // This change adds a local but it switches some local.gets to use a local of a // more refined type. That can help other optimizations later. // // An example of an important pattern this handles are itable calls: // // (call_ref // (ref.cast $actual.type // (local.get $object) // ) // (struct.get $vtable .. // (ref.cast $vtable // (struct.get $itable .. // (local.get $object) // ) // ) // ) // ) // // We cast to the actual type for the |this| parameter, but we technically do // not need to do so for reading its itable - since the itable may be of a // generic type, and we cast the vtable afterwards anyhow. But since we cast // |this|, we can use the cast value for the itable get, which may then lead to // removing the vtable cast after we refine the itable type. And that can lead // to devirtualization later. // // Closely related things appear in other passes: // // * SimplifyLocals will find locals already containing a more refined type and // switch to them. RedundantSetElimination does the same across basic blocks. // In theory one of them could be extended to also add new locals, and then // they would be doing something similar to this pass. // * LocalCSE finds repeated expressions and stores them in locals for use // later. In theory that pass could be extended to look not for exact copies // but for equivalent things through a cast, and then it would be doing // something similar to this pass. // // However, while those other passes could be extended to cover what this pass // does, we will have further cast-specific optimizations to add, which make // sense in new pass anyhow, and things should be simpler overall to keep such // casts all in one pass, here. // // TODO: Move casts earlier in a basic block as well, at least in traps-never- // happen mode where we can assume they never fail. // TODO: Look past individual basic blocks? // TODO: Look at LocalSet as well and not just Get. That would add some overlap // with the other passes mentioned above, but once we do things like // moving casts earlier as in the other TODO, we'd be doing uniquely // useful things with LocalSet here. // #include "ir/linear-execution.h" #include "ir/properties.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" #include "wasm.h" namespace wasm { namespace { // Find the best casted verisons of local.gets: other local.gets with the same // value, but cast to a more refined type. struct BestCastFinder : public LinearExecutionWalker { PassOptions options; // Map local indices to the most refined downcastings of local.gets from those // indices. // // This is tracked in each basic block, and cleared between them. std::unordered_map mostCastedGets; // For each most-downcasted local.get, a vector of other local.gets that could // be replaced with gets of the downcasted value. // // This is tracked until the end of the entire function, and contains the // information we need to optimize later. That is, entries here are things we // want to apply. std::unordered_map> lessCastedGets; static void doNoteNonLinear(BestCastFinder* self, Expression** currp) { self->mostCastedGets.clear(); } void visitLocalSet(LocalSet* curr) { // Clear any information about this local; it has a new value here. mostCastedGets.erase(curr->index); } void visitLocalGet(LocalGet* curr) { auto iter = mostCastedGets.find(curr->index); if (iter != mostCastedGets.end()) { auto* bestCast = iter->second; if (curr->type != bestCast->type && Type::isSubType(bestCast->type, curr->type)) { // The best cast has a more refined type, note that we want to use it. lessCastedGets[bestCast].push_back(curr); } } } void visitRefAs(RefAs* curr) { handleRefinement(curr); } void visitRefCast(RefCast* curr) { handleRefinement(curr); } void handleRefinement(Expression* curr) { auto* fallthrough = Properties::getFallthrough(curr, options, *getModule()); if (auto* get = fallthrough->dynCast()) { auto*& bestCast = mostCastedGets[get->index]; if (!bestCast) { // This is the first. bestCast = curr; return; } // See if we are better than the current best. if (curr->type != bestCast->type && Type::isSubType(curr->type, bestCast->type)) { bestCast = curr; } } } }; // Given a set of best casts, apply them: save each best cast in a local and use // it in the places that want to. // // It is simpler to do this in another pass after BestCastFinder so that we do // not need to worry about corner cases with invalidation of pointers in things // we've already walked past. struct FindingApplier : public PostWalker { BestCastFinder& finder; FindingApplier(BestCastFinder& finder) : finder(finder) {} void visitRefAs(RefAs* curr) { handleRefinement(curr); } void visitRefCast(RefCast* curr) { handleRefinement(curr); } void handleRefinement(Expression* curr) { auto iter = finder.lessCastedGets.find(curr); if (iter == finder.lessCastedGets.end()) { return; } // This expression was the best cast for some gets. Add a new local to // store this value, then use it for the gets. auto var = Builder::addVar(getFunction(), curr->type); auto& gets = iter->second; for (auto* get : gets) { get->index = var; get->type = curr->type; } // Replace ourselves with a tee. replaceCurrent(Builder(*getModule()).makeLocalTee(var, curr, curr->type)); } }; } // anonymous namespace struct OptimizeCasts : public WalkerPass> { bool isFunctionParallel() override { return true; } std::unique_ptr create() override { return std::make_unique(); } void doWalkFunction(Function* func) { if (!getModule()->features.hasGC()) { return; } // First, find the best casts that we want to use. BestCastFinder finder; finder.options = getPassOptions(); finder.walkFunctionInModule(func, getModule()); if (finder.lessCastedGets.empty()) { // Nothing to do. return; } // Apply the requests: use the best casts. FindingApplier applier(finder); applier.walkFunctionInModule(func, getModule()); // LocalGet type changes must be propagated. ReFinalize().walkFunctionInModule(func, getModule()); } }; Pass* createOptimizeCastsPass() { return new OptimizeCasts(); } } // namespace wasm