summaryrefslogtreecommitdiff
path: root/src/ir/possible-contents.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/ir/possible-contents.cpp')
-rw-r--r--src/ir/possible-contents.cpp561
1 files changed, 550 insertions, 11 deletions
diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp
index 916a39f2c..232ebb7ad 100644
--- a/src/ir/possible-contents.cpp
+++ b/src/ir/possible-contents.cpp
@@ -17,10 +17,12 @@
#include <optional>
#include <variant>
+#include "analysis/cfg.h"
#include "ir/bits.h"
#include "ir/branch-utils.h"
#include "ir/eh-utils.h"
#include "ir/gc-type-utils.h"
+#include "ir/linear-execution.h"
#include "ir/local-graph.h"
#include "ir/module-utils.h"
#include "ir/possible-contents.h"
@@ -1321,12 +1323,506 @@ struct InfoCollector
}
};
+// TrapsNeverHappen Oracle. This makes inferences *backwards* from traps that we
+// know will not happen due to the TNH assumption. For example,
+//
+// (local.get $a)
+// (ref.cast $B (local.get $a))
+//
+// The cast happens right after the first local.get, and we assume it does not
+// fail, so the local must contain a B, even though the IR only has A.
+//
+// This analysis complements ContentOracle, which uses this analysis internally.
+// ContentOracle does a forward flow analysis (as content moves from place to
+// place) which increases from "nothing", while this does a backwards analysis
+// that decreases from "everything" (or rather, from the type declared in the
+// IR), so the two cannot be done at once.
+//
+// TODO: We could cycle between this and ContentOracle for repeated
+// improvements.
+// TODO: This pass itself could benefit from internal cycles.
+//
+// This analysis mainly focuses on information across calls, as simple backwards
+// inference is done in OptimizeCasts. Note that it is not needed if a call is
+// inlined, obviously, and so it mostly helps cases like functions too large to
+// inline, or when optimizing for size, or with indirect calls.
+//
+// We track cast parameters by mapping an index to the type it is definitely
+// cast to if the function is entered. From that information we can infer things
+// about the values being sent to the function (which we can assume must have
+// the right type so that the casts do not trap).
+using CastParams = std::unordered_map<Index, Type>;
+
+// The information we collect and utilize as we operate in parallel in each
+// function.
+struct TNHInfo {
+ CastParams castParams;
+
+ // TODO: Returns as well: when we see (ref.cast (call $foo)) in all callers
+ // then we can refine inside $foo (in closed world).
+
+ // We gather calls in parallel in order to process them later.
+ std::vector<Call*> calls;
+ std::vector<CallRef*> callRefs;
+
+ // Note if a function body definitely traps.
+ bool traps = false;
+
+ // We gather inferences in parallel and combine them at the end.
+ std::unordered_map<Expression*, PossibleContents> inferences;
+};
+
+class TNHOracle : public ModuleUtils::ParallelFunctionAnalysis<TNHInfo> {
+ const PassOptions& options;
+
+public:
+ using Parent = ModuleUtils::ParallelFunctionAnalysis<TNHInfo>;
+ TNHOracle(Module& wasm, const PassOptions& options)
+ : Parent(wasm,
+ [this, &options](Function* func, TNHInfo& info) {
+ scan(func, info, options);
+ }),
+ options(options) {
+
+ // After the scanning phase that we run in the constructor, continue to the
+ // second phase of analysis: inference.
+ infer();
+ }
+
+ // Get the type we inferred was possible at a location.
+ PossibleContents getContents(Expression* curr) {
+ auto naiveContents = PossibleContents::fullConeType(curr->type);
+
+ // If we inferred nothing, use the naive type.
+ auto iter = inferences.find(curr);
+ if (iter == inferences.end()) {
+ return naiveContents;
+ }
+
+ auto& contents = iter->second;
+ // We only store useful contents that improve on the naive estimate that
+ // uses the type in the IR.
+ assert(contents != naiveContents);
+ return contents;
+ }
+
+private:
+ // Maps expressions to the content we inferred there. If an expression is not
+ // here then expression->type (the type in Binaryen IR) is all we have.
+ std::unordered_map<Expression*, PossibleContents> inferences;
+
+ // Phase 1: Scan to find cast parameters and calls. This operates on a single
+ // function, and is called in parallel.
+ void scan(Function* func, TNHInfo& info, const PassOptions& options);
+
+ // Phase 2: Infer contents based on what we scanned.
+ void infer();
+
+ // Optimize one specific call (or call_ref).
+ void optimizeCallCasts(Expression* call,
+ const ExpressionList& operands,
+ const CastParams& targetCastParams,
+ const analysis::CFGBlockIndexes& blockIndexes,
+ TNHInfo& info);
+};
+
+void TNHOracle::scan(Function* func,
+ TNHInfo& info,
+ const PassOptions& options) {
+ if (func->imported()) {
+ return;
+ }
+
+ // Gather parameters that are definitely cast in the function entry.
+ struct EntryScanner : public LinearExecutionWalker<EntryScanner> {
+ Module& wasm;
+ const PassOptions& options;
+ TNHInfo& info;
+
+ EntryScanner(Module& wasm, const PassOptions& options, TNHInfo& info)
+ : wasm(wasm), options(options), info(info) {}
+
+ // Note while we are still in the entry (first) block.
+ bool inEntryBlock = true;
+
+ static void doNoteNonLinear(EntryScanner* self, Expression** currp) {
+ // This is the end of the first basic block.
+ self->inEntryBlock = false;
+ }
+
+ void visitCall(Call* curr) { info.calls.push_back(curr); }
+
+ void visitCallRef(CallRef* curr) {
+ // We can only optimize call_ref in closed world, as otherwise the
+ // call can go somewhere we can't see.
+ if (options.closedWorld) {
+ info.callRefs.push_back(curr);
+ }
+ }
+
+ void visitRefAs(RefAs* curr) {
+ if (curr->op == RefAsNonNull) {
+ noteCast(curr->value, curr->type);
+ }
+ }
+ void visitRefCast(RefCast* curr) { noteCast(curr->ref, curr->type); }
+
+ // Note a cast of an expression to a particular type.
+ void noteCast(Expression* expr, Type type) {
+ if (!inEntryBlock) {
+ return;
+ }
+
+ auto* fallthrough = Properties::getFallthrough(expr, options, wasm);
+ if (auto* get = fallthrough->dynCast<LocalGet>()) {
+ // To optimize, this needs to be a param, and of a useful type.
+ //
+ // Note that if we see more than one cast we keep the first one. This is
+ // not important in optimized code, as the most refined cast would be
+ // the only one to exist there, so it's ok to keep things simple here.
+ if (getFunction()->isParam(get->index) && type != get->type &&
+ info.castParams.count(get->index) == 0) {
+ info.castParams[get->index] = type;
+ }
+ }
+ }
+
+ // Operations that trap on null are equivalent to casts to non-null, in that
+ // they imply that their input is non-null if traps never happen.
+ //
+ // We only look at them if the input is actually nullable, since if they
+ // are non-nullable then we can add no information. (This is equivalent
+ // to the handling of RefAsNonNull above, in the sense that in optimized
+ // code the RefAs will not appear if the input is already non-nullable).
+ // This function is called with the reference that will be trapped on,
+ // if it is null.
+ void notePossibleTrap(Expression* expr) {
+ if (!expr->type.isRef() || expr->type.isNonNullable()) {
+ return;
+ }
+ noteCast(expr, Type(expr->type.getHeapType(), NonNullable));
+ }
+
+ void visitStructGet(StructGet* curr) { notePossibleTrap(curr->ref); }
+ void visitStructSet(StructSet* curr) { notePossibleTrap(curr->ref); }
+ void visitArrayGet(ArrayGet* curr) { notePossibleTrap(curr->ref); }
+ void visitArraySet(ArraySet* curr) { notePossibleTrap(curr->ref); }
+ void visitArrayLen(ArrayLen* curr) { notePossibleTrap(curr->ref); }
+ void visitArrayCopy(ArrayCopy* curr) {
+ notePossibleTrap(curr->srcRef);
+ notePossibleTrap(curr->destRef);
+ }
+ void visitArrayFill(ArrayFill* curr) { notePossibleTrap(curr->ref); }
+ void visitArrayInitData(ArrayInitData* curr) {
+ notePossibleTrap(curr->ref);
+ }
+ void visitArrayInitElem(ArrayInitElem* curr) {
+ notePossibleTrap(curr->ref);
+ }
+
+ void visitFunction(Function* curr) {
+ // In optimized TNH code, a function that always traps will be turned
+ // into a singleton unreachable instruction, so it is enough to check
+ // for that.
+ if (curr->body->is<Unreachable>()) {
+ info.traps = true;
+ }
+ }
+ } scanner(wasm, options, info);
+ scanner.walkFunction(func);
+}
+
+void TNHOracle::infer() {
+ // Phase 2: Inside each function, optimize calls based on the cast params of
+ // the called function (which we noted during phase 1).
+ //
+ // Specifically, each time we call a target that will cast a param, we can
+ // infer that the param must have that type (or else we'd trap, but we are
+ // assuming traps never happen).
+ //
+ // While doing so we must be careful of control flow transfers right before
+ // the call:
+ //
+ // (call $target
+ // (A)
+ // (br_if ..)
+ // (B)
+ // )
+ //
+ // If we branch in the br_if then we might execute A and then something else
+ // entirely, and not reach B or the call. In that case we can't infer anything
+ // about A (perhaps, for example, we branch away exactly when A would fail the
+ // cast). Therefore in the optimization below we only optimize code that, if
+ // reached, will definitely reach the call, like B.
+ //
+ // TODO: Some control flow transfers are ok, so long as we must reach the
+ // call, like if we replace the br_if with an if with two arms (and no
+ // branches in either).
+ // TODO: We can also infer backwards past basic blocks from casts, even
+ // without calls. Any cast tells us something about the uses of that
+ // value that must reach the cast.
+ // TODO: We can do a whole-program flow of this information.
+
+ // For call_ref, we need to know which functions belong to each type. Gather
+ // that first. This map will map each heap type to each function that is of
+ // that type or a subtype, i.e., might be called when that type is seen in a
+ // call_ref target.
+ std::unordered_map<HeapType, std::vector<Function*>> typeFunctions;
+ if (options.closedWorld) {
+ for (auto& func : wasm.functions) {
+ auto type = func->type;
+ auto& info = map[wasm.getFunction(func->name)];
+ if (info.traps) {
+ // This function definitely traps, so we can assume it is never called,
+ // and don't need to even bother putting it in |typeFunctions|.
+ continue;
+ }
+ while (1) {
+ typeFunctions[type].push_back(func.get());
+ if (auto super = type.getSuperType()) {
+ type = *super;
+ } else {
+ break;
+ }
+ }
+ }
+ }
+
+ doAnalysis([&](Function* func, TNHInfo& info) {
+ // We will need some CFG information below. Computing this is expensive, so
+ // only do it if we find optimization opportunities.
+ std::optional<analysis::CFGBlockIndexes> blockIndexes;
+
+ auto ensureCFG = [&]() {
+ if (!blockIndexes) {
+ auto cfg = analysis::CFG::fromFunction(func);
+ blockIndexes = analysis::CFGBlockIndexes(cfg);
+ }
+ };
+
+ for (auto* call : info.calls) {
+ auto& targetInfo = map[wasm.getFunction(call->target)];
+
+ auto& targetCastParams = targetInfo.castParams;
+ if (targetCastParams.empty()) {
+ continue;
+ }
+
+ // This looks promising, create the CFG if we haven't already, and
+ // optimize.
+ ensureCFG();
+ optimizeCallCasts(
+ call, call->operands, targetCastParams, *blockIndexes, info);
+
+ // Note that we don't need to do anything for targetInfo.traps for a
+ // direct call: the inliner will inline the singleton unreachable in the
+ // target function anyhow.
+ }
+
+ for (auto* call : info.callRefs) {
+ auto targetType = call->target->type;
+ if (!targetType.isRef()) {
+ // This is unreachable or null, and other passes will optimize that.
+ continue;
+ }
+
+ // We should only get here in a closed world, in which we know which
+ // functions might be called (the scan phase only notes callRefs if we are
+ // in fact in a closed world).
+ assert(options.closedWorld);
+
+ auto iter = typeFunctions.find(targetType.getHeapType());
+ if (iter == typeFunctions.end()) {
+ // No function exists of this type, so the call_ref will trap. We can
+ // mark the target as empty, which has the identical effect.
+ info.inferences[call->target] = PossibleContents::none();
+ continue;
+ }
+
+ // Go through the targets and ignore any that will trap. That will leave
+ // us with the actually possible targets.
+ //
+ // Note that we did not even add functions that certainly trap to
+ // |typeFunctions| at all, so those are already excluded.
+ const auto& targets = iter->second;
+ std::vector<Function*> possibleTargets;
+ for (Function* target : targets) {
+ auto& targetInfo = map[target];
+
+ // If any of our operands will fail a cast, then we will trap.
+ bool traps = false;
+ for (auto& [castIndex, castType] : targetInfo.castParams) {
+ auto operandType = call->operands[castIndex]->type;
+ auto result = GCTypeUtils::evaluateCastCheck(operandType, castType);
+ if (result == GCTypeUtils::Failure) {
+ traps = true;
+ break;
+ }
+ }
+ if (!traps) {
+ possibleTargets.push_back(target);
+ }
+ }
+
+ if (possibleTargets.empty()) {
+ // No target is possible.
+ info.inferences[call->target] = PossibleContents::none();
+ continue;
+ }
+
+ if (possibleTargets.size() == 1) {
+ // There is exactly one possible call target, which means we can
+ // actually infer what the call_ref is calling. Add that as an
+ // inference.
+ // TODO: We could also optimizeCallCasts() here, but it is low priority
+ // as other opts will make this call direct later, after which a
+ // lot of other optimizations become possible anyhow.
+ auto target = possibleTargets[0]->name;
+ info.inferences[call->target] = PossibleContents::literal(
+ Literal(target, wasm.getFunction(target)->type));
+ continue;
+ }
+
+ // More than one target exists: apply the intersection of their
+ // constraints. That is, if they all cast the k-th parameter to type T (or
+ // more) than we can apply that here.
+ auto numParams = call->operands.size();
+ std::vector<Type> sharedCastParamsVec(numParams, Type::unreachable);
+ for (auto* target : possibleTargets) {
+ auto& targetInfo = map[target];
+ auto& targetCastParams = targetInfo.castParams;
+ for (Index i = 0; i < numParams; i++) {
+ auto iter = targetCastParams.find(i);
+ if (iter == targetCastParams.end()) {
+ // If the target does not cast, we cannot do anything with this
+ // parameter; mark it as unoptimizable with an impossible type.
+ sharedCastParamsVec[i] = Type::none;
+ continue;
+ }
+
+ // This function casts this param. Combine this with existing info.
+ auto castType = iter->second;
+ sharedCastParamsVec[i] =
+ Type::getLeastUpperBound(sharedCastParamsVec[i], castType);
+ }
+ }
+
+ // Build a map of the interesting cast params we found, and if there are
+ // any, optimize using them.
+ CastParams sharedCastParams;
+ for (Index i = 0; i < numParams; i++) {
+ auto type = sharedCastParamsVec[i];
+ if (type != Type::none) {
+ sharedCastParams[i] = type;
+ }
+ }
+ if (!sharedCastParams.empty()) {
+ ensureCFG();
+ optimizeCallCasts(
+ call, call->operands, sharedCastParams, *blockIndexes, info);
+ }
+ }
+ });
+
+ // Combine all of our inferences from the parallel phase above us into the
+ // final list of inferences.
+ for (auto& [_, info] : map) {
+ for (auto& [expr, contents] : info.inferences) {
+ inferences[expr] = contents;
+ }
+ }
+}
+
+void TNHOracle::optimizeCallCasts(Expression* call,
+ const ExpressionList& operands,
+ const CastParams& targetCastParams,
+ const analysis::CFGBlockIndexes& blockIndexes,
+ TNHInfo& info) {
+ // Optimize in the same basic block as the call: all instructions still in
+ // that block will definitely execute if the call is reached. We will do that
+ // by going backwards through the call's operands and fallthrough values, and
+ // optimizing while we are still in the same basic block.
+ auto callBlockIndex = blockIndexes.get(call);
+
+ // Operands must exist since there is a cast param, so a param exists.
+ assert(operands.size() > 0);
+ for (int i = int(operands.size() - 1); i >= 0; i--) {
+ auto* operand = operands[i];
+
+ if (blockIndexes.get(operand) != callBlockIndex) {
+ // Control flow might transfer; stop.
+ break;
+ }
+
+ auto iter = targetCastParams.find(i);
+ if (iter == targetCastParams.end()) {
+ // This param is not cast, so skip it.
+ continue;
+ }
+
+ // If the call executes then this parameter is definitely reached (since it
+ // is in the same basic block), and we know that it will be cast to a more
+ // refined type.
+ auto castType = iter->second;
+
+ // Apply what we found to the operand and also to its fallthrough
+ // values.
+ //
+ // At the loop entry |curr| has been checked for a possible control flow
+ // transfer (and that problem ruled out).
+ auto* curr = operand;
+ while (1) {
+ // Note the type if it is useful.
+ if (castType != curr->type) {
+ // There are two constraints on this location: any value there must
+ // be of the declared type (curr->type) and also the cast type, so
+ // we know only their intersection can appear here.
+ auto declared = PossibleContents::fullConeType(curr->type);
+ auto intersection = PossibleContents::fullConeType(castType);
+ intersection.intersect(declared);
+ if (intersection.isConeType()) {
+ auto intersectionType = intersection.getType();
+ if (intersectionType != curr->type) {
+ // We inferred a more refined type.
+ info.inferences[curr] = intersection;
+ }
+ } else {
+ // Otherwise, the intersection can be a null (if the heap types are
+ // incompatible, but a null is allowed), or empty. We can apply
+ // either.
+ assert(intersection.isNull() || intersection.isNone());
+ info.inferences[curr] = intersection;
+ }
+ }
+
+ auto* next = Properties::getImmediateFallthrough(curr, options, wasm);
+ if (next == curr) {
+ // No fallthrough, we're done with this param.
+ break;
+ }
+
+ // There is a fallthrough. Check for a control flow transfer.
+ if (blockIndexes.get(next) != callBlockIndex) {
+ // Control flow might transfer; stop. We also cannot look at any further
+ // operands (if a child of this operand is in another basic block from
+ // the call, so are previous operands), so return from the entire
+ // function.
+ return;
+ }
+
+ // Continue to the fallthrough.
+ curr = next;
+ }
+ }
+}
+
// Main logic for building data for the flow analysis and then performing that
// analysis.
struct Flower {
Module& wasm;
+ const PassOptions& options;
- Flower(Module& wasm);
+ Flower(Module& wasm, const PassOptions& options);
// Each LocationIndex will have one LocationInfo that contains the relevant
// information we need for each location.
@@ -1361,7 +1857,19 @@ struct Flower {
return locations[index].contents;
}
+ // Check what we know about the type of an expression, using static
+ // information from a TrapsNeverHappen oracle (see TNHOracle), if we have one.
+ PossibleContents getTNHContents(Expression* curr) {
+ if (!tnhOracle) {
+ // No oracle; just use the type in the IR.
+ return PossibleContents::fullConeType(curr->type);
+ }
+ return tnhOracle->getContents(curr);
+ }
+
private:
+ std::unique_ptr<TNHOracle> tnhOracle;
+
std::vector<LocationIndex>& getTargets(LocationIndex index) {
assert(index < locations.size());
return locations[index].targets;
@@ -1527,7 +2035,19 @@ private:
#endif
};
-Flower::Flower(Module& wasm) : wasm(wasm) {
+Flower::Flower(Module& wasm, const PassOptions& options)
+ : wasm(wasm), options(options) {
+
+ // If traps never happen, create a TNH oracle.
+ //
+ // Atm this oracle only helps on GC content, so disable it without GC.
+ if (options.trapsNeverHappen && wasm.features.hasGC()) {
+#ifdef POSSIBLE_CONTENTS_DEBUG
+ std::cout << "tnh phase\n";
+#endif
+ tnhOracle = std::make_unique<TNHOracle>(wasm, options);
+ }
+
#ifdef POSSIBLE_CONTENTS_DEBUG
std::cout << "parallel phase\n";
#endif
@@ -1560,6 +2080,10 @@ Flower::Flower(Module& wasm) : wasm(wasm) {
InfoCollector finder(globalInfo);
finder.walkModuleCode(&wasm);
+#ifdef POSSIBLE_CONTENTS_DEBUG
+ std::cout << "global init phase\n";
+#endif
+
// Connect global init values (which we've just processed, as part of the
// module code) to the globals they initialize.
for (auto& global : wasm.globals) {
@@ -1733,7 +2257,7 @@ bool Flower::updateContents(LocationIndex locationIndex,
auto oldContents = contents;
#if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2
- std::cout << "updateContents\n";
+ std::cout << "\nupdateContents\n";
dump(getLocation(locationIndex));
contents.dump(std::cout, &wasm);
std::cout << "\n with new contents \n";
@@ -1961,7 +2485,10 @@ void Flower::filterExpressionContents(PossibleContents& contents,
const ExpressionLocation& exprLoc,
bool& worthSendingMore) {
auto type = exprLoc.expr->type;
- if (!type.isRef()) {
+
+ if (type.isTuple()) {
+ // TODO: Optimize tuples here as well. We would need to take into account
+ // exprLoc.tupleIndex for that in all the below.
return;
}
@@ -1969,17 +2496,28 @@ void Flower::filterExpressionContents(PossibleContents& contents,
// more to a reference - all that logic is in here. That is, the rest of this
// function is the only place we can mark |worthSendingMore| as false for a
// reference.
- assert(worthSendingMore);
+ bool isRef = type.isRef();
+ assert(!isRef || worthSendingMore);
- // The maximal contents here are the declared type and all subtypes. Nothing
- // else can pass through, so filter such things out.
- auto maximalContents = PossibleContents::fullConeType(type);
+ // The TNH oracle informs us of the maximal contents possible here.
+ auto maximalContents = getTNHContents(exprLoc.expr);
+#if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2
+ std::cout << "TNHOracle informs us that " << *exprLoc.expr << " contains "
+ << maximalContents << "\n";
+#endif
contents.intersect(maximalContents);
if (contents.isNone()) {
// Nothing was left here at all.
return;
}
+ // For references we need to normalize the intersection, see below. For non-
+ // references, we are done (we did all the relevant work in the intersect()
+ // call).
+ if (!isRef) {
+ return;
+ }
+
// Normalize the intersection. We want to check later if any more content can
// arrive here, and also we want to avoid flowing around anything non-
// normalized, as explained earlier.
@@ -2232,7 +2770,8 @@ void Flower::writeToData(Expression* ref, Expression* value, Index fieldIndex) {
#if defined(POSSIBLE_CONTENTS_DEBUG) && POSSIBLE_CONTENTS_DEBUG >= 2
void Flower::dump(Location location) {
if (auto* loc = std::get_if<ExpressionLocation>(&location)) {
- std::cout << " exprloc \n" << *loc->expr << '\n';
+ std::cout << " exprloc \n"
+ << *loc->expr << " : " << loc->tupleIndex << '\n';
} else if (auto* loc = std::get_if<DataLocation>(&location)) {
std::cout << " dataloc ";
if (wasm.typeNames.count(loc->type)) {
@@ -2242,7 +2781,7 @@ void Flower::dump(Location location) {
}
std::cout << " : " << loc->index << '\n';
} else if (auto* loc = std::get_if<TagLocation>(&location)) {
- std::cout << " tagloc " << loc->tag << '\n';
+ std::cout << " tagloc " << loc->tag << " : " << loc->tupleIndex << '\n';
} else if (auto* loc = std::get_if<ParamLocation>(&location)) {
std::cout << " paramloc " << loc->func->name << " : " << loc->index
<< '\n';
@@ -2269,7 +2808,7 @@ void Flower::dump(Location location) {
} // anonymous namespace
void ContentOracle::analyze() {
- Flower flower(wasm);
+ Flower flower(wasm, options);
for (LocationIndex i = 0; i < flower.locations.size(); i++) {
locationContents[flower.getLocation(i)] = flower.getContents(i);
}