diff options
author | Alon Zakai <azakai@google.com> | 2021-07-28 13:54:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-28 13:54:29 -0700 |
commit | 1ed257a587a30885f42f2c1b6170d662e741ae40 (patch) | |
tree | ad6cce346ea5a620b47944f5270b0fe7bba454a1 /src | |
parent | c5166f636c5835413046be76e26c362ef4bbecc5 (diff) | |
download | binaryen-1ed257a587a30885f42f2c1b6170d662e741ae40.tar.gz binaryen-1ed257a587a30885f42f2c1b6170d662e741ae40.tar.bz2 binaryen-1ed257a587a30885f42f2c1b6170d662e741ae40.zip |
[Wasm GC] Handle uses of default values in LocalSubtyping (#4024)
It is ok to use the default value of a reference even if we refine the type,
as it would be a more specifically-typed null, and all nulls compare the
same. However, if the default is used then we *cannot* alter the type to
be non-nullable, as then we'd use a null where that is not allowed.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/LocalSubtyping.cpp | 79 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 6 | ||||
-rw-r--r-- | src/tools/execution-results.h | 4 |
3 files changed, 59 insertions, 30 deletions
diff --git a/src/passes/LocalSubtyping.cpp b/src/passes/LocalSubtyping.cpp index 472463f44..ae197858c 100644 --- a/src/passes/LocalSubtyping.cpp +++ b/src/passes/LocalSubtyping.cpp @@ -24,6 +24,7 @@ #include <ir/find_all.h> #include <ir/linear-execution.h> +#include <ir/local-graph.h> #include <ir/utils.h> #include <pass.h> #include <wasm.h> @@ -35,29 +36,60 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> { Pass* create() override { return new LocalSubtyping(); } - // Shared code to find all sets or gets for each local index. Returns a vector - // that maps local indexes => vector of LocalGet*|LocalSet* expressions. - template<typename T> - std::vector<std::vector<T*>> getLocalOperations(Function* func) { - std::vector<std::vector<T*>> ret; - ret.resize(func->getNumLocals()); - FindAll<T> operations(func->body); - for (auto* operation : operations.list) { - ret[operation->index].push_back(operation); - } - return ret; - } - void doWalkFunction(Function* func) { if (!getModule()->features.hasGC()) { return; } - auto varBase = func->getVarIndexBase(); auto numLocals = func->getNumLocals(); - auto setsForLocal = getLocalOperations<LocalSet>(func); - auto getsForLocal = getLocalOperations<LocalGet>(func); + // Compute the local graph. We need to get the list of gets and sets for + // each local, so that we can do the analysis. For non-nullable locals, we + // also need to know when the default value of a local is used: if so then + // we cannot change that type, as if we change the local type to + // non-nullable then we'd be accessing the default, which is not allowed. + // + // TODO: Optimize this, as LocalGraph computes more than we need, and on + // more locals than we need. + LocalGraph localGraph(func); + + // For each local index, compute all the the sets and gets. + std::vector<std::vector<LocalSet*>> setsForLocal(numLocals); + std::vector<std::vector<LocalGet*>> getsForLocal(numLocals); + + for (auto& kv : localGraph.locations) { + auto* curr = kv.first; + if (auto* set = curr->dynCast<LocalSet>()) { + setsForLocal[set->index].push_back(set); + } else { + auto* get = curr->cast<LocalGet>(); + getsForLocal[get->index].push_back(get); + } + } + + // Find which vars use the default value, if we allow non-nullable locals. + // + // If that feature is not enabled, then we can safely assume that the + // default is never used - the default would be a null value, and the type + // of the null does not really matter as all nulls compare equally, so we do + // not need to worry. + std::unordered_set<Index> usesDefault; + + if (getModule()->features.hasGCNNLocals()) { + for (auto& kv : localGraph.getSetses) { + auto* get = kv.first; + auto& sets = kv.second; + auto index = get->index; + if (func->isVar(index) && + std::any_of(sets.begin(), sets.end(), [&](LocalSet* set) { + return set == nullptr; + })) { + usesDefault.insert(index); + } + } + } + + auto varBase = func->getVarIndexBase(); // Keep iterating while we find things to change. There can be chains like // X -> Y -> Z where one change enables more. Note that we are O(N^2) on @@ -85,10 +117,6 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> { for (Index i = varBase; i < numLocals; i++) { // Find all the types assigned to the var, and compute the optimal LUB. - // Note that we do not need to take into account the initial value of - // zero or null that locals have: that value has the type of the local, - // which is a supertype of all the assigned values anyhow. It will never - // be able to tell us of a more specific subtype that is possible. std::unordered_set<Type> types; for (auto* set : setsForLocal[i]) { types.insert(set->value->type); @@ -104,14 +132,11 @@ struct LocalSubtyping : public WalkerPass<PostWalker<LocalSubtyping>> { // Remove non-nullability if we disallow that in locals. if (newType.isNonNullable()) { - if (!getModule()->features.hasGCNNLocals()) { + // As mentioned earlier, even if we allow non-nullability, there may + // be a problem if the default value - a null - is used. In that case, + // remove non-nullability as well. + if (!getModule()->features.hasGCNNLocals() || usesDefault.count(i)) { newType = Type(newType.getHeapType(), Nullable); - // Note that the old type must have been nullable as well, as non- - // nullable types cannot be locals without that feature being - // enabled, which means that we will not have to do any extra work - // to handle non-nullability if we update the type: we are just - // updating the heap type, and leaving the type nullable as it was. - assert(oldType.isNullable()); } } else if (!newType.isDefaultable()) { // Aside from the case we just handled of allowed non-nullability, we diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 44227a50f..99741ebe3 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -351,8 +351,10 @@ private: if (set == nullptr) { if (getFunction()->isVar(get->index)) { auto localType = getFunction()->getLocalType(get->index); - assert(!localType.isNonNullable() && - "Non-nullable locals must not use the default value"); + if (localType.isNonNullable()) { + Fatal() << "Non-nullable local accessing the default value in " + << getFunction()->name << " (" << get->index << ')'; + } curr = Literal::makeZeros(localType); } else { // it's a param, so it's hopeless diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index aabbb4819..8fc4c2eed 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -144,7 +144,9 @@ struct ExecutionResults { } bool areEqual(Literal a, Literal b) { - if (a.type != b.type) { + // We allow nulls to have different types (as they compare equal regardless) + // but anything else must have an identical type. + if (a.type != b.type && !(a.isNull() && b.isNull())) { std::cout << "types not identical! " << a << " != " << b << '\n'; return false; } |