diff options
-rw-r--r-- | src/ir/module-utils.cpp | 65 | ||||
-rw-r--r-- | src/ir/module-utils.h | 27 | ||||
-rw-r--r-- | src/ir/subtypes.h | 3 | ||||
-rw-r--r-- | src/ir/type-updating.cpp | 31 | ||||
-rw-r--r-- | src/ir/type-updating.h | 8 | ||||
-rw-r--r-- | src/passes/NameTypes.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 8 | ||||
-rw-r--r-- | src/wasm-binary.h | 3 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 22 |
9 files changed, 103 insertions, 68 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index c738c6143..9b2293984 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -19,22 +19,23 @@ namespace wasm::ModuleUtils { -void collectHeapTypes(Module& wasm, - std::vector<HeapType>& types, - std::unordered_map<HeapType, Index>& typeIndices) { - struct Counts : public InsertOrderedMap<HeapType, size_t> { - void note(HeapType type) { - if (!type.isBasic()) { - (*this)[type]++; - } +namespace { + +// Helper for collecting HeapTypes and their frequencies. +struct Counts : public InsertOrderedMap<HeapType, size_t> { + void note(HeapType type) { + if (!type.isBasic()) { + (*this)[type]++; } - void note(Type type) { - for (HeapType ht : type.getHeapTypeChildren()) { - note(ht); - } + } + void note(Type type) { + for (HeapType ht : type.getHeapTypeChildren()) { + note(ht); } - }; + } +}; +Counts getHeapTypeCounts(Module& wasm) { struct CodeScanner : PostWalker<CodeScanner, UnifiedExpressionVisitor<CodeScanner>> { Counts& counts; @@ -160,15 +161,49 @@ void collectHeapTypes(Module& wasm, } } + return counts; +} + +} // anonymous namespace + +std::vector<HeapType> collectHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + std::vector<HeapType> types; + types.reserve(counts.size()); + for (auto& [type, _] : counts) { + types.push_back(type); + } + return types; +} + +IndexedHeapTypes getIndexedHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + IndexedHeapTypes indexedTypes; + Index i = 0; + for (auto& [type, _] : counts) { + indexedTypes.types.push_back(type); + indexedTypes.indices[type] = i++; + } + return indexedTypes; +} + +IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { + Counts counts = getHeapTypeCounts(wasm); + // Sort by frequency and then original insertion order. std::vector<std::pair<HeapType, size_t>> sorted(counts.begin(), counts.end()); std::stable_sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { return a.second > b.second; }); + + // Collect the results. + IndexedHeapTypes indexedTypes; for (Index i = 0; i < sorted.size(); ++i) { - typeIndices[sorted[i].first] = i; - types.push_back(sorted[i].first); + HeapType type = sorted[i].first; + indexedTypes.types.push_back(type); + indexedTypes.indices[type] = i; } + return indexedTypes; } } // namespace wasm::ModuleUtils diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index d645ae5f7..c0b4dbcfd 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -455,17 +455,22 @@ template<typename T> struct CallGraphPropertyAnalysis { } }; -// Helper function for collecting all the types that are declared in a module, -// which means the HeapTypes (that are non-basic, that is, not eqref etc., which -// do not need to be defined). -// -// Used when emitting or printing a module to give HeapTypes canonical -// indices. HeapTypes are sorted in order of decreasing frequency to minize the -// size of their collective encoding. Both a vector mapping indices to -// HeapTypes and a map mapping HeapTypes to indices are produced. -void collectHeapTypes(Module& wasm, - std::vector<HeapType>& types, - std::unordered_map<HeapType, Index>& typeIndices); +// Helper function for collecting all the non-basic heap types used in the +// module, i.e. the types that would appear in the type section. +std::vector<HeapType> collectHeapTypes(Module& wasm); + +struct IndexedHeapTypes { + std::vector<HeapType> types; + std::unordered_map<HeapType, Index> indices; +}; + +// Similar to `collectHeapTypes`, but provides fast lookup of the index for each +// type as well. +IndexedHeapTypes getIndexedHeapTypes(Module& wasm); + +// The same as `getIndexedHeapTypes`, but also sorts the types by frequency of +// use to minimize code size. +IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm); } // namespace wasm::ModuleUtils diff --git a/src/ir/subtypes.h b/src/ir/subtypes.h index 64697e2a3..2315934d7 100644 --- a/src/ir/subtypes.h +++ b/src/ir/subtypes.h @@ -26,8 +26,7 @@ namespace wasm { // them. struct SubTypes { SubTypes(Module& wasm) { - std::unordered_map<HeapType, Index> typeIndices; - ModuleUtils::collectHeapTypes(wasm, types, typeIndices); + types = ModuleUtils::collectHeapTypes(wasm); for (auto type : types) { note(type); } diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 2920a4a95..3516a10a9 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -26,15 +26,15 @@ namespace wasm { GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} void GlobalTypeRewriter::update() { - ModuleUtils::collectHeapTypes(wasm, types, typeIndices); - if (types.empty()) { + indexedTypes = ModuleUtils::getIndexedHeapTypes(wasm); + if (indexedTypes.types.empty()) { return; } - typeBuilder.grow(types.size()); + typeBuilder.grow(indexedTypes.types.size()); // Create the temporary heap types. - for (Index i = 0; i < types.size(); i++) { - auto type = types[i]; + for (Index i = 0; i < indexedTypes.types.size(); i++) { + auto type = indexedTypes.types[i]; if (type.isSignature()) { auto sig = type.getSignature(); TypeList newParams, newResults; @@ -46,7 +46,7 @@ void GlobalTypeRewriter::update() { } Signature newSig(typeBuilder.getTempTupleType(newParams), typeBuilder.getTempTupleType(newResults)); - modifySignature(types[i], newSig); + modifySignature(indexedTypes.types[i], newSig); typeBuilder.setHeapType(i, newSig); } else if (type.isStruct()) { auto struct_ = type.getStruct(); @@ -55,14 +55,14 @@ void GlobalTypeRewriter::update() { for (auto& field : newStruct.fields) { field.type = getTempType(field.type); } - modifyStruct(types[i], newStruct); + modifyStruct(indexedTypes.types[i], newStruct); typeBuilder.setHeapType(i, newStruct); } else if (type.isArray()) { auto array = type.getArray(); // Start with a copy to get mutability/packing/etc. auto newArray = array; newArray.element.type = getTempType(newArray.element.type); - modifyArray(types[i], newArray); + modifyArray(indexedTypes.types[i], newArray); typeBuilder.setHeapType(i, newArray); } else { WASM_UNREACHABLE("bad type"); @@ -70,7 +70,7 @@ void GlobalTypeRewriter::update() { // Apply a super, if there is one if (auto super = type.getSuperType()) { - typeBuilder.setSubType(i, typeIndices[*super]); + typeBuilder.setSubType(i, indexedTypes.indices[*super]); } } @@ -81,8 +81,8 @@ void GlobalTypeRewriter::update() { // removed types, just modified them. using OldToNewTypes = std::unordered_map<HeapType, HeapType>; OldToNewTypes oldToNewTypes; - for (Index i = 0; i < types.size(); i++) { - oldToNewTypes[types[i]] = newTypes[i]; + for (Index i = 0; i < indexedTypes.types.size(); i++) { + oldToNewTypes[indexedTypes.types[i]] = newTypes[i]; } // Replace all the old types in the module with the new ones. @@ -202,24 +202,25 @@ Type GlobalTypeRewriter::getTempType(Type type) { } if (type.isRef()) { auto heapType = type.getHeapType(); - if (!typeIndices.count(heapType)) { + if (!indexedTypes.indices.count(heapType)) { // This type was not present in the module, but is now being used when // defining new types. That is fine; just use it. return type; } return typeBuilder.getTempRefType( - typeBuilder.getTempHeapType(typeIndices[heapType]), + typeBuilder.getTempHeapType(indexedTypes.indices[heapType]), type.getNullability()); } if (type.isRtt()) { auto rtt = type.getRtt(); auto newRtt = rtt; auto heapType = type.getHeapType(); - if (!typeIndices.count(heapType)) { + if (!indexedTypes.indices.count(heapType)) { // See above with references. return type; } - newRtt.heapType = typeBuilder.getTempHeapType(typeIndices[heapType]); + newRtt.heapType = + typeBuilder.getTempHeapType(indexedTypes.indices[heapType]); return typeBuilder.getTempRttType(newRtt); } if (type.isTuple()) { diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index 8794cd0da..151b0d41a 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -18,6 +18,7 @@ #define wasm_ir_type_updating_h #include "ir/branch-utils.h" +#include "ir/module-utils.h" #include "wasm-traversal.h" namespace wasm { @@ -336,11 +337,8 @@ public: private: TypeBuilder typeBuilder; - // The list of old types. - std::vector<HeapType> types; - - // Type indices of the old types. - std::unordered_map<HeapType, Index> typeIndices; + // The old types and their indices. + ModuleUtils::IndexedHeapTypes indexedTypes; }; namespace TypeUpdating { diff --git a/src/passes/NameTypes.cpp b/src/passes/NameTypes.cpp index 25b13dfb3..b855e4c0d 100644 --- a/src/passes/NameTypes.cpp +++ b/src/passes/NameTypes.cpp @@ -30,9 +30,7 @@ static const size_t NameLenLimit = 20; struct NameTypes : public Pass { void run(PassRunner* runner, Module* module) override { // Find all the types. - std::vector<HeapType> types; - std::unordered_map<HeapType, Index> typeIndices; - ModuleUtils::collectHeapTypes(*module, types, typeIndices); + std::vector<HeapType> types = ModuleUtils::collectHeapTypes(*module); // Ensure simple names. If a name already exists, and is short enough, keep // it. diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 33d817326..7bcf23b8e 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -3064,10 +3064,10 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printName(curr->name, o); } incIndent(); - std::vector<HeapType> types; - std::unordered_map<HeapType, Index> indices; - ModuleUtils::collectHeapTypes(*curr, types, indices); - for (auto type : types) { + // Use the same type order as the binary output would even though there is + // no code size benefit in the text format. + auto indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*curr); + for (auto type : indexedTypes.types) { doIndent(o, indent); o << '('; printMedium(o, "type") << ' '; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 01f37d9aa..0157f4858 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1310,8 +1310,7 @@ private: Module* wasm; BufferWithRandomAccess& o; BinaryIndexes indexes; - std::unordered_map<HeapType, Index> typeIndices; - std::vector<HeapType> types; + ModuleUtils::IndexedHeapTypes indexedTypes; bool debugInfo = true; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 5ebc4d864..36cb0fd1a 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -34,7 +34,7 @@ namespace wasm { void WasmBinaryWriter::prepare() { // Collect function types and their frequencies. Collect information in each // function in parallel, then merge. - ModuleUtils::collectHeapTypes(*wasm, types, typeIndices); + indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*wasm); importInfo = wasm::make_unique<ImportInfo>(*wasm); } @@ -216,14 +216,14 @@ void WasmBinaryWriter::writeMemory() { } void WasmBinaryWriter::writeTypes() { - if (types.size() == 0) { + if (indexedTypes.types.size() == 0) { return; } BYN_TRACE("== writeTypes\n"); auto start = startSection(BinaryConsts::Section::Type); - o << U32LEB(types.size()); - for (Index i = 0; i < types.size(); ++i) { - auto type = types[i]; + o << U32LEB(indexedTypes.types.size()); + for (Index i = 0; i < indexedTypes.types.size(); ++i) { + auto type = indexedTypes.types[i]; bool nominal = type.isNominal() || getTypeSystem() == TypeSystem::Nominal; BYN_TRACE("write " << type << std::endl); if (type.isSignature()) { @@ -539,9 +539,9 @@ uint32_t WasmBinaryWriter::getTagIndex(Name name) const { } uint32_t WasmBinaryWriter::getTypeIndex(HeapType type) const { - auto it = typeIndices.find(type); + auto it = indexedTypes.indices.find(type); #ifndef NDEBUG - if (it == typeIndices.end()) { + if (it == indexedTypes.indices.end()) { std::cout << "Missing type: " << type << '\n'; assert(0); } @@ -774,7 +774,7 @@ void WasmBinaryWriter::writeNames() { // type names { std::vector<HeapType> namedTypes; - for (auto& [type, _] : typeIndices) { + for (auto& [type, _] : indexedTypes.indices) { if (wasm->typeNames.count(type) && wasm->typeNames[type].name.is()) { namedTypes.push_back(type); } @@ -784,7 +784,7 @@ void WasmBinaryWriter::writeNames() { startSubsection(BinaryConsts::UserSections::Subsection::NameType); o << U32LEB(namedTypes.size()); for (auto type : namedTypes) { - o << U32LEB(typeIndices[type]); + o << U32LEB(indexedTypes.indices[type]); writeEscapedName(wasm->typeNames[type].name.str); } finishSubsection(substart); @@ -909,7 +909,7 @@ void WasmBinaryWriter::writeNames() { // GC field names if (wasm->features.hasGC()) { std::vector<HeapType> relevantTypes; - for (auto& type : types) { + for (auto& type : indexedTypes.types) { if (type.isStruct() && wasm->typeNames.count(type) && !wasm->typeNames[type].fieldNames.empty()) { relevantTypes.push_back(type); @@ -921,7 +921,7 @@ void WasmBinaryWriter::writeNames() { o << U32LEB(relevantTypes.size()); for (Index i = 0; i < relevantTypes.size(); i++) { auto type = relevantTypes[i]; - o << U32LEB(typeIndices[type]); + o << U32LEB(indexedTypes.indices[type]); std::unordered_map<Index, Name>& fieldNames = wasm->typeNames.at(type).fieldNames; o << U32LEB(fieldNames.size()); |