summaryrefslogtreecommitdiff
path: root/src/passes/LocalSubtyping.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2024-09-10 09:54:51 -0700
committerGitHub <noreply@github.com>2024-09-10 09:54:51 -0700
commit203dcd5c47d6ea784e613f647f8addd9815a3d5b (patch)
tree1212721bf13e0cc945ac5fea9c559ab00f129779 /src/passes/LocalSubtyping.cpp
parent2467e70524c96481c34e5ac23b9f068eb60abcbf (diff)
downloadbinaryen-203dcd5c47d6ea784e613f647f8addd9815a3d5b.tar.gz
binaryen-203dcd5c47d6ea784e613f647f8addd9815a3d5b.tar.bz2
binaryen-203dcd5c47d6ea784e613f647f8addd9815a3d5b.zip
[NFC-ish] Remove LocalGraph from LocalSubtyping (#6921)
The LocalGraph there was used for two purposes: 1. Get the list of gets and sets. 2. Get only the reachable gets and sets. It is trivial to get all the gets and sets in a much faster way, by just walking the code as this PR does. The downside is that we also consider unreachable gets and sets, so unreachable code can prevent us from optimizing, but that seems worthwhile as many passes make that assumption (and they all become maximally effective after --dce). That is the only non-NFC part here. Removing LocalGraph + the fixup code for unreachability makes this significantly shorter, and also 2-3x faster.
Diffstat (limited to 'src/passes/LocalSubtyping.cpp')
-rw-r--r--src/passes/LocalSubtyping.cpp111
1 files changed, 42 insertions, 69 deletions
diff --git a/src/passes/LocalSubtyping.cpp b/src/passes/LocalSubtyping.cpp
index ad0cfa7d2..7b30e3538 100644
--- a/src/passes/LocalSubtyping.cpp
+++ b/src/passes/LocalSubtyping.cpp
@@ -50,32 +50,49 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> {
return;
}
- auto numLocals = func->getNumLocals();
+ // Compute the list of gets and sets for each local.
+ struct Scanner : public PostWalker<Scanner> {
+ // Which locals are relevant for us (we can ignore non-references).
+ std::vector<bool> relevant;
+
+ // The lists of gets and sets.
+ std::vector<std::vector<LocalSet*>> setsForLocal;
+ std::vector<std::vector<LocalGet*>> getsForLocal;
+
+ Scanner(Function* func) {
+ auto numLocals = func->getNumLocals();
+ relevant.resize(numLocals);
+ setsForLocal.resize(numLocals);
+ getsForLocal.resize(numLocals);
+
+ for (Index i = 0; i < numLocals; i++) {
+ // TODO: Ignore params here? That may require changes below.
+ if (func->getLocalType(i).isRef()) {
+ relevant[i] = true;
+ }
+ }
- // Compute the local graph. We need to get the list of gets and sets for
- // each local, so that we can do the analysis. For non-nullable locals, we
- // also need to know when the default value of a local is used: if so then
- // we cannot change that type, as if we change the local type to
- // non-nullable then we'd be accessing the default, which is not allowed.
- //
- // TODO: Optimize this, as LocalGraph computes more than we need, and on
- // more locals than we need.
- LocalGraph localGraph(func, getModule());
-
- // For each local index, compute all the the sets and gets.
- std::vector<std::vector<LocalSet*>> setsForLocal(numLocals);
- std::vector<std::vector<LocalGet*>> getsForLocal(numLocals);
-
- for (auto& [curr, _] : localGraph.locations) {
- if (auto* set = curr->dynCast<LocalSet>()) {
- setsForLocal[set->index].push_back(set);
- } else {
- auto* get = curr->cast<LocalGet>();
- getsForLocal[get->index].push_back(get);
+ walk(func->body);
}
- }
- // Find which vars can be non-nullable.
+ void visitLocalGet(LocalGet* curr) {
+ if (relevant[curr->index]) {
+ getsForLocal[curr->index].push_back(curr);
+ }
+ }
+
+ void visitLocalSet(LocalSet* curr) {
+ if (relevant[curr->index]) {
+ setsForLocal[curr->index].push_back(curr);
+ }
+ }
+ } scanner(func);
+
+ auto& setsForLocal = scanner.setsForLocal;
+ auto& getsForLocal = scanner.getsForLocal;
+
+ // Find which vars can be non-nullable (if a null is written, or the default
+ // null is used, then a local cannot become non-nullable).
std::unordered_set<Index> cannotBeNonNullable;
// All gets must be dominated structurally by sets for the local to be non-
@@ -98,7 +115,8 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> {
// TODO: handle cycles of X -> Y -> X etc.
bool more;
- bool optimized = false;
+
+ auto numLocals = func->getNumLocals();
do {
more = false;
@@ -148,7 +166,6 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> {
assert(Type::isSubType(newType, oldType));
func->vars[i - varBase] = newType;
more = true;
- optimized = true;
// Update gets and tees.
for (auto* get : getsForLocal[i]) {
@@ -166,50 +183,6 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> {
}
}
} while (more);
-
- // If we ever optimized, then we also need to do a final pass to update any
- // unreachable gets and tees. They are not seen or updated in the above
- // analysis, but must be fixed up for validation to work.
- if (optimized) {
- for (auto* get : FindAll<LocalGet>(func->body).list) {
- get->type = func->getLocalType(get->index);
- }
- for (auto* set : FindAll<LocalSet>(func->body).list) {
- auto newType = func->getLocalType(set->index);
- if (set->isTee()) {
- set->type = newType;
- set->finalize();
- }
-
- // If this set was not processed earlier - that is, if it is in
- // unreachable code - then it may have an incompatible type. That is,
- // If we saw a reachable set that writes type A, and this set writes
- // type B, we may have specialized the local type to A, but the value
- // of type B in this unreachable set is no longer valid to write to
- // that local. In such a case we must do additional work.
- if (!Type::isSubType(set->value->type, newType)) {
- // The type is incompatible. To fix this, replace
- //
- // (set (bad-value))
- //
- // with
- //
- // (set (block
- // (drop (bad-value))
- // (unreachable)
- // ))
- //
- // (We cannot just ignore the bad value, as it may contain a break to
- // a target that is necessary for validation.)
- Builder builder(*getModule());
- set->value = builder.makeSequence(builder.makeDrop(set->value),
- builder.makeUnreachable());
- }
- }
-
- // Also update their parents.
- ReFinalize().walkFunctionInModule(func, getModule());
- }
}
};