diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/module-utils.cpp | 101 | ||||
-rw-r--r-- | src/ir/module-utils.h | 8 | ||||
-rw-r--r-- | src/ir/type-updating.cpp | 93 | ||||
-rw-r--r-- | src/ir/type-updating.h | 5 | ||||
-rw-r--r-- | src/pass.h | 33 | ||||
-rw-r--r-- | src/passes/pass.cpp | 8 | ||||
-rw-r--r-- | src/support/insert_ordered.h | 9 | ||||
-rw-r--r-- | src/tools/wasm-opt.cpp | 6 | ||||
-rw-r--r-- | src/wasm-validator.h | 5 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 39 |
10 files changed, 244 insertions, 63 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 22f07a2e8..f2c90dcfd 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -198,10 +198,86 @@ void setIndices(IndexedHeapTypes& indexedTypes) { } } +InsertOrderedSet<HeapType> getPublicTypeSet(Module& wasm) { + InsertOrderedSet<HeapType> publicTypes; + + auto notePublic = [&](HeapType type) { + if (type.isBasic()) { + return; + } + // All the rec group members are public as well. + for (auto member : type.getRecGroup()) { + if (!publicTypes.insert(member)) { + // We've already inserted this rec group. + break; + } + } + }; + + // TODO: Consider Tags as well, but they should store HeapTypes instead of + // Signatures first. + ModuleUtils::iterImportedTables(wasm, [&](Table* table) { + assert(table->type.isRef()); + notePublic(table->type.getHeapType()); + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* global) { + if (global->type.isRef()) { + notePublic(global->type.getHeapType()); + } + }); + ModuleUtils::iterImportedFunctions( + wasm, [&](Function* func) { notePublic(func->type); }); + for (auto& ex : wasm.exports) { + switch (ex->kind) { + case ExternalKind::Function: { + auto* func = wasm.getFunction(ex->value); + notePublic(func->type); + continue; + } + case ExternalKind::Table: { + auto* table = wasm.getTable(ex->value); + assert(table->type.isRef()); + notePublic(table->type.getHeapType()); + continue; + } + case ExternalKind::Memory: + // Never a reference type. + continue; + case ExternalKind::Global: { + auto* global = wasm.getGlobal(ex->value); + if (global->type.isRef()) { + notePublic(global->type.getHeapType()); + } + continue; + } + case ExternalKind::Tag: + // TODO + continue; + case ExternalKind::Invalid: + break; + } + WASM_UNREACHABLE("unexpected export kind"); + } + + // Find all the other public types reachable from directly publicized types. + std::vector<HeapType> workList(publicTypes.begin(), publicTypes.end()); + while (workList.size()) { + auto curr = workList.back(); + workList.pop_back(); + for (auto t : curr.getReferencedHeapTypes()) { + if (!t.isBasic() && publicTypes.insert(t)) { + workList.push_back(t); + } + } + } + + return publicTypes; +} + } // anonymous namespace std::vector<HeapType> collectHeapTypes(Module& wasm) { - Counts counts = getHeapTypeCounts(wasm); + auto counts = getHeapTypeCounts(wasm); std::vector<HeapType> types; types.reserve(counts.size()); for (auto& [type, _] : counts) { @@ -210,6 +286,29 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) { return types; } +std::vector<HeapType> getPublicHeapTypes(Module& wasm) { + auto publicTypes = getPublicTypeSet(wasm); + std::vector<HeapType> types; + types.reserve(publicTypes.size()); + for (auto type : publicTypes) { + types.push_back(type); + } + return types; +} + +std::vector<HeapType> getPrivateHeapTypes(Module& wasm) { + auto allTypes = getHeapTypeCounts(wasm); + auto publicTypes = getPublicTypeSet(wasm); + std::vector<HeapType> types; + types.reserve(allTypes.size() - publicTypes.size()); + for (auto [type, _] : allTypes) { + if (!publicTypes.count(type)) { + types.push_back(type); + } + } + return types; +} + IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) { TypeSystem system = getTypeSystem(); Counts counts = getHeapTypeCounts(wasm); diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index aba5b1ec6..7dfd2d42d 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -519,6 +519,14 @@ template<typename T> struct CallGraphPropertyAnalysis { // module, i.e. the types that would appear in the type section. std::vector<HeapType> collectHeapTypes(Module& wasm); +// Collect all the heap types visible on the module boundary that cannot be +// changed. TODO: For open world use cases, this needs to include all subtypes +// of public types as well. +std::vector<HeapType> getPublicHeapTypes(Module& wasm); + +// getHeapTypes - getPublicHeapTypes +std::vector<HeapType> getPrivateHeapTypes(Module& wasm); + struct IndexedHeapTypes { std::vector<HeapType> types; std::unordered_map<HeapType, Index> indices; diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index 31c110738..9f9692a7b 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -19,6 +19,7 @@ #include "ir/local-structural-dominance.h" #include "ir/module-utils.h" #include "ir/utils.h" +#include "support/topological_sort.h" #include "wasm-type.h" #include "wasm.h" @@ -27,24 +28,55 @@ namespace wasm { GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} void GlobalTypeRewriter::update() { - indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(wasm); - if (indexedTypes.types.empty()) { + // Find the heap types that are not publicly observable. Even in a closed + // world scenario, don't modify public types because we assume that they may + // be reflected on or used for linking. Figure out where each private type + // will be located in the builder. Sort the private types so that supertypes + // come before their subtypes. + struct SortedPrivateTypes : TopologicalSort<HeapType, SortedPrivateTypes> { + SortedPrivateTypes(Module& wasm) { + auto privateTypes = ModuleUtils::getPrivateHeapTypes(wasm); + std::unordered_set<HeapType> supertypes; + for (auto type : privateTypes) { + if (auto super = type.getSuperType()) { + supertypes.insert(*super); + } + } + // Types that are not supertypes of others are the roots. + for (auto type : privateTypes) { + if (!supertypes.count(type)) { + push(type); + } + } + } + + void pushPredecessors(HeapType type) { + if (auto super = type.getSuperType()) { + push(*super); + } + } + }; + + Index i = 0; + for (auto type : SortedPrivateTypes(wasm)) { + typeIndices[type] = i++; + } + + if (typeIndices.size() == 0) { return; } - typeBuilder.grow(indexedTypes.types.size()); + typeBuilder.grow(typeIndices.size()); // All the input types are distinct, so we need to make sure the output types // are distinct as well. Further, the new types may have more recursions than // the original types, so the old recursion groups may not be sufficient any // more. Both of these problems are solved by putting all the new types into a // single large recursion group. - // TODO: When we properly analyze which types are external and which are - // internal to the module, only optimize internal types. typeBuilder.createRecGroup(0, typeBuilder.size()); // Create the temporary heap types. - for (Index i = 0; i < indexedTypes.types.size(); i++) { - auto type = indexedTypes.types[i]; + i = 0; + for (auto [type, _] : typeIndices) { if (type.isSignature()) { auto sig = type.getSignature(); TypeList newParams, newResults; @@ -56,8 +88,8 @@ void GlobalTypeRewriter::update() { } Signature newSig(typeBuilder.getTempTupleType(newParams), typeBuilder.getTempTupleType(newResults)); - modifySignature(indexedTypes.types[i], newSig); - typeBuilder.setHeapType(i, newSig); + modifySignature(type, newSig); + typeBuilder[i] = newSig; } else if (type.isStruct()) { auto struct_ = type.getStruct(); // Start with a copy to get mutability/packing/etc. @@ -65,24 +97,29 @@ void GlobalTypeRewriter::update() { for (auto& field : newStruct.fields) { field.type = getTempType(field.type); } - modifyStruct(indexedTypes.types[i], newStruct); - typeBuilder.setHeapType(i, newStruct); + modifyStruct(type, newStruct); + typeBuilder[i] = newStruct; } else if (type.isArray()) { auto array = type.getArray(); // Start with a copy to get mutability/packing/etc. auto newArray = array; newArray.element.type = getTempType(newArray.element.type); - modifyArray(indexedTypes.types[i], newArray); - typeBuilder.setHeapType(i, newArray); + modifyArray(type, newArray); + typeBuilder[i] = newArray; } else { WASM_UNREACHABLE("bad type"); } // Apply a super, if there is one if (auto super = type.getSuperType()) { - typeBuilder.setSubType( - i, typeBuilder.getTempHeapType(indexedTypes.indices[*super])); + if (auto it = typeIndices.find(*super); it != typeIndices.end()) { + assert(it->second < i); + typeBuilder[i].subTypeOf(typeBuilder[it->second]); + } else { + typeBuilder[i].subTypeOf(*super); + } } + i++; } auto buildResults = typeBuilder.build(); @@ -94,19 +131,17 @@ void GlobalTypeRewriter::update() { #endif auto& newTypes = *buildResults; - // Map the old types to the new ones. This uses the fact that type indices - // are the same in the old and new types, that is, we have not added or - // removed types, just modified them. + // Map the old types to the new ones. TypeMap oldToNewTypes; - for (Index i = 0; i < indexedTypes.types.size(); i++) { - oldToNewTypes[indexedTypes.types[i]] = newTypes[i]; + for (auto [type, index] : typeIndices) { + oldToNewTypes[type] = newTypes[index]; } // Update type names (doing it before mapTypes can help debugging there, but // has no other effect; mapTypes does not look at type names). for (auto& [old, new_] : oldToNewTypes) { - if (wasm.typeNames.count(old)) { - wasm.typeNames[new_] = wasm.typeNames[old]; + if (auto it = wasm.typeNames.find(old); it != wasm.typeNames.end()) { + wasm.typeNames[new_] = it->second; } } @@ -114,7 +149,6 @@ void GlobalTypeRewriter::update() { } void GlobalTypeRewriter::mapTypes(const TypeMap& oldToNewTypes) { - // Replace all the old types in the module with the new ones. struct CodeUpdater : public WalkerPass< @@ -246,14 +280,13 @@ Type GlobalTypeRewriter::getTempType(Type type) { } if (type.isRef()) { auto heapType = type.getHeapType(); - if (!indexedTypes.indices.count(heapType)) { - // This type was not present in the module, but is now being used when - // defining new types. That is fine; just use it. - return type; + if (auto it = typeIndices.find(heapType); it != typeIndices.end()) { + return typeBuilder.getTempRefType(typeBuilder[it->second], + type.getNullability()); } - return typeBuilder.getTempRefType( - typeBuilder.getTempHeapType(indexedTypes.indices[heapType]), - type.getNullability()); + // This type is not one that is eligible for optimizing. That is fine; just + // use it unmodified. + return type; } if (type.isTuple()) { auto& tuple = type.getTuple(); diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index 12e0b8b57..5cd8eef47 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -19,6 +19,7 @@ #include "ir/branch-utils.h" #include "ir/module-utils.h" +#include "support/insert_ordered.h" #include "wasm-traversal.h" namespace wasm { @@ -393,8 +394,8 @@ public: private: TypeBuilder typeBuilder; - // The old types and their indices. - ModuleUtils::IndexedHeapTypes indexedTypes; + // Map old types to their indices in the builder. + InsertOrderedMap<HeapType, Index> typeIndices; }; namespace TypeUpdating { diff --git a/src/pass.h b/src/pass.h index 2cb661e72..4ab0a8c34 100644 --- a/src/pass.h +++ b/src/pass.h @@ -185,12 +185,16 @@ struct PassOptions { // applied.) bool zeroFilledMemory = false; // Assume code outside of the module does not inspect or interact with GC and - // function references, even if they are passed out. The outside may hold on - // to them and pass them back in, but not inspect their contents or call them. - // By default we do not make that assumption, and assume anything that escapes + // function references, with the goal of being able to aggressively optimize + // all user-defined types. The outside may hold on to references and pass them + // back in, but may not inspect their contents, call them, or reflect on their + // types in any way. + // + // By default we do not make this assumption, and assume anything that escapes // to the outside may be inspected in detail, which prevents us from e.g. - // changing a type that escapes (so we can't remove or refine fields on an - // escaping struct type, for example). + // changing the type of any value that may escape except by refining it (so we + // can't remove or refine fields on an escaping struct type, for example, + // unless the new type declares the original type as a supertype). // // Note that the module can still have imports and exports - otherwise it // could do nothing at all! - so the meaning of "closed world" is a little @@ -202,17 +206,14 @@ struct PassOptions { // but we also want to keep types of things on the boundary unchanged. For // example, we should not change an exported function's signature, as the // outside may need that type to properly call the export. - // * For now we disallow nontrivial types on the boundary, that is, you - // cannot use a custom struct type as a function parameter type. Such a - // type would not be optimizable due to the constraint to not modify the - // type of the import/export, but it is very simple to use a type in a - // single import/export and then all of its supertypes become - // unoptimizable; likewise, in some optimizations all subtypes may be - // affected (say in not being able to remove a field from them). Overall, - // there is a risk of missing out on significant optimization - // opportunities here, and for that reason we error on using such types on - // the boundary for now. Instead, use basic types like anyref, externref, - // etc. + // + // * Since the goal of closedWorld is to optimize types aggressively but + // types on the module boundary cannot be changed, we assume the producer + // has made a mistake and we consider it a validation error if any user + // defined types besides the types of imported or exported functions + // themselves appear on the module boundary. For example, no user defined + // struct type may be a parameter or result of an exported function. This + // error may be relaxed or made more configurable in the future. bool closedWorld = false; // Whether to try to preserve debug info through, which are special calls. bool debugInfo = false; diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index e62b1f3a6..21501c62f 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -698,10 +698,6 @@ void PassRunner::run() { // for debug logging purposes, run each pass in full before running the // other auto totalTime = std::chrono::duration<double>(0); - WasmValidator::Flags validationFlags = WasmValidator::Minimal; - if (options.validateGlobally) { - validationFlags = validationFlags | WasmValidator::Globally; - } auto what = isNested ? "nested passes" : "passes"; std::cerr << "[PassRunner] running " << what << std::endl; size_t padding = 0; @@ -738,7 +734,7 @@ void PassRunner::run() { if (options.validate && !isNested) { // validate, ignoring the time std::cerr << "[PassRunner] (validating)\n"; - if (!WasmValidator().validate(*wasm, validationFlags)) { + if (!WasmValidator().validate(*wasm, options)) { std::cout << *wasm << '\n'; if (passDebug >= 2) { Fatal() << "Last pass (" << pass->name @@ -760,7 +756,7 @@ void PassRunner::run() { << " seconds." << std::endl; if (options.validate && !isNested) { std::cerr << "[PassRunner] (final validation)\n"; - if (!WasmValidator().validate(*wasm, validationFlags)) { + if (!WasmValidator().validate(*wasm, options)) { std::cout << *wasm << '\n'; Fatal() << "final module does not validate\n"; } diff --git a/src/support/insert_ordered.h b/src/support/insert_ordered.h index 1460c4329..f39eb61b0 100644 --- a/src/support/insert_ordered.h +++ b/src/support/insert_ordered.h @@ -50,12 +50,13 @@ template<typename T> struct InsertOrderedSet { } // cheating a bit, not returning the iterator - void insert(const T& val) { - auto it = Map.find(val); - if (it == Map.end()) { + bool insert(const T& val) { + auto [it, inserted] = Map.insert({val, List.begin()}); + if (inserted) { List.push_back(val); - Map.insert(std::make_pair(val, --List.end())); + it->second = --List.end(); } + return inserted; } size_t size() const { return Map.size(); } diff --git a/src/tools/wasm-opt.cpp b/src/tools/wasm-opt.cpp index 8f958f164..0d279dc08 100644 --- a/src/tools/wasm-opt.cpp +++ b/src/tools/wasm-opt.cpp @@ -295,7 +295,7 @@ int main(int argc, const char* argv[]) { } if (options.passOptions.validate) { - if (!WasmValidator().validate(wasm)) { + if (!WasmValidator().validate(wasm, options.passOptions)) { exitOnInvalidWasm("error validating input"); } } @@ -309,7 +309,7 @@ int main(int argc, const char* argv[]) { reader.setAllowOOB(fuzzOOB); reader.build(); if (options.passOptions.validate) { - if (!WasmValidator().validate(wasm)) { + if (!WasmValidator().validate(wasm, options.passOptions)) { std::cout << wasm << '\n'; Fatal() << "error after translate-to-fuzz"; } @@ -368,7 +368,7 @@ int main(int argc, const char* argv[]) { auto runPasses = [&]() { options.runPasses(wasm); if (options.passOptions.validate) { - bool valid = WasmValidator().validate(wasm); + bool valid = WasmValidator().validate(wasm, options.passOptions); if (!valid) { exitOnInvalidWasm("error after opts"); } diff --git a/src/wasm-validator.h b/src/wasm-validator.h index ccef84733..6aaa8197c 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -43,6 +43,7 @@ #include <sstream> #include <unordered_set> +#include "pass.h" #include "wasm.h" namespace wasm { @@ -52,12 +53,14 @@ struct WasmValidator { Minimal = 0, Web = 1 << 0, Globally = 1 << 1, - Quiet = 1 << 2 + Quiet = 1 << 2, + ClosedWorld = 1 << 3, }; using Flags = uint32_t; // Validate an entire module. bool validate(Module& module, Flags flags = Globally); + bool validate(Module& module, const PassOptions& options); // Validate a specific function. bool validate(Function* func, Module& module, Flags flags = Globally); diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 16b6f5314..726c70c7c 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -61,6 +61,7 @@ struct ValidationInfo { bool validateWeb; bool validateGlobally; bool quiet; + bool closedWorld; std::atomic<bool> valid; @@ -3500,6 +3501,32 @@ static void validateFeatures(Module& module, ValidationInfo& info) { } } +static void validateClosedWorldInterface(Module& module, ValidationInfo& info) { + // Error if there are any publicly exposed heap types beyond the types of + // publicly exposed functions. + std::unordered_set<HeapType> publicFuncTypes; + ModuleUtils::iterImportedFunctions( + module, [&](Function* func) { publicFuncTypes.insert(func->type); }); + for (auto& ex : module.exports) { + if (ex->kind == ExternalKind::Function) { + publicFuncTypes.insert(module.getFunction(ex->value)->type); + } + } + + for (auto type : ModuleUtils::getPublicHeapTypes(module)) { + if (!publicFuncTypes.count(type)) { + auto name = type.toString(); + if (auto it = module.typeNames.find(type); it != module.typeNames.end()) { + name = it->second.name.toString(); + } + info.fail("publicly exposed type disallowed with a closed world: $" + + name, + type, + nullptr); + } + } +} + // TODO: If we want the validator to be part of libwasm rather than libpasses, // then Using PassRunner::getPassDebug causes a circular dependence. We should // fix that, perhaps by moving some of the pass infrastructure into libsupport. @@ -3508,6 +3535,7 @@ bool WasmValidator::validate(Module& module, Flags flags) { info.validateWeb = (flags & Web) != 0; info.validateGlobally = (flags & Globally) != 0; info.quiet = (flags & Quiet) != 0; + info.closedWorld = (flags & ClosedWorld) != 0; // parallel wasm logic validation PassRunner runner(&module); FunctionValidator(module, &info).validate(&runner); @@ -3522,6 +3550,9 @@ bool WasmValidator::validate(Module& module, Flags flags) { validateTags(module, info); validateModule(module, info); validateFeatures(module, info); + if (info.closedWorld) { + validateClosedWorldInterface(module, info); + } } // validate additional internal IR details when in pass-debug mode if (PassRunner::getPassDebug()) { @@ -3537,6 +3568,14 @@ bool WasmValidator::validate(Module& module, Flags flags) { return info.valid.load(); } +bool WasmValidator::validate(Module& module, const PassOptions& options) { + Flags flags = options.validateGlobally ? Globally : Minimal; + if (options.closedWorld) { + flags |= ClosedWorld; + } + return validate(module, flags); +} + bool WasmValidator::validate(Function* func, Module& module, Flags flags) { ValidationInfo info(module); info.validateWeb = (flags & Web) != 0; |