diff options
author | Alon Zakai <azakai@google.com> | 2021-12-14 13:58:12 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-14 13:58:12 -0800 |
commit | 01e6429a13abd2d8f5b3d4019a273e992fdd3a66 (patch) | |
tree | 4437eb47d7c6e3cb5c9b63a857e7cbff60212656 /src | |
parent | e5594bbde36ac96b6f2a41259d8a8d66e5d0a7cc (diff) | |
download | binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.gz binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.bz2 binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.zip |
[Wasm GC] Refine results in SignatureRefining (#4380)
Similar to what DeadArgumentElimination does for individual functions, this
can refine the results of a set of functions all using the same heap type, when
they all return something more specific. After this PR SignatureRefining can
refine both params and results and is basically complete.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/SignatureRefining.cpp | 87 |
1 files changed, 70 insertions, 17 deletions
diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp index 8e8ecfbe2..623a393c1 100644 --- a/src/passes/SignatureRefining.cpp +++ b/src/passes/SignatureRefining.cpp @@ -25,13 +25,12 @@ // so while considering all users of the type (across all functions sharing that // type, and all call_refs using it). // -// TODO: optimize results too and not just params. -// #include "ir/find_all.h" #include "ir/lubs.h" #include "ir/module-utils.h" #include "ir/type-updating.h" +#include "ir/utils.h" #include "pass.h" #include "wasm-type.h" #include "wasm.h" @@ -62,43 +61,55 @@ struct SignatureRefining : public Pass { return; } - // First, find all the calls and call_refs. + // First, find all the information we need. Start by collecting inside each + // function in parallel. - struct CallInfo { + struct Info { + // The calls and call_refs. std::vector<Call*> calls; std::vector<CallRef*> callRefs; + + // A possibly improved LUB for the results. + LUBFinder resultsLUB; }; - ModuleUtils::ParallelFunctionAnalysis<CallInfo> analysis( - *module, [&](Function* func, CallInfo& info) { + ModuleUtils::ParallelFunctionAnalysis<Info> analysis( + *module, [&](Function* func, Info& info) { if (func->imported()) { return; } info.calls = std::move(FindAll<Call>(func->body).list); info.callRefs = std::move(FindAll<CallRef>(func->body).list); + info.resultsLUB = LUB::getResultsLUB(func, *module); }); - // A map of types to the calls and call_refs that use that type. - std::unordered_map<HeapType, CallInfo> allCallsTo; + // A map of types to all the information combined over all the functions + // with that type. + std::unordered_map<HeapType, Info> allInfo; // Combine all the information we gathered into that map. for (auto& [func, info] : analysis.map) { // For direct calls, add each call to the type of the function being // called. for (auto* call : info.calls) { - allCallsTo[module->getFunction(call->target)->type].calls.push_back( - call); + allInfo[module->getFunction(call->target)->type].calls.push_back(call); } // For indirect calls, add each call_ref to the type the call_ref uses. for (auto* callRef : info.callRefs) { auto calledType = callRef->target->type; if (calledType != Type::unreachable) { - allCallsTo[calledType.getHeapType()].callRefs.push_back(callRef); + allInfo[calledType.getHeapType()].callRefs.push_back(callRef); } } + + // Add the function's return LUB to the one for the heap type of that + // function. + allInfo[func->type].resultsLUB.combine(info.resultsLUB); } + bool refinedResults = false; + // Compute optimal LUBs. std::unordered_set<HeapType> seen; for (auto& func : module->functions) { @@ -118,11 +129,11 @@ struct SignatureRefining : public Pass { } }; - auto& callsTo = allCallsTo[type]; - for (auto* call : callsTo.calls) { + auto& info = allInfo[type]; + for (auto* call : info.calls) { updateLUBs(call->operands); } - for (auto* callRef : callsTo.callRefs) { + for (auto* callRef : info.callRefs) { updateLUBs(callRef->operands); } @@ -134,20 +145,55 @@ struct SignatureRefining : public Pass { } newParamsTypes.push_back(lub.getBestPossible()); } + Type newParams; if (newParamsTypes.size() < numParams) { // We did not have type information to calculate a LUB (no calls, or // some param is always unreachable), so there is nothing we can improve // here. Other passes might remove the type entirely. + newParams = func->getParams(); + } else { + newParams = Type(newParamsTypes); + } + + auto& resultsLUB = info.resultsLUB; + Type newResults; + if (!resultsLUB.noted()) { + // We did not have type information to calculate a LUB (no returned + // value, or it can return a value but traps instead etc.). + newResults = func->getResults(); + } else { + newResults = resultsLUB.getBestPossible(); + } + + if (newParams == func->getParams() && newResults == func->getResults()) { continue; } - auto newParams = Type(newParamsTypes); + + // We found an improvement! + newSignatures[type] = Signature(newParams, newResults); + + // Update nulls as necessary, now that we are changing things. if (newParams != func->getParams()) { - // We found an improvement! - newSignatures[type] = Signature(newParams, Type::none); for (auto& lub : paramLUBs) { lub.updateNulls(); } } + if (newResults != func->getResults()) { + resultsLUB.updateNulls(); + refinedResults = true; + + // Update the types of calls using the signature. + for (auto* call : info.calls) { + if (call->type != Type::unreachable) { + call->type = newResults; + } + } + for (auto* callRef : info.callRefs) { + if (callRef->type != Type::unreachable) { + callRef->type = newResults; + } + } + } } if (newSignatures.empty()) { @@ -192,11 +238,18 @@ struct SignatureRefining : public Pass { auto iter = parent.newSignatures.find(oldSignatureType); if (iter != parent.newSignatures.end()) { sig.params = getTempType(iter->second.params); + sig.results = getTempType(iter->second.results); } } }; TypeRewriter(*module, *this).update(); + + if (refinedResults) { + // After return types change we need to propagate. + // TODO: we could do this only in relevant functions perhaps + ReFinalize().run(runner, module); + } } }; |