diff options
Diffstat (limited to 'src/ir/possible-contents.cpp')
-rw-r--r-- | src/ir/possible-contents.cpp | 561 |
1 files changed, 550 insertions, 11 deletions
diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index 916a39f2c..232ebb7ad 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp @@ -17,10 +17,12 @@ #include <optional> #include <variant> +#include "analysis/cfg.h" #include "ir/bits.h" #include "ir/branch-utils.h" #include "ir/eh-utils.h" #include "ir/gc-type-utils.h" +#include "ir/linear-execution.h" #include "ir/local-graph.h" #include "ir/module-utils.h" #include "ir/possible-contents.h" @@ -1321,12 +1323,506 @@ struct InfoCollector } }; +// TrapsNeverHappen Oracle. This makes inferences *backwards* from traps that we +// know will not happen due to the TNH assumption. For example, +// +// (local.get $a) +// (ref.cast $B (local.get $a)) +// +// The cast happens right after the first local.get, and we assume it does not +// fail, so the local must contain a B, even though the IR only has A. +// +// This analysis complements ContentOracle, which uses this analysis internally. +// ContentOracle does a forward flow analysis (as content moves from place to +// place) which increases from "nothing", while this does a backwards analysis +// that decreases from "everything" (or rather, from the type declared in the +// IR), so the two cannot be done at once. +// +// TODO: We could cycle between this and ContentOracle for repeated +// improvements. +// TODO: This pass itself could benefit from internal cycles. +// +// This analysis mainly focuses on information across calls, as simple backwards +// inference is done in OptimizeCasts. Note that it is not needed if a call is +// inlined, obviously, and so it mostly helps cases like functions too large to +// inline, or when optimizing for size, or with indirect calls. +// +// We track cast parameters by mapping an index to the type it is definitely +// cast to if the function is entered. From that information we can infer things +// about the values being sent to the function (which we can assume must have +// the right type so that the casts do not trap). +using CastParams = std::unordered_map<Index, Type>; + +// The information we collect and utilize as we operate in parallel in each +// function. +struct TNHInfo { + CastParams castParams; + + // TODO: Returns as well: when we see (ref.cast (call $foo)) in all callers + // then we can refine inside $foo (in closed world). + + // We gather calls in parallel in order to process them later. + std::vector<Call*> calls; + std::vector<CallRef*> callRefs; + + // Note if a function body definitely traps. + bool traps = false; + + // We gather inferences in parallel and combine them at the end. + std::unordered_map<Expression*, PossibleContents> inferences; +}; + +class TNHOracle : public ModuleUtils::ParallelFunctionAnalysis<TNHInfo> { + const PassOptions& options; + +public: + using Parent = ModuleUtils::ParallelFunctionAnalysis<TNHInfo>; + TNHOracle(Module& wasm, const PassOptions& options) + : Parent(wasm, + [this, &options](Function* func, TNHInfo& info) { + scan(func, info, options); + }), + options(options) { + + // After the scanning phase that we run in the constructor, continue to the + // second phase of analysis: inference. + infer(); + } + + // Get the type we inferred was possible at a location. + PossibleContents getContents(Expression* curr) { + auto naiveContents = PossibleContents::fullConeType(curr->type); + + // If we inferred nothing, use the naive type. + auto iter = inferences.find(curr); + if (iter == inferences.end()) { + return naiveContents; + } + + auto& contents = iter->second; + // We only store useful contents that improve on the naive estimate that + // uses the type in the IR. + assert(contents != naiveContents); + return contents; + } + +private: + // Maps expressions to the content we inferred there. If an expression is not + // here then expression->type (the type in Binaryen IR) is all we have. + std::unordered_map<Expression*, PossibleContents> inferences; + + // Phase 1: Scan to find cast parameters and calls. This operates on a single + // function, and is called in parallel. + void scan(Function* func, TNHInfo& info, const PassOptions& options); + + // Phase 2: Infer contents based on what we scanned. + void infer(); + + // Optimize one specific call (or call_ref). + void optimizeCallCasts(Expression* call, + const ExpressionList& operands, + const CastParams& targetCastParams, + const analysis::CFGBlockIndexes& blockIndexes, + TNHInfo& info); +}; + +void TNHOracle::scan(Function* func, + TNHInfo& info, + const PassOptions& options) { + if (func->imported()) { + return; + } + + // Gather parameters that are definitely cast in the function entry. + struct EntryScanner : public LinearExecutionWalker<EntryScanner> { + Module& wasm; + const PassOptions& options; + TNHInfo& info; + + EntryScanner(Module& wasm, const PassOptions& options, TNHInfo& info) + : wasm(wasm), options(options), info(info) {} + + // Note while we are still in the entry (first) block. + bool inEntryBlock = true; + + static void doNoteNonLinear(EntryScanner* self, Expression** currp) { + // This is the end of the first basic block. + self->inEntryBlock = false; + } + + void visitCall(Call* curr) { info.calls.push_back(curr); } + + void visitCallRef(CallRef* curr) { + // We can only optimize call_ref in closed world, as otherwise the + // call can go somewhere we can't see. + if (options.closedWorld) { + info.callRefs.push_back(curr); + } + } + + void visitRefAs(RefAs* curr) { + if (curr->op == RefAsNonNull) { + noteCast(curr->value, curr->type); + } + } + void visitRefCast(RefCast* curr) { noteCast(curr->ref, curr->type); } + + // Note a cast of an expression to a particular type. + void noteCast(Expression* expr, Type type) { + if (!inEntryBlock) { + return; + } + + auto* fallthrough = Properties::getFallthrough(expr, options, wasm); + if (auto* get = fallthrough->dynCast<LocalGet>()) { + // To optimize, this needs to be a param, and of a useful type. + // + // Note that if we see more than one cast we keep the first one. This is + // not important in optimized code, as the most refined cast would be + // the only one to exist there, so it's ok to keep things simple here. + if (getFunction()->isParam(get->index) && type != get->type && + info.castParams.count(get->index) == 0) { + info.castParams[get->index] = type; + } + } + } + + // Operations that trap on null are equivalent to casts to non-null, in that + // they imply that their input is non-null if traps never happen. + // + // We only look at them if the input is actually nullable, since if they + // are non-nullable then we can add no information. (This is equivalent + // to the handling of RefAsNonNull above, in the sense that in optimized + // code the RefAs will not appear if the input is already non-nullable). + // This function is called with the reference that will be trapped on, + // if it is null. + void notePossibleTrap(Expression* expr) { + if (!expr->type.isRef() || expr->type.isNonNullable()) { + return; + } + noteCast(expr, Type(expr->type.getHeapType(), NonNullable)); + } + + void visitStructGet(StructGet* curr) { notePossibleTrap(curr->ref); } + void visitStructSet(StructSet* curr) { notePossibleTrap(curr->ref); } + void visitArrayGet(ArrayGet* curr) { notePossibleTrap(curr->ref); } + void visitArraySet(ArraySet* curr) { notePossibleTrap(curr->ref); } + void visitArrayLen(ArrayLen* curr) { notePossibleTrap(curr->ref); } + void visitArrayCopy(ArrayCopy* curr) { + notePossibleTrap(curr->srcRef); + notePossibleTrap(curr->destRef); + } + void visitArrayFill(ArrayFill* curr) { notePossibleTrap(curr->ref); } + void visitArrayInitData(ArrayInitData* curr) { + notePossibleTrap(curr->ref); + } + void visitArrayInitElem(ArrayInitElem* curr) { + notePossibleTrap(curr->ref); + } + + void visitFunction(Function* curr) { + // In optimized TNH code, a function that always traps will be turned + // into a singleton unreachable instruction, so it is enough to check + // for that. + if (curr->body->is<Unreachable>()) { + info.traps = true; + } + } + } scanner(wasm, options, info); + scanner.walkFunction(func); +} + +void TNHOracle::infer() { + // Phase 2: Inside each function, optimize calls based on the cast params of + // the called function (which we noted during phase 1). + // + // Specifically, each time we call a target that will cast a param, we can + // infer that the param must have that type (or else we'd trap, but we are + // assuming traps never happen). + // + // While doing so we must be careful of control flow transfers right before + // the call: + // + // (call $target + // (A) + // (br_if ..) + // (B) + // ) + // + // If we branch in the br_if then we might execute A and then something else + // entirely, and not reach B or the call. In that case we can't infer anything + // about A (perhaps, for example, we branch away exactly when A would fail the + // cast). Therefore in the optimization below we only optimize code that, if + // reached, will definitely reach the call, like B. + // + // TODO: Some control flow transfers are ok, so long as we must reach the + // call, like if we replace the br_if with an if with two arms (and no + // branches in either). + // TODO: We can also infer backwards past basic blocks from casts, even + // without calls. Any cast tells us something about the uses of that + // value that must reach the cast. + // TODO: We can do a whole-program flow of this information. + + // For call_ref, we need to know which functions belong to each type. Gather + // that first. This map will map each heap type to each function that is of + // that type or a subtype, i.e., might be called when that type is seen in a + // call_ref target. + std::unordered_map<HeapType, std::vector<Function*>> typeFunctions; + if (options.closedWorld) { + for (auto& func : wasm.functions) { + auto type = func->type; + auto& info = map[wasm.getFunction(func->name)]; + if (info.traps) { + // This function definitely traps, so we can assume it is never called, + // and don't need to even bother putting it in |typeFunctions|. + continue; + } + while (1) { + typeFunctions[type].push_back(func.get()); + if (auto super = type.getSuperType()) { + type = *super; + } else { + break; + } + } + } + } + + doAnalysis([&](Function* func, TNHInfo& info) { + // We will need some CFG information below. Computing this is expensive, so + // only do it if we find optimization opportunities. + std::optional<analysis::CFGBlockIndexes> blockIndexes; + + auto ensureCFG = [&]() { + if (!blockIndexes) { + auto cfg = analysis::CFG::fromFunction(func); + blockIndexes = analysis::CFGBlockIndexes(cfg); + } + }; + + for (auto* call : info.calls) { + auto& targetInfo = map[wasm.getFunction(call->target)]; + + auto& targetCastParams = targetInfo.castParams; + if (targetCastParams.empty()) { + continue; + } + + // This looks promising, create the CFG if we haven't already, and + // optimize. + ensureCFG(); + optimizeCallCasts( + call, call->operands, targetCastParams, *blockIndexes, info); + + // Note that we don't need to do anything for targetInfo.traps for a + // direct call: the inliner will inline the singleton unreachable in the + // target function anyhow. + } + + for (auto* call : info.callRefs) { + auto targetType = call->target->type; + if (!targetType.isRef()) { + // This is unreachable or null, and other passes will optimize that. + continue; + } + + // We should only get here in a closed world, in which we know which + // functions might be called (the scan phase only notes callRefs if we are + // in fact in a closed world). + assert(options.closedWorld); + + auto iter = typeFunctions.find(targetType.getHeapType()); + if (iter == typeFunctions.end()) { + // No function exists of this type, so the call_ref will trap. We can + // mark the target as empty, which has the identical effect. + info.inferences[call->target] = PossibleContents::none(); + continue; + } + + // Go through the targets and ignore any that will trap. That will leave + // us with the actually possible targets. + // + // Note that we did not even add functions that certainly trap to + // |typeFunctions| at all, so those are already excluded. + const auto& targets = iter->second; + std::vector<Function*> possibleTargets; + for (Function* target : targets) { + auto& targetInfo = map[target]; + + // If any of our operands will fail a cast, then we will trap. + bool traps = false; + for (auto& [castIndex, castType] : targetInfo.castParams) { + auto operandType = call->operands[castIndex]->type; + auto result = GCTypeUtils::evaluateCastCheck(operandType, castType); + if (result == GCTypeUtils::Failure) { + traps = true; + break; + } + } + if (!traps) { + possibleTargets.push_back(target); + } + } + + if (possibleTargets.empty()) { + // No target is possible. + info.inferences[call->target] = PossibleContents::none(); + continue; + } + + if (possibleTargets.size() == 1) { + // There is exactly one possible call target, which means we can + // actually infer what the call_ref is calling. Add that as an + // inference. + // TODO: We could also optimizeCallCasts() here, but it is low priority + // as other opts will make this call direct later, after which a + // lot of other optimizations become possible anyhow. + auto target = possibleTargets[0]->name; + info.inferences[call->target] = PossibleContents::literal( + Literal(target, wasm.getFunction(target)->type)); + continue; + } + + // More than one target exists: apply the intersection of their + // constraints. That is, if they all cast the k-th parameter to type T (or + // more) than we can apply that here. + auto numParams = call->operands.size(); + std::vector<Type> sharedCastParamsVec(numParams, Type::unreachable); + for (auto* target : possibleTargets) { + auto& targetInfo = map[target]; + auto& targetCastParams = targetInfo.castParams; + for (Index i = 0; i < numParams; i++) { + auto iter = targetCastParams.find(i); + if (iter == targetCastParams.end()) { + // If the target does not cast, we cannot do anything with this + // parameter; mark it as unoptimizable with an impossible type. + sharedCastParamsVec[i] = Type::none; + continue; + } + + // This function casts this param. Combine this with existing info. + auto castType = iter->second; + sharedCastParamsVec[i] = + Type::getLeastUpperBound(sharedCastParamsVec[i], castType); + } + } + + // Build a map of the interesting cast params we found, and if there are + // any, optimize using them. + CastParams sharedCastParams; + for (Index i = 0; i < numParams; i++) { + auto type = sharedCastParamsVec[i]; + if (type != Type::none) { + sharedCastParams[i] = type; + } + } + if (!sharedCastParams.empty()) { + ensureCFG(); + optimizeCallCasts( + call, call->operands, sharedCastParams, *blockIndexes, info); + } + } + }); + + // Combine all of our inferences from the parallel phase above us into the + // final list of inferences. + for (auto& [_, info] : map) { + for (auto& [expr, contents] : info.inferences) { + inferences[expr] = contents; + } + } +} + +void TNHOracle::optimizeCallCasts(Expression* call, + const ExpressionList& operands, + const CastParams& targetCastParams, + const analysis::CFGBlockIndexes& blockIndexes, + TNHInfo& info) { + // Optimize in the same basic block as the call: all instructions still in + // that block will definitely execute if the call is reached. We will do that + // by going backwards through the call's operands and fallthrough values, and + // optimizing while we are still in the same basic block. + auto callBlockIndex = blockIndexes.get(call); + + // Operands must exist since there is a cast param, so a param exists. + assert(operands.size() > 0); + for (int i = int(operands.size() - 1); i >= 0; i--) { + auto* operand = operands[i]; + + if (blockIndexes.get(operand) != callBlockIndex) { + // Control flow might transfer; stop. + break; + } + + auto iter = targetCastParams.find(i); + if (iter == targetCastParams.end()) { + // This param is not cast, so skip it. + continue; + } + + // If the call executes then this parameter is definitely reached (since it + // is in the same basic block), and we know that it will be cast to a more + // refined type. + auto castType = iter->second; + + // Apply what we found to the operand and also to its fallthrough + // values. + // + // At the loop entry |curr| has been checked for a possible control flow + // transfer (and that problem ruled out). + auto* curr = operand; + while (1) { + // Note the type if it is useful. + if (castType != curr->type) { + // There are two constraints on this location: any value there must + // be of the declared type (curr->type) and also the cast type, so + // we know only their intersection can appear here. + auto declared = PossibleContents::fullConeType(curr->type); + auto intersection = PossibleContents::fullConeType(castType); + intersection.intersect(declared); + if (intersection.isConeType()) { + auto intersectionType = intersection.getType(); + if (intersectionType != curr->type) { + // We inferred a more refined type. + info.inferences[curr] = intersection; + } + } else { + // Otherwise, the intersection can be a null (if the heap types are + // incompatible, but a null is allowed), or empty. We can apply + // either. + assert(intersection.isNull() || intersection.isNone()); + info.inferences[curr] = intersection; + } + } + + auto* next = Properties::getImmediateFallthrough(curr, options, wasm); + if (next == curr) { + // No fallthrough, we're done with this param. + break; + } + + // There is a fallthrough. Check for a control flow transfer. + if (blockIndexes.get(next) != callBlockIndex) { + // Control flow might transfer; stop. We also cannot look at any further + // operands (if a child of this operand is in another basic block from + // the call, so are previous operands), so return from the entire + // function. + return; + } + + // Continue to the fallthrough. + curr = next; + } + } +} + // Main logic for building data for the flow analysis and then performing that // analysis. struct Flower { Module& wasm; + const PassOptions& options; - Flower(Module& wasm); + Flower(Module& wasm, const PassOptions& options); // Each LocationIndex will have one LocationInfo that contains the relevant // information we need for each location. @@ -1361,7 +1857,19 @@ struct Flower { return locations[index].contents; } + // Check what we know about the type of an expression, using static + // information from a TrapsNeverHappen oracle (see TNHOracle), if we have one. + PossibleContents getTNHContents(Expression* curr) { + if (!tnhOracle) { + // No oracle; just use the type in the IR. + return PossibleContents::fullConeType(curr->type); + } + return tnhOracle->getContents(curr); + } + private: + std::unique_ptr<TNHOracle> tnhOracle; + std::vector<LocationIndex>& getTargets(LocationIndex index) { assert(index < locations.size()); return locations[index].targets; @@ -1527,7 +2035,19 @@ private: #endif }; -Flower::Flower(Module& wasm) : wasm(wasm) { +Flower::Flower(Module& wasm, const PassOptions& options) + : wasm(wasm), options(options) { + + // If traps never happen, create a TNH oracle. + // + // Atm this oracle only helps on GC content, so disable it without GC. + if (options.trapsNeverHappen && wasm.features.hasGC()) { +#ifdef POSSIBLE_CONTENTS_DEBUG + std::cout << "tnh phase\n"; +#endif + tnhOracle = std::make_unique<TNHOracle>(wasm, options); + } + #ifdef POSSIBLE_CONTENTS_DEBUG std::cout << "parallel phase\n"; #endif @@ -1560,6 +2080,10 @@ Flower::Flower(Module& wasm) : wasm(wasm) { InfoCollector finder(globalInfo); finder.walkModuleCode(&wasm); +#ifdef POSSIBLE_CONTENTS_DEBUG + std::cout << "global init phase\n"; +#endif + // Connect global init values (which we've just processed, as part of the // module code) to the globals they initialize. for (auto& global : wasm.globals) { @@ -1733,7 +2257,7 @@ bool Flower::updateContents(LocationIndex locationIndex, auto oldContents = contents; #if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2 - std::cout << "updateContents\n"; + std::cout << "\nupdateContents\n"; dump(getLocation(locationIndex)); contents.dump(std::cout, &wasm); std::cout << "\n with new contents \n"; @@ -1961,7 +2485,10 @@ void Flower::filterExpressionContents(PossibleContents& contents, const ExpressionLocation& exprLoc, bool& worthSendingMore) { auto type = exprLoc.expr->type; - if (!type.isRef()) { + + if (type.isTuple()) { + // TODO: Optimize tuples here as well. We would need to take into account + // exprLoc.tupleIndex for that in all the below. return; } @@ -1969,17 +2496,28 @@ void Flower::filterExpressionContents(PossibleContents& contents, // more to a reference - all that logic is in here. That is, the rest of this // function is the only place we can mark |worthSendingMore| as false for a // reference. - assert(worthSendingMore); + bool isRef = type.isRef(); + assert(!isRef || worthSendingMore); - // The maximal contents here are the declared type and all subtypes. Nothing - // else can pass through, so filter such things out. - auto maximalContents = PossibleContents::fullConeType(type); + // The TNH oracle informs us of the maximal contents possible here. + auto maximalContents = getTNHContents(exprLoc.expr); +#if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2 + std::cout << "TNHOracle informs us that " << *exprLoc.expr << " contains " + << maximalContents << "\n"; +#endif contents.intersect(maximalContents); if (contents.isNone()) { // Nothing was left here at all. return; } + // For references we need to normalize the intersection, see below. For non- + // references, we are done (we did all the relevant work in the intersect() + // call). + if (!isRef) { + return; + } + // Normalize the intersection. We want to check later if any more content can // arrive here, and also we want to avoid flowing around anything non- // normalized, as explained earlier. @@ -2232,7 +2770,8 @@ void Flower::writeToData(Expression* ref, Expression* value, Index fieldIndex) { #if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2 void Flower::dump(Location location) { if (auto* loc = std::get_if<ExpressionLocation>(&location)) { - std::cout << " exprloc \n" << *loc->expr << '\n'; + std::cout << " exprloc \n" + << *loc->expr << " : " << loc->tupleIndex << '\n'; } else if (auto* loc = std::get_if<DataLocation>(&location)) { std::cout << " dataloc "; if (wasm.typeNames.count(loc->type)) { @@ -2242,7 +2781,7 @@ void Flower::dump(Location location) { } std::cout << " : " << loc->index << '\n'; } else if (auto* loc = std::get_if<TagLocation>(&location)) { - std::cout << " tagloc " << loc->tag << '\n'; + std::cout << " tagloc " << loc->tag << " : " << loc->tupleIndex << '\n'; } else if (auto* loc = std::get_if<ParamLocation>(&location)) { std::cout << " paramloc " << loc->func->name << " : " << loc->index << '\n'; @@ -2269,7 +2808,7 @@ void Flower::dump(Location location) { } // anonymous namespace void ContentOracle::analyze() { - Flower flower(wasm); + Flower flower(wasm, options); for (LocationIndex i = 0; i < flower.locations.size(); i++) { locationContents[flower.getLocation(i)] = flower.getContents(i); } |