summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2021-12-14 13:58:12 -0800
committerGitHub <noreply@github.com>2021-12-14 13:58:12 -0800
commit01e6429a13abd2d8f5b3d4019a273e992fdd3a66 (patch)
tree4437eb47d7c6e3cb5c9b63a857e7cbff60212656
parente5594bbde36ac96b6f2a41259d8a8d66e5d0a7cc (diff)
downloadbinaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.gz
binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.tar.bz2
binaryen-01e6429a13abd2d8f5b3d4019a273e992fdd3a66.zip
[Wasm GC] Refine results in SignatureRefining (#4380)
Similar to what DeadArgumentElimination does for individual functions, this can refine the results of a set of functions all using the same heap type, when they all return something more specific. After this PR SignatureRefining can refine both params and results and is basically complete.
-rw-r--r--src/passes/SignatureRefining.cpp87
-rw-r--r--test/lit/passes/signature-refining.wast134
2 files changed, 204 insertions, 17 deletions
diff --git a/src/passes/SignatureRefining.cpp b/src/passes/SignatureRefining.cpp
index 8e8ecfbe2..623a393c1 100644
--- a/src/passes/SignatureRefining.cpp
+++ b/src/passes/SignatureRefining.cpp
@@ -25,13 +25,12 @@
// so while considering all users of the type (across all functions sharing that
// type, and all call_refs using it).
//
-// TODO: optimize results too and not just params.
-//
#include "ir/find_all.h"
#include "ir/lubs.h"
#include "ir/module-utils.h"
#include "ir/type-updating.h"
+#include "ir/utils.h"
#include "pass.h"
#include "wasm-type.h"
#include "wasm.h"
@@ -62,43 +61,55 @@ struct SignatureRefining : public Pass {
return;
}
- // First, find all the calls and call_refs.
+ // First, find all the information we need. Start by collecting inside each
+ // function in parallel.
- struct CallInfo {
+ struct Info {
+ // The calls and call_refs.
std::vector<Call*> calls;
std::vector<CallRef*> callRefs;
+
+ // A possibly improved LUB for the results.
+ LUBFinder resultsLUB;
};
- ModuleUtils::ParallelFunctionAnalysis<CallInfo> analysis(
- *module, [&](Function* func, CallInfo& info) {
+ ModuleUtils::ParallelFunctionAnalysis<Info> analysis(
+ *module, [&](Function* func, Info& info) {
if (func->imported()) {
return;
}
info.calls = std::move(FindAll<Call>(func->body).list);
info.callRefs = std::move(FindAll<CallRef>(func->body).list);
+ info.resultsLUB = LUB::getResultsLUB(func, *module);
});
- // A map of types to the calls and call_refs that use that type.
- std::unordered_map<HeapType, CallInfo> allCallsTo;
+ // A map of types to all the information combined over all the functions
+ // with that type.
+ std::unordered_map<HeapType, Info> allInfo;
// Combine all the information we gathered into that map.
for (auto& [func, info] : analysis.map) {
// For direct calls, add each call to the type of the function being
// called.
for (auto* call : info.calls) {
- allCallsTo[module->getFunction(call->target)->type].calls.push_back(
- call);
+ allInfo[module->getFunction(call->target)->type].calls.push_back(call);
}
// For indirect calls, add each call_ref to the type the call_ref uses.
for (auto* callRef : info.callRefs) {
auto calledType = callRef->target->type;
if (calledType != Type::unreachable) {
- allCallsTo[calledType.getHeapType()].callRefs.push_back(callRef);
+ allInfo[calledType.getHeapType()].callRefs.push_back(callRef);
}
}
+
+ // Add the function's return LUB to the one for the heap type of that
+ // function.
+ allInfo[func->type].resultsLUB.combine(info.resultsLUB);
}
+ bool refinedResults = false;
+
// Compute optimal LUBs.
std::unordered_set<HeapType> seen;
for (auto& func : module->functions) {
@@ -118,11 +129,11 @@ struct SignatureRefining : public Pass {
}
};
- auto& callsTo = allCallsTo[type];
- for (auto* call : callsTo.calls) {
+ auto& info = allInfo[type];
+ for (auto* call : info.calls) {
updateLUBs(call->operands);
}
- for (auto* callRef : callsTo.callRefs) {
+ for (auto* callRef : info.callRefs) {
updateLUBs(callRef->operands);
}
@@ -134,20 +145,55 @@ struct SignatureRefining : public Pass {
}
newParamsTypes.push_back(lub.getBestPossible());
}
+ Type newParams;
if (newParamsTypes.size() < numParams) {
// We did not have type information to calculate a LUB (no calls, or
// some param is always unreachable), so there is nothing we can improve
// here. Other passes might remove the type entirely.
+ newParams = func->getParams();
+ } else {
+ newParams = Type(newParamsTypes);
+ }
+
+ auto& resultsLUB = info.resultsLUB;
+ Type newResults;
+ if (!resultsLUB.noted()) {
+ // We did not have type information to calculate a LUB (no returned
+ // value, or it can return a value but traps instead etc.).
+ newResults = func->getResults();
+ } else {
+ newResults = resultsLUB.getBestPossible();
+ }
+
+ if (newParams == func->getParams() && newResults == func->getResults()) {
continue;
}
- auto newParams = Type(newParamsTypes);
+
+ // We found an improvement!
+ newSignatures[type] = Signature(newParams, newResults);
+
+ // Update nulls as necessary, now that we are changing things.
if (newParams != func->getParams()) {
- // We found an improvement!
- newSignatures[type] = Signature(newParams, Type::none);
for (auto& lub : paramLUBs) {
lub.updateNulls();
}
}
+ if (newResults != func->getResults()) {
+ resultsLUB.updateNulls();
+ refinedResults = true;
+
+ // Update the types of calls using the signature.
+ for (auto* call : info.calls) {
+ if (call->type != Type::unreachable) {
+ call->type = newResults;
+ }
+ }
+ for (auto* callRef : info.callRefs) {
+ if (callRef->type != Type::unreachable) {
+ callRef->type = newResults;
+ }
+ }
+ }
}
if (newSignatures.empty()) {
@@ -192,11 +238,18 @@ struct SignatureRefining : public Pass {
auto iter = parent.newSignatures.find(oldSignatureType);
if (iter != parent.newSignatures.end()) {
sig.params = getTempType(iter->second.params);
+ sig.results = getTempType(iter->second.results);
}
}
};
TypeRewriter(*module, *this).update();
+
+ if (refinedResults) {
+ // After return types change we need to propagate.
+ // TODO: we could do this only in relevant functions perhaps
+ ReFinalize().run(runner, module);
+ }
}
};
diff --git a/test/lit/passes/signature-refining.wast b/test/lit/passes/signature-refining.wast
index 71eec54b2..fcffcc550 100644
--- a/test/lit/passes/signature-refining.wast
+++ b/test/lit/passes/signature-refining.wast
@@ -490,3 +490,137 @@
)
)
)
+
+(module
+ ;; CHECK: (type $struct (struct_subtype data))
+ (type $struct (struct_subtype data))
+
+ ;; This signature has a single function using it, which returns a more
+ ;; refined type, and we can refine to that.
+ ;; CHECK: (type $sig-can-refine (func_subtype (result (ref $struct)) func))
+ (type $sig-can-refine (func_subtype (result anyref) func))
+
+ ;; Also a single function, but no refinement is possible.
+ ;; CHECK: (type $sig-cannot-refine (func_subtype (result anyref) func))
+ (type $sig-cannot-refine (func_subtype (result anyref) func))
+
+ ;; The single function never returns, so no refinement is possible.
+ ;; CHECK: (type $sig-unreachable (func_subtype (result anyref) func))
+ (type $sig-unreachable (func_subtype (result anyref) func))
+
+ ;; CHECK: (type $none_=>_none (func_subtype func))
+
+ ;; CHECK: (elem declare func $func-can-refine)
+
+ ;; CHECK: (func $func-can-refine (type $sig-can-refine) (result (ref $struct))
+ ;; CHECK-NEXT: (struct.new_default $struct)
+ ;; CHECK-NEXT: )
+ (func $func-can-refine (type $sig-can-refine) (result anyref)
+ (struct.new $struct)
+ )
+
+ ;; CHECK: (func $func-cannot-refine (type $sig-cannot-refine) (result anyref)
+ ;; CHECK-NEXT: (ref.null any)
+ ;; CHECK-NEXT: )
+ (func $func-cannot-refine (type $sig-cannot-refine) (result anyref)
+ (ref.null any)
+ )
+
+ ;; CHECK: (func $func-unreachable (type $sig-unreachable) (result anyref)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $func-unreachable (type $sig-unreachable) (result anyref)
+ (unreachable)
+ )
+
+ ;; CHECK: (func $caller (type $none_=>_none)
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (if (result (ref $struct))
+ ;; CHECK-NEXT: (i32.const 1)
+ ;; CHECK-NEXT: (call $func-can-refine)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (if (result (ref $struct))
+ ;; CHECK-NEXT: (i32.const 1)
+ ;; CHECK-NEXT: (call_ref
+ ;; CHECK-NEXT: (ref.func $func-can-refine)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $caller
+ ;; Add a call to see that we update call types properly.
+ ;; Put the call in an if so the refinalize will update the if type and get
+ ;; printed out conveniently.
+ (drop
+ (if (result anyref)
+ (i32.const 1)
+ (call $func-can-refine)
+ (unreachable)
+ )
+ )
+ ;; The same with a call_ref.
+ (drop
+ (if (result anyref)
+ (i32.const 1)
+ (call_ref
+ (ref.func $func-can-refine)
+ )
+ (unreachable)
+ )
+ )
+ )
+)
+
+(module
+ ;; CHECK: (type $struct (struct_subtype data))
+ (type $struct (struct_subtype data))
+
+ ;; This signature has multiple functions using it, and some of them have nulls
+ ;; which should be updated when we refine.
+ ;; CHECK: (type $sig (func_subtype (result (ref null $struct)) func))
+ (type $sig (func_subtype (result anyref) func))
+
+ ;; CHECK: (func $func-1 (type $sig) (result (ref null $struct))
+ ;; CHECK-NEXT: (struct.new_default $struct)
+ ;; CHECK-NEXT: )
+ (func $func-1 (type $sig) (result anyref)
+ (struct.new $struct)
+ )
+
+ ;; CHECK: (func $func-2 (type $sig) (result (ref null $struct))
+ ;; CHECK-NEXT: (ref.null $struct)
+ ;; CHECK-NEXT: )
+ (func $func-2 (type $sig) (result anyref)
+ (ref.null any)
+ )
+
+ ;; CHECK: (func $func-3 (type $sig) (result (ref null $struct))
+ ;; CHECK-NEXT: (ref.null $struct)
+ ;; CHECK-NEXT: )
+ (func $func-3 (type $sig) (result anyref)
+ (ref.null eq)
+ )
+
+ ;; CHECK: (func $func-4 (type $sig) (result (ref null $struct))
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (i32.const 1)
+ ;; CHECK-NEXT: (return
+ ;; CHECK-NEXT: (ref.null $struct)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $func-4 (type $sig) (result anyref)
+ (if
+ (i32.const 1)
+ (return
+ (ref.null any)
+ )
+ )
+ (unreachable)
+ )
+)