diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2022-01-14 11:50:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-14 11:50:52 -0800 |
commit | 8e8284e6464d524bd9091f21a62982ed54df0093 (patch) | |
tree | 13b88033d4e5d2f8d53289807efdd6b88055e0be | |
parent | 80329023c30ca108b0a8ce1b3939f5e9a96250bb (diff) | |
download | binaryen-8e8284e6464d524bd9091f21a62982ed54df0093.tar.gz binaryen-8e8284e6464d524bd9091f21a62982ed54df0093.tar.bz2 binaryen-8e8284e6464d524bd9091f21a62982ed54df0093.zip |
Refactor ModuleUtils::collectHeapTypes (#4455)
Update the API to make both the type indices and optimized sorting optional.
It will become more important to avoid unnecessary sorting once isorecursive
types have been implemented because they will make the sorting more complicated.
-rw-r--r-- | src/ir/module-utils.cpp | 65 | ||||
-rw-r--r-- | src/ir/module-utils.h | 27 | ||||
-rw-r--r-- | src/ir/subtypes.h | 3 | ||||
-rw-r--r-- | src/ir/type-updating.cpp | 31 | ||||
-rw-r--r-- | src/ir/type-updating.h | 8 | ||||
-rw-r--r-- | src/passes/NameTypes.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 8 | ||||
-rw-r--r-- | src/wasm-binary.h | 3 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 22 |
9 files changed, 103 insertions, 68 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index c738c6143..9b2293984 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -19,22 +19,23 @@ namespace wasm::ModuleUtils { -void collectHeapTypes(Module& wasm, - std::vector<HeapType>& types, - std::unordered_map<HeapType, Index>& typeIndices) { - struct Counts : public InsertOrderedMap<HeapType, size_t> { - void note(HeapType type) { - if (!type.isBasic()) { - (*this)[type]++; - } +namespace { + +// Helper for collecting HeapTypes and their frequencies. +struct Counts : public InsertOrderedMap<HeapType, size_t> { + void note(HeapType type) { + if (!type.isBasic()) { + (*this)[type]++; } - void note(Type type) { - for (HeapType ht : type.getHeapTypeChildren()) { - note(ht); - } + } + void note(Type type) { + for (HeapType ht : type.getHeapTypeChildren()) { + note(ht); } - }; + } +}; +Counts getHeapTypeCounts(Module& wasm) { struct CodeScanner : PostWalker<CodeScanner, UnifiedExpressionVisitor<CodeScanner>> { Counts& counts; @@ -160,15 +161,49 @@ void collectHeapTypes(Module& wasm, } } + return counts; +} + +} // anonymous namespace + +std::vector<HeapType> collectHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + std::vector<HeapType> types; + types.reserve(counts.size()); + for (auto& [type, _] : counts) { + types.push_back(type); + } + return types; +} + +IndexedHeapTypes getIndexedHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + IndexedHeapTypes indexedTypes; + Index i = 0; + for (auto& [type, _] : counts) { + indexedTypes.types.push_back(type); + indexedTypes.indices[type] = i++; + } + return indexedTypes; +} + +IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + // Sort by frequency and then original insertion order. std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), counts.end()); std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { return a.second > b.second; }); + + // Collect the results. + IndexedHeapTypes indexedTypes; for (Index i = 0; i < sorted.size(); ++i) { - typeIndices[sorted[i].first] = i; - types.push_back(sorted[i].first); + HeapType type = sorted[i].first; + indexedTypes.types.push_back(type); + indexedTypes.indices[type] = i; } + return indexedTypes; } } // namespace wasm::ModuleUtils diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index d645ae5f7..c0b4dbcfd 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -455,17 +455,22 @@ template<typename T> struct CallGraphPropertyAnalysis { } }; -// Helper function for collecting all the types that are declared in a module, -// which means the HeapTypes (that are non-basic, that is, not eqref etc., which -// do not need to be defined). -// -// Used when emitting or printing a module to give HeapTypes canonical -// indices. HeapTypes are sorted in order of decreasing frequency to minize the -// size of their collective encoding. Both a vector mapping indices to -// HeapTypes and a map mapping HeapTypes to indices are produced. -void collectHeapTypes(Module& wasm, - std::vector<HeapType>& types, - std::unordered_map<HeapType, Index>& typeIndices); +// Helper function for collecting all the non-basic heap types used in the +// module, i.e. the types that would appear in the type section. +std::vector<HeapType> collectHeapTypes(Module& wasm); + +struct IndexedHeapTypes { + std::vector<HeapType> types; + std::unordered_map<HeapType, Index> indices; +}; + +// Similar to `collectHeapTypes`, but provides fast lookup of the index for each +// type as well. +IndexedHeapTypes getIndexedHeapTypes(Module& wasm); + +// The same as `getIndexedHeapTypes`, but also sorts the types by frequency of +// use to minimize code size. +IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm); } // namespace wasm::ModuleUtils diff --git a/src/ir/subtypes.h b/src/ir/subtypes.h index 64697e2a3..2315934d7 100644 --- a/src/ir/subtypes.h +++ b/src/ir/subtypes.h @@ -26,8 +26,7 @@ namespace wasm { // them. struct SubTypes { SubTypes(Module& wasm) { - std::unordered_map<HeapType, Index> typeIndices; - ModuleUtils::collectHeapTypes(wasm, types, typeIndices); + types = ModuleUtils::collectHeapTypes(wasm); for (auto type : types) { note(type); } diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 2920a4a95..3516a10a9 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -26,15 +26,15 @@ namespace wasm { GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} void GlobalTypeRewriter::update() { - ModuleUtils::collectHeapTypes(wasm, types, typeIndices); - if (types.empty()) { + indexedTypes = ModuleUtils::getIndexedHeapTypes(wasm); + if (indexedTypes.types.empty()) { return; } - typeBuilder.grow(types.size()); + typeBuilder.grow(indexedTypes.types.size()); // Create the temporary heap types. - for (Index i = 0; i < types.size(); i++) { - auto type = types[i]; + for (Index i = 0; i < indexedTypes.types.size(); i++) { + auto type = indexedTypes.types[i]; if (type.isSignature()) { auto sig = type.getSignature(); TypeList newParams, newResults; @@ -46,7 +46,7 @@ void GlobalTypeRewriter::update() { } Signature newSig(typeBuilder.getTempTupleType(newParams), typeBuilder.getTempTupleType(newResults)); - modifySignature(types[i], newSig); + modifySignature(indexedTypes.types[i], newSig); typeBuilder.setHeapType(i, newSig); } else if (type.isStruct()) { auto struct_ = type.getStruct(); @@ -55,14 +55,14 @@ void GlobalTypeRewriter::update() { for (auto& field : newStruct.fields) { field.type = getTempType(field.type); } - modifyStruct(types[i], newStruct); + modifyStruct(indexedTypes.types[i], newStruct); typeBuilder.setHeapType(i, newStruct); } else if (type.isArray()) { auto array = type.getArray(); // Start with a copy to get mutability/packing/etc. auto newArray = array; newArray.element.type = getTempType(newArray.element.type); - modifyArray(types[i], newArray); + modifyArray(indexedTypes.types[i], newArray); typeBuilder.setHeapType(i, newArray); } else { WASM_UNREACHABLE("bad type"); @@ -70,7 +70,7 @@ void GlobalTypeRewriter::update() { // Apply a super, if there is one if (auto super = type.getSuperType()) { - typeBuilder.setSubType(i, typeIndices[*super]); + typeBuilder.setSubType(i, indexedTypes.indices[*super]); } } @@ -81,8 +81,8 @@ void GlobalTypeRewriter::update() { // removed types, just modified them. using OldToNewTypes = std::unordered_map<HeapType, HeapType>; OldToNewTypes oldToNewTypes; - for (Index i = 0; i < types.size(); i++) { - oldToNewTypes[types[i]] = newTypes[i]; + for (Index i = 0; i < indexedTypes.types.size(); i++) { + oldToNewTypes[indexedTypes.types[i]] = newTypes[i]; } // Replace all the old types in the module with the new ones. @@ -202,24 +202,25 @@ Type GlobalTypeRewriter::getTempType(Type type) { } if (type.isRef()) { auto heapType = type.getHeapType(); - if (!typeIndices.count(heapType)) { + if (!indexedTypes.indices.count(heapType)) { // This type was not present in the module, but is now being used when // defining new types. That is fine; just use it. return type; } return typeBuilder.getTempRefType( - typeBuilder.getTempHeapType(typeIndices[heapType]), + typeBuilder.getTempHeapType(indexedTypes.indices[heapType]), type.getNullability()); } if (type.isRtt()) { auto rtt = type.getRtt(); auto newRtt = rtt; auto heapType = type.getHeapType(); - if (!typeIndices.count(heapType)) { + if (!indexedTypes.indices.count(heapType)) { // See above with references. return type; } - newRtt.heapType = typeBuilder.getTempHeapType(typeIndices[heapType]); + newRtt.heapType = + typeBuilder.getTempHeapType(indexedTypes.indices[heapType]); return typeBuilder.getTempRttType(newRtt); } if (type.isTuple()) { diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index 8794cd0da..151b0d41a 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -18,6 +18,7 @@ #define wasm_ir_type_updating_h #include "ir/branch-utils.h" +#include "ir/module-utils.h" #include "wasm-traversal.h" namespace wasm { @@ -336,11 +337,8 @@ public: private: TypeBuilder typeBuilder; - // The list of old types. - std::vector<HeapType> types; - - // Type indices of the old types. - std::unordered_map<HeapType, Index> typeIndices; + // The old types and their indices. + ModuleUtils::IndexedHeapTypes indexedTypes; }; namespace TypeUpdating { diff --git a/src/passes/NameTypes.cpp b/src/passes/NameTypes.cpp index 25b13dfb3..b855e4c0d 100644 --- a/src/passes/NameTypes.cpp +++ b/src/passes/NameTypes.cpp @@ -30,9 +30,7 @@ static const size_t NameLenLimit = 20; struct NameTypes : public Pass { void run(PassRunner* runner, Module* module) override { // Find all the types. - std::vector<HeapType> types; - std::unordered_map<HeapType, Index> typeIndices; - ModuleUtils::collectHeapTypes(*module, types, typeIndices); + std::vector<HeapType> types = ModuleUtils::collectHeapTypes(*module); // Ensure simple names. If a name already exists, and is short enough, keep // it. diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 33d817326..7bcf23b8e 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -3064,10 +3064,10 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printName(curr->name, o); } incIndent(); - std::vector<HeapType> types; - std::unordered_map<HeapType, Index> indices; - ModuleUtils::collectHeapTypes(*curr, types, indices); - for (auto type : types) { + // Use the same type order as the binary output would even though there is + // no code size benefit in the text format. + auto indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*curr); + for (auto type : indexedTypes.types) { doIndent(o, indent); o << '('; printMedium(o, "type") << ' '; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 01f37d9aa..0157f4858 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1310,8 +1310,7 @@ private: Module* wasm; BufferWithRandomAccess& o; BinaryIndexes indexes; - std::unordered_map<HeapType, Index> typeIndices; - std::vector<HeapType> types; + ModuleUtils::IndexedHeapTypes indexedTypes; bool debugInfo = true; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 5ebc4d864..36cb0fd1a 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -34,7 +34,7 @@ namespace wasm { void WasmBinaryWriter::prepare() { // Collect function types and their frequencies. Collect information in each // function in parallel, then merge. - ModuleUtils::collectHeapTypes(*wasm, types, typeIndices); + indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*wasm); importInfo = wasm::make_unique<ImportInfo>(*wasm); } @@ -216,14 +216,14 @@ void WasmBinaryWriter::writeMemory() { } void WasmBinaryWriter::writeTypes() { - if (types.size() == 0) { + if (indexedTypes.types.size() == 0) { return; } BYN_TRACE("== writeTypes\n"); auto start = startSection(BinaryConsts::Section::Type); - o << U32LEB(types.size()); - for (Index i = 0; i < types.size(); ++i) { - auto type = types[i]; + o << U32LEB(indexedTypes.types.size()); + for (Index i = 0; i < indexedTypes.types.size(); ++i) { + auto type = indexedTypes.types[i]; bool nominal = type.isNominal() || getTypeSystem() == TypeSystem::Nominal; BYN_TRACE("write " << type << std::endl); if (type.isSignature()) { @@ -539,9 +539,9 @@ uint32_t WasmBinaryWriter::getTagIndex(Name name) const { } uint32_t WasmBinaryWriter::getTypeIndex(HeapType type) const { - auto it = typeIndices.find(type); + auto it = indexedTypes.indices.find(type); #ifndef NDEBUG - if (it == typeIndices.end()) { + if (it == indexedTypes.indices.end()) { std::cout << "Missing type: " << type << '\n'; assert(0); } @@ -774,7 +774,7 @@ void WasmBinaryWriter::writeNames() { // type names { std::vector<HeapType> namedTypes; - for (auto& [type, _] : typeIndices) { + for (auto& [type, _] : indexedTypes.indices) { if (wasm->typeNames.count(type) && wasm->typeNames[type].name.is()) { namedTypes.push_back(type); } @@ -784,7 +784,7 @@ void WasmBinaryWriter::writeNames() { startSubsection(BinaryConsts::UserSections::Subsection::NameType); o << U32LEB(namedTypes.size()); for (auto type : namedTypes) { - o << U32LEB(typeIndices[type]); + o << U32LEB(indexedTypes.indices[type]); writeEscapedName(wasm->typeNames[type].name.str); } finishSubsection(substart); @@ -909,7 +909,7 @@ void WasmBinaryWriter::writeNames() { // GC field names if (wasm->features.hasGC()) { std::vector<HeapType> relevantTypes; - for (auto& type : types) { + for (auto& type : indexedTypes.types) { if (type.isStruct() && wasm->typeNames.count(type) && !wasm->typeNames[type].fieldNames.empty()) { relevantTypes.push_back(type); @@ -921,7 +921,7 @@ void WasmBinaryWriter::writeNames() { o << U32LEB(relevantTypes.size()); for (Index i = 0; i < relevantTypes.size(); i++) { auto type = relevantTypes[i]; - o << U32LEB(typeIndices[type]); + o << U32LEB(indexedTypes.indices[type]); std::unordered_map<Index, Name>& fieldNames = wasm->typeNames.at(type).fieldNames; o << U32LEB(fieldNames.size()); |