diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/subtypes.h | 23 | ||||
-rw-r--r-- | src/ir/type-updating.h | 70 | ||||
-rw-r--r-- | src/passes/AbstractTypeRefining.cpp | 298 | ||||
-rw-r--r-- | src/passes/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/passes/TypeMerging.cpp | 74 | ||||
-rw-r--r-- | src/passes/pass.cpp | 4 | ||||
-rw-r--r-- | src/passes/passes.h | 1 | ||||
-rw-r--r-- | src/wasm.h | 6 |
8 files changed, 395 insertions, 82 deletions
diff --git a/src/ir/subtypes.h b/src/ir/subtypes.h index d3a6dceaa..420bdcc1d 100644 --- a/src/ir/subtypes.h +++ b/src/ir/subtypes.h @@ -70,16 +70,12 @@ struct SubTypes { return ret; } - // Computes the depth of children for each type. This is 0 if the type has no - // subtypes, 1 if it has subtypes but none of those have subtypes themselves, - // and so forth. - // - // This depth ignores bottom types. - std::unordered_map<HeapType, Index> getMaxDepths() { - struct DepthSort : TopologicalSort<HeapType, DepthSort> { + // A topological sort that visits subtypes first. + auto getSubTypesFirstSort() const { + struct SubTypesFirstSort : TopologicalSort<HeapType, SubTypesFirstSort> { const SubTypes& parent; - DepthSort(const SubTypes& parent) : parent(parent) { + SubTypesFirstSort(const SubTypes& parent) : parent(parent) { for (auto type : parent.types) { // The roots are types with no supertype. if (!type.getSuperType()) { @@ -97,9 +93,18 @@ struct SubTypes { } }; + return SubTypesFirstSort(*this); + } + + // Computes the depth of children for each type. This is 0 if the type has no + // subtypes, 1 if it has subtypes but none of those have subtypes themselves, + // and so forth. + // + // This depth ignores bottom types. + std::unordered_map<HeapType, Index> getMaxDepths() { std::unordered_map<HeapType, Index> depths; - for (auto type : DepthSort(*this)) { + for (auto type : getSubTypesFirstSort()) { // Begin with depth 0, then take into account the subtype depths. Index depth = 0; for (auto subType : getStrictSubTypes(type)) { diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index e1b5e42a9..d224f93dc 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -405,6 +405,76 @@ private: InsertOrderedMap<HeapType, Index> typeIndices; }; +class TypeMapper : public GlobalTypeRewriter { +public: + using TypeUpdates = std::unordered_map<HeapType, HeapType>; + + const TypeUpdates& mapping; + + std::unordered_map<HeapType, Signature> newSignatures; + +public: + TypeMapper(Module& wasm, const TypeUpdates& mapping) + : GlobalTypeRewriter(wasm), mapping(mapping) {} + + void map() { + // Map the types of expressions (curr->type, etc.) to their merged + // types. + mapTypes(mapping); + + // Update the internals of types (struct fields, signatures, etc.) to + // refer to the merged types. + update(); + } + + Type getNewType(Type type) { + if (!type.isRef()) { + return type; + } + auto heapType = type.getHeapType(); + auto iter = mapping.find(heapType); + if (iter != mapping.end()) { + return getTempType(Type(iter->second, type.getNullability())); + } + return getTempType(type); + } + + void modifyStruct(HeapType oldType, Struct& struct_) override { + auto& oldFields = oldType.getStruct().fields; + for (Index i = 0; i < oldFields.size(); i++) { + auto& oldField = oldFields[i]; + auto& newField = struct_.fields[i]; + newField.type = getNewType(oldField.type); + } + } + void modifyArray(HeapType oldType, Array& array) override { + array.element.type = getNewType(oldType.getArray().element.type); + } + void modifySignature(HeapType oldSignatureType, Signature& sig) override { + auto getUpdatedTypeList = [&](Type type) { + std::vector<Type> vec; + for (auto t : type) { + vec.push_back(getNewType(t)); + } + return getTempTupleType(vec); + }; + + auto oldSig = oldSignatureType.getSignature(); + sig.params = getUpdatedTypeList(oldSig.params); + sig.results = getUpdatedTypeList(oldSig.results); + } + std::optional<HeapType> getSuperType(HeapType oldType) override { + // If the super is mapped, get it from the mapping. + auto super = oldType.getSuperType(); + if (super) { + if (auto it = mapping.find(*super); it != mapping.end()) { + return it->second; + } + } + return super; + } +}; + namespace TypeUpdating { // Checks whether a type is valid as a local, or whether diff --git a/src/passes/AbstractTypeRefining.cpp b/src/passes/AbstractTypeRefining.cpp new file mode 100644 index 000000000..1d3ff3f74 --- /dev/null +++ b/src/passes/AbstractTypeRefining.cpp @@ -0,0 +1,298 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +// Refine types based on global information about abstract types, that is, types +// that are not created anywhere (no struct.new etc.). +// +// In trapsNeverHappen mode, if we see a cast to $B and the type hierarchy is +// this: +// +// $A :> $B :> $C +// +// and $B has no struct.new instructions, and we are in closed world, then we +// can infer that the cast must be to $C. That is necessarily so since we will +// not trap by assumption, and $C or a subtype of it is all that remains +// possible. +// +// Even without trapsNeverHappen we can optimize certain cases. When we see a +// cast to a type that is never created, nor any subtype is created, then it +// must fail unless it allows null. +// + +#include "ir/module-utils.h" +#include "ir/subtypes.h" +#include "ir/type-updating.h" +#include "ir/utils.h" +#include "pass.h" +#include "wasm-type.h" +#include "wasm.h" + +namespace wasm { + +namespace { + +using Types = std::unordered_set<HeapType>; + +// Gather all types in StructNews. +struct NewFinder : public PostWalker<NewFinder> { + Types& types; + + NewFinder(Types& types) : types(types) {} + + void visitStructNew(StructNew* curr) { + auto type = curr->type; + if (type != Type::unreachable) { + types.insert(type.getHeapType()); + } + } +}; + +struct AbstractTypeRefining : public Pass { + // Changes types by refining them. We never add new non-nullable locals here + // (even if we refine a type to a bottom type, we only change the heap type + // there, not nullability). + bool requiresNonNullableLocalFixups() override { return false; } + + // The types that are created (have a struct.new). + Types createdTypes; + + // The types that are created, or have a subtype that is created. + Types createdTypesOrSubTypes; + + // A map of a cast type to refine and the type to refine it to. + TypeMapper::TypeUpdates refinableTypes; + + bool trapsNeverHappen; + + void run(Module* module) override { + if (!module->features.hasGC()) { + return; + } + + if (!getPassOptions().closedWorld) { + Fatal() << "AbstractTypeRefining requires --closed-world"; + } + + trapsNeverHappen = getPassOptions().trapsNeverHappen; + + // First, find all the created types (that have a struct.new) both in module + // code and in functions. + NewFinder(createdTypes).walkModuleCode(module); + + ModuleUtils::ParallelFunctionAnalysis<Types> analysis( + *module, [&](Function* func, Types& types) { + if (!func->imported()) { + NewFinder(types).walk(func->body); + } + }); + + for (auto& [_, types] : analysis.map) { + for (auto type : types) { + createdTypes.insert(type); + } + } + + SubTypes subTypes(*module); + + // Compute createdTypesOrSubTypes by starting with the created types and + // then propagating subtypes. + createdTypesOrSubTypes = createdTypes; + for (auto type : subTypes.getSubTypesFirstSort()) { + // If any of our subtypes are created, so are we. + for (auto subType : subTypes.getStrictSubTypes(type)) { + if (createdTypesOrSubTypes.count(subType)) { + createdTypesOrSubTypes.insert(type); + break; + } + } + } + + if (trapsNeverHappen) { + computeAbstractTypes(subTypes); + } + + // Use what we found about abstract types and never-created types to + // optimize. + optimize(module, subTypes); + } + + void computeAbstractTypes(const SubTypes& subTypes) { + // Abstract types are those with no news, i.e., the complement of + // |createdTypes|. As mentioned above, we can only optimize this case if + // traps never happen. + // TODO: We could do some of this even if traps are possible. If an abstract + // type has no casts at all, then no traps are relevant, and we could + // remove it from the module. That might also make sense in MergeTypes + // perhaps (which atm will not merge such types if they add fields, + // in particular). + Types abstractTypes; + for (auto type : subTypes.types) { + if (createdTypes.count(type) == 0) { + abstractTypes.insert(type); + } + } + + // We found abstract types. Next, find which of them are refinable. We + // need an abstract type to have a single subtype, to which we will switch + // all of their casts. + // + // Do this depth-first, so that we visit subtypes first. That will handle + // chains where we want to refine a type A to a subtype of a subtype of + // it. + for (auto type : subTypes.getSubTypesFirstSort()) { + if (!abstractTypes.count(type)) { + continue; + } + + std::optional<HeapType> refinedType; + auto& typeSubTypes = subTypes.getStrictSubTypes(type); + if (typeSubTypes.size() == 1) { + // There is only a single possibility, so we can definitely use that + /// one. + refinedType = typeSubTypes[0]; + } else if (!typeSubTypes.empty()) { + // There are multiple possibilities. However, perhaps only one of them + // is relevant, if nothing is ever created of the others or their + // subtypes. + for (auto subType : typeSubTypes) { + if (createdTypesOrSubTypes.count(subType)) { + if (!refinedType) { + // This is the first relevant thing, and hopefully will remain + // the only one. + refinedType = subType; + } else { + // We've seen more than one as relevant, so we have failed to + // find a singleton. + refinedType = std::nullopt; + break; + } + } + } + } + if (refinedType) { + // Propagate anything from the child, to handle chains. + auto iter = refinableTypes.find(*refinedType); + if (iter != refinableTypes.end()) { + *refinedType = iter->second; + } + + refinableTypes[type] = *refinedType; + } + } + } + + void optimize(Module* module, const SubTypes& subTypes) { + // To optimize we rewrite types. That is, if we want to optimize all casts + // of $A to instead cast to the refined type $B, we can do that by simply + // replacing all appearances of $A with $B. That is possible here since we + // only optimize when we know $A is never created, and we are removing all + // casts to it, which means no other references to it are needed - so we can + // just rewrite all references to $A to point to $B. Doing such a rewrite + // will also remove the unneeded type from the type section, which is nice + // for code size. + // + // Even though this pass removes types, it does not on its own inhibit + // further optimizations. In more detail, a possible issue could have been + // something like this: imagine that we replace all $A with $B, and we had + // types like this: + // + // $C = [.., $A, ..] + // $D = [.., $B, ..] + // + // After replacing $A with $B, we cause $C and $D to be structurally + // identical. If we merged $C and $D then we might lose some optimization + // potential (perhaps different values are written to each, and GUFA or + // another pass can optimize each separately, but not if they were merged). + // However, the type rewriter will create a single new rec group for all new + // types anyhow, so they all remain distinct from each other. The only thing + // that would actually merge them is if we run TypeMerging, which is not run + // by default exactly for this reason, that it can limit optimizations. + // Thus, this pass does only "safe" merging, that cannot limit later + // optimizations - merging $A and $B is of course fine as one of them was + // not even used anywhere. + + TypeMapper::TypeUpdates mapping; + + for (auto type : subTypes.types) { + if (!type.isStruct()) { + // TODO: support arrays and funcs + continue; + } + + // Add a mapping of types that are never created (and none of their + // subtypes) to the bottom type. This is valid because all locations of + // that type, like a local variable, will only contain null at runtime. + // Likewise, if we have a ref.test of such a type, we can only be looking + // for a null at best. This can be seen as "refining" uses of these + // never-created types to the bottom type. + // + // We check this first as it is the most powerful change. + if (createdTypesOrSubTypes.count(type) == 0) { + mapping[type] = type.getBottom(); + continue; + } + + // Otherwise, apply a refining if we found one before. + if (auto iter = refinableTypes.find(type); iter != refinableTypes.end()) { + mapping[type] = iter->second; + } + } + + if (mapping.empty()) { + return; + } + + // A TypeMapper that handles the patterns we have in our mapping, where we + // end up mapping a type to a *subtype*. We need to properly create + // supertypes while doing this rewriting. For example, say we have this: + // + // A :> B :> C + // + // Say we see B is never created, so we want to map B to its subtype C. C's + // supertype must now be A. + class AbstractTypeRefiningTypeMapper : public TypeMapper { + public: + AbstractTypeRefiningTypeMapper(Module& wasm, const TypeUpdates& mapping) + : TypeMapper(wasm, mapping) {} + + std::optional<HeapType> getSuperType(HeapType oldType) override { + auto super = oldType.getSuperType(); + + // Go up the chain of supertypes, skipping things we are mapping away, + // as those things will not appear in the output. This skips B in the + // example above. + while (super && mapping.count(*super)) { + super = super->getSuperType(); + } + return super; + } + }; + + AbstractTypeRefiningTypeMapper(*module, mapping).map(); + + // Refinalize to propagate the type changes we made. For example, a refined + // cast may lead to a struct.get reading a more refined type using that + // type. + ReFinalize().run(getPassRunner(), module); + } +}; + +} // anonymous namespace + +Pass* createAbstractTypeRefiningPass() { return new AbstractTypeRefining(); } + +} // namespace wasm diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index f29e7db9b..9d16f9782 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -16,6 +16,7 @@ set(passes_SOURCES param-utils.cpp pass.cpp test_passes.cpp + AbstractTypeRefining.cpp AlignmentLowering.cpp Asyncify.cpp AvoidReinterprets.cpp diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp index 2ef6f8f54..d3652d010 100644 --- a/src/passes/TypeMerging.cpp +++ b/src/passes/TypeMerging.cpp @@ -110,8 +110,6 @@ struct CastFinder : public PostWalker<CastFinder> { // refine the partitions so that types that turn out to not be mergeable will be // split out into separate partitions. struct TypeMerging : public Pass { - using TypeUpdates = std::unordered_map<HeapType, HeapType>; - // Only modifies types. bool requiresNonNullableLocalFixups() override { return false; } @@ -125,7 +123,7 @@ struct TypeMerging : public Pass { CastTypes findCastTypes(); std::vector<HeapType> getPublicChildren(HeapType type); DFA::State<HeapType> makeDFAState(HeapType type); - void applyMerges(const TypeUpdates& merges); + void applyMerges(const TypeMapper::TypeUpdates& merges); }; // Hash and equality-compare HeapTypes based on their top-level structure (i.e. @@ -285,7 +283,7 @@ void TypeMerging::run(Module* module_) { auto refinedPartitions = DFA::refinePartitions(dfa); // The types we can merge mapped to the type we are merging them into. - TypeUpdates merges; + TypeMapper::TypeUpdates merges; // Merge each refined partition into a single type. We should only merge into // supertypes or siblings because if we try to merge into a subtype then we @@ -366,78 +364,14 @@ DFA::State<HeapType> TypeMerging::makeDFAState(HeapType type) { return {type, std::move(succs)}; } -void TypeMerging::applyMerges(const TypeUpdates& merges) { +void TypeMerging::applyMerges(const TypeMapper::TypeUpdates& merges) { if (merges.empty()) { return; } // We found things to optimize! Rewrite types in the module to apply those // changes. - - class TypeInternalsUpdater : public GlobalTypeRewriter { - const TypeUpdates& merges; - - std::unordered_map<HeapType, Signature> newSignatures; - - public: - TypeInternalsUpdater(Module& wasm, const TypeUpdates& merges) - : GlobalTypeRewriter(wasm), merges(merges) { - - // Map the types of expressions (curr->type, etc.) to their merged - // types. - mapTypes(merges); - - // Update the internals of types (struct fields, signatures, etc.) to - // refer to the merged types. - update(); - } - - Type getNewType(Type type) { - if (!type.isRef()) { - return type; - } - auto heapType = type.getHeapType(); - auto iter = merges.find(heapType); - if (iter != merges.end()) { - return getTempType(Type(iter->second, type.getNullability())); - } - return getTempType(type); - } - - void modifyStruct(HeapType oldType, Struct& struct_) override { - auto& oldFields = oldType.getStruct().fields; - for (Index i = 0; i < oldFields.size(); i++) { - auto& oldField = oldFields[i]; - auto& newField = struct_.fields[i]; - newField.type = getNewType(oldField.type); - } - } - void modifyArray(HeapType oldType, Array& array) override { - array.element.type = getNewType(oldType.getArray().element.type); - } - void modifySignature(HeapType oldSignatureType, Signature& sig) override { - auto getUpdatedTypeList = [&](Type type) { - std::vector<Type> vec; - for (auto t : type) { - vec.push_back(getNewType(t)); - } - return getTempTupleType(vec); - }; - - auto oldSig = oldSignatureType.getSignature(); - sig.params = getUpdatedTypeList(oldSig.params); - sig.results = getUpdatedTypeList(oldSig.results); - } - std::optional<HeapType> getSuperType(HeapType oldType) override { - auto super = oldType.getSuperType(); - if (super) { - if (auto it = merges.find(*super); it != merges.end()) { - return it->second; - } - } - return super; - } - } rewriter(*module, merges); + TypeMapper(*module, merges).map(); } bool shapeEq(HeapType a, HeapType b) { diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 13d700af9..d567fee95 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -101,6 +101,9 @@ void PassRegistry::registerPasses() { "removes arguments to calls in an lto-like manner, and " "optimizes where we removed", createDAEOptimizingPass); + registerPass("abstract-type-refining", + "refine and merge abstract (never-created) types", + createAbstractTypeRefiningPass); registerPass("coalesce-locals", "reduce # of locals by coalescing", createCoalesceLocalsPass); @@ -626,6 +629,7 @@ void PassRunner::addDefaultGlobalOptimizationPrePasses() { addIfNoDWARFIssues("remove-unused-types"); addIfNoDWARFIssues("cfp"); addIfNoDWARFIssues("gsi"); + addIfNoDWARFIssues("abstract-type-refining"); } } // TODO: generate-global-effects here, right before function passes, then diff --git a/src/passes/passes.h b/src/passes/passes.h index 4d6da0ca9..346741b1a 100644 --- a/src/passes/passes.h +++ b/src/passes/passes.h @@ -22,6 +22,7 @@ namespace wasm { class Pass; // Normal passes: +Pass* createAbstractTypeRefiningPass(); Pass* createAlignmentLoweringPass(); Pass* createAsyncifyPass(); Pass* createAvoidReinterpretsPass(); diff --git a/src/wasm.h b/src/wasm.h index 779efe991..931c560c1 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1514,7 +1514,7 @@ public: void finalize(); - Type getCastType() { return castType; } + Type& getCastType() { return castType; } }; class RefCast : public SpecificExpression<Expression::RefCastId> { @@ -1530,7 +1530,7 @@ public: void finalize(); - Type getCastType() { return type; } + Type& getCastType() { return type; } }; class BrOn : public SpecificExpression<Expression::BrOnId> { @@ -1544,7 +1544,7 @@ public: void finalize(); - Type getCastType() { return castType; } + Type& getCastType() { return castType; } // Returns the type sent on the branch, if it is taken. Type getSentType(); |