diff options
author | Alon Zakai <azakai@google.com> | 2021-12-14 13:58:12 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-14 13:58:12 -0800 |
commit | 01e6429a13abd2d8f5b3d4019a273e992fdd3a66 (patch) | |
tree | 4437eb47d7c6e3cb5c9b63a857e7cbff60212656 | |
parent | e5594bbde36ac96b6f2a41259d8a8d66e5d0a7cc (diff) | |
download | binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.gz binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.bz2 binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.zip |
[Wasm GC] Refine results in SignatureRefining (#4380)
Similar to what DeadArgumentElimination does for individual functions, this
can refine the results of a set of functions all using the same heap type, when
they all return something more specific. After this PR SignatureRefining can
refine both params and results and is basically complete.
-rw-r--r-- | src/passes/SignatureRefining.cpp | 87 | ||||
-rw-r--r-- | test/lit/passes/signature-refining.wast | 134 |
2 files changed, 204 insertions, 17 deletions
diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp index 8e8ecfbe2..623a393c1 100644 --- a/src/passes/SignatureRefining.cpp +++ b/src/passes/SignatureRefining.cpp @@ -25,13 +25,12 @@ // so while considering all users of the type (across all functions sharing that // type, and all call_refs using it). // -// TODO: optimize results too and not just params. -// #include "ir/find_all.h" #include "ir/lubs.h" #include "ir/module-utils.h" #include "ir/type-updating.h" +#include "ir/utils.h" #include "pass.h" #include "wasm-type.h" #include "wasm.h" @@ -62,43 +61,55 @@ struct SignatureRefining : public Pass { return; } - // First, find all the calls and call_refs. + // First, find all the information we need. Start by collecting inside each + // function in parallel. - struct CallInfo { + struct Info { + // The calls and call_refs. std::vector<Call*> calls; std::vector<CallRef*> callRefs; + + // A possibly improved LUB for the results. + LUBFinder resultsLUB; }; - ModuleUtils::ParallelFunctionAnalysis<CallInfo> analysis( - *module, [&](Function* func, CallInfo& info) { + ModuleUtils::ParallelFunctionAnalysis<Info> analysis( + *module, [&](Function* func, Info& info) { if (func->imported()) { return; } info.calls = std::move(FindAll<Call>(func->body).list); info.callRefs = std::move(FindAll<CallRef>(func->body).list); + info.resultsLUB = LUB::getResultsLUB(func, *module); }); - // A map of types to the calls and call_refs that use that type. - std::unordered_map<HeapType, CallInfo> allCallsTo; + // A map of types to all the information combined over all the functions + // with that type. + std::unordered_map<HeapType, Info> allInfo; // Combine all the information we gathered into that map. for (auto& [func, info] : analysis.map) { // For direct calls, add each call to the type of the function being // called. for (auto* call : info.calls) { - allCallsTo[module->getFunction(call->target)->type].calls.push_back( - call); + allInfo[module->getFunction(call->target)->type].calls.push_back(call); } // For indirect calls, add each call_ref to the type the call_ref uses. for (auto* callRef : info.callRefs) { auto calledType = callRef->target->type; if (calledType != Type::unreachable) { - allCallsTo[calledType.getHeapType()].callRefs.push_back(callRef); + allInfo[calledType.getHeapType()].callRefs.push_back(callRef); } } + + // Add the function's return LUB to the one for the heap type of that + // function. + allInfo[func->type].resultsLUB.combine(info.resultsLUB); } + bool refinedResults = false; + // Compute optimal LUBs. std::unordered_set<HeapType> seen; for (auto& func : module->functions) { @@ -118,11 +129,11 @@ struct SignatureRefining : public Pass { } }; - auto& callsTo = allCallsTo[type]; - for (auto* call : callsTo.calls) { + auto& info = allInfo[type]; + for (auto* call : info.calls) { updateLUBs(call->operands); } - for (auto* callRef : callsTo.callRefs) { + for (auto* callRef : info.callRefs) { updateLUBs(callRef->operands); } @@ -134,20 +145,55 @@ struct SignatureRefining : public Pass { } newParamsTypes.push_back(lub.getBestPossible()); } + Type newParams; if (newParamsTypes.size() < numParams) { // We did not have type information to calculate a LUB (no calls, or // some param is always unreachable), so there is nothing we can improve // here. Other passes might remove the type entirely. + newParams = func->getParams(); + } else { + newParams = Type(newParamsTypes); + } + + auto& resultsLUB = info.resultsLUB; + Type newResults; + if (!resultsLUB.noted()) { + // We did not have type information to calculate a LUB (no returned + // value, or it can return a value but traps instead etc.). + newResults = func->getResults(); + } else { + newResults = resultsLUB.getBestPossible(); + } + + if (newParams == func->getParams() && newResults == func->getResults()) { continue; } - auto newParams = Type(newParamsTypes); + + // We found an improvement! + newSignatures[type] = Signature(newParams, newResults); + + // Update nulls as necessary, now that we are changing things. if (newParams != func->getParams()) { - // We found an improvement! - newSignatures[type] = Signature(newParams, Type::none); for (auto& lub : paramLUBs) { lub.updateNulls(); } } + if (newResults != func->getResults()) { + resultsLUB.updateNulls(); + refinedResults = true; + + // Update the types of calls using the signature. + for (auto* call : info.calls) { + if (call->type != Type::unreachable) { + call->type = newResults; + } + } + for (auto* callRef : info.callRefs) { + if (callRef->type != Type::unreachable) { + callRef->type = newResults; + } + } + } } if (newSignatures.empty()) { @@ -192,11 +238,18 @@ struct SignatureRefining : public Pass { auto iter = parent.newSignatures.find(oldSignatureType); if (iter != parent.newSignatures.end()) { sig.params = getTempType(iter->second.params); + sig.results = getTempType(iter->second.results); } } }; TypeRewriter(*module, *this).update(); + + if (refinedResults) { + // After return types change we need to propagate. + // TODO: we could do this only in relevant functions perhaps + ReFinalize().run(runner, module); + } } }; diff --git a/test/lit/passes/signature-refining.wast b/test/lit/passes/signature-refining.wast index 71eec54b2..fcffcc550 100644 --- a/test/lit/passes/signature-refining.wast +++ b/test/lit/passes/signature-refining.wast @@ -490,3 +490,137 @@ ) ) ) + +(module + ;; CHECK: (type $struct (struct_subtype data)) + (type $struct (struct_subtype data)) + + ;; This signature has a single function using it, which returns a more + ;; refined type, and we can refine to that. + ;; CHECK: (type $sig-can-refine (func_subtype (result (ref $struct)) func)) + (type $sig-can-refine (func_subtype (result anyref) func)) + + ;; Also a single function, but no refinement is possible. + ;; CHECK: (type $sig-cannot-refine (func_subtype (result anyref) func)) + (type $sig-cannot-refine (func_subtype (result anyref) func)) + + ;; The single function never returns, so no refinement is possible. + ;; CHECK: (type $sig-unreachable (func_subtype (result anyref) func)) + (type $sig-unreachable (func_subtype (result anyref) func)) + + ;; CHECK: (type $none_=>_none (func_subtype func)) + + ;; CHECK: (elem declare func $func-can-refine) + + ;; CHECK: (func $func-can-refine (type $sig-can-refine) (result (ref $struct)) + ;; CHECK-NEXT: (struct.new_default $struct) + ;; CHECK-NEXT: ) + (func $func-can-refine (type $sig-can-refine) (result anyref) + (struct.new $struct) + ) + + ;; CHECK: (func $func-cannot-refine (type $sig-cannot-refine) (result anyref) + ;; CHECK-NEXT: (ref.null any) + ;; CHECK-NEXT: ) + (func $func-cannot-refine (type $sig-cannot-refine) (result anyref) + (ref.null any) + ) + + ;; CHECK: (func $func-unreachable (type $sig-unreachable) (result anyref) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $func-unreachable (type $sig-unreachable) (result anyref) + (unreachable) + ) + + ;; CHECK: (func $caller (type $none_=>_none) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (if (result (ref $struct)) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: (call $func-can-refine) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (if (result (ref $struct)) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: (call_ref + ;; CHECK-NEXT: (ref.func $func-can-refine) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller + ;; Add a call to see that we update call types properly. + ;; Put the call in an if so the refinalize will update the if type and get + ;; printed out conveniently. + (drop + (if (result anyref) + (i32.const 1) + (call $func-can-refine) + (unreachable) + ) + ) + ;; The same with a call_ref. + (drop + (if (result anyref) + (i32.const 1) + (call_ref + (ref.func $func-can-refine) + ) + (unreachable) + ) + ) + ) +) + +(module + ;; CHECK: (type $struct (struct_subtype data)) + (type $struct (struct_subtype data)) + + ;; This signature has multiple functions using it, and some of them have nulls + ;; which should be updated when we refine. + ;; CHECK: (type $sig (func_subtype (result (ref null $struct)) func)) + (type $sig (func_subtype (result anyref) func)) + + ;; CHECK: (func $func-1 (type $sig) (result (ref null $struct)) + ;; CHECK-NEXT: (struct.new_default $struct) + ;; CHECK-NEXT: ) + (func $func-1 (type $sig) (result anyref) + (struct.new $struct) + ) + + ;; CHECK: (func $func-2 (type $sig) (result (ref null $struct)) + ;; CHECK-NEXT: (ref.null $struct) + ;; CHECK-NEXT: ) + (func $func-2 (type $sig) (result anyref) + (ref.null any) + ) + + ;; CHECK: (func $func-3 (type $sig) (result (ref null $struct)) + ;; CHECK-NEXT: (ref.null $struct) + ;; CHECK-NEXT: ) + (func $func-3 (type $sig) (result anyref) + (ref.null eq) + ) + + ;; CHECK: (func $func-4 (type $sig) (result (ref null $struct)) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (ref.null $struct) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $func-4 (type $sig) (result anyref) + (if + (i32.const 1) + (return + (ref.null any) + ) + ) + (unreachable) + ) +) |