summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ir/module-utils.cpp101
-rw-r--r--src/ir/module-utils.h8
-rw-r--r--src/ir/type-updating.cpp93
-rw-r--r--src/ir/type-updating.h5
-rw-r--r--src/pass.h33
-rw-r--r--src/passes/pass.cpp8
-rw-r--r--src/support/insert_ordered.h9
-rw-r--r--src/tools/wasm-opt.cpp6
-rw-r--r--src/wasm-validator.h5
-rw-r--r--src/wasm/wasm-validator.cpp39
10 files changed, 244 insertions, 63 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index 22f07a2e8..f2c90dcfd 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -198,10 +198,86 @@ void setIndices(IndexedHeapTypes& indexedTypes) {
}
}
+InsertOrderedSet<HeapType> getPublicTypeSet(Module& wasm) {
+ InsertOrderedSet<HeapType> publicTypes;
+
+ auto notePublic = [&](HeapType type) {
+ if (type.isBasic()) {
+ return;
+ }
+ // All the rec group members are public as well.
+ for (auto member : type.getRecGroup()) {
+ if (!publicTypes.insert(member)) {
+ // We've already inserted this rec group.
+ break;
+ }
+ }
+ };
+
+ // TODO: Consider Tags as well, but they should store HeapTypes instead of
+ // Signatures first.
+ ModuleUtils::iterImportedTables(wasm, [&](Table* table) {
+ assert(table->type.isRef());
+ notePublic(table->type.getHeapType());
+ });
+ ModuleUtils::iterImportedGlobals(wasm, [&](Global* global) {
+ if (global->type.isRef()) {
+ notePublic(global->type.getHeapType());
+ }
+ });
+ ModuleUtils::iterImportedFunctions(
+ wasm, [&](Function* func) { notePublic(func->type); });
+ for (auto& ex : wasm.exports) {
+ switch (ex->kind) {
+ case ExternalKind::Function: {
+ auto* func = wasm.getFunction(ex->value);
+ notePublic(func->type);
+ continue;
+ }
+ case ExternalKind::Table: {
+ auto* table = wasm.getTable(ex->value);
+ assert(table->type.isRef());
+ notePublic(table->type.getHeapType());
+ continue;
+ }
+ case ExternalKind::Memory:
+ // Never a reference type.
+ continue;
+ case ExternalKind::Global: {
+ auto* global = wasm.getGlobal(ex->value);
+ if (global->type.isRef()) {
+ notePublic(global->type.getHeapType());
+ }
+ continue;
+ }
+ case ExternalKind::Tag:
+ // TODO
+ continue;
+ case ExternalKind::Invalid:
+ break;
+ }
+ WASM_UNREACHABLE("unexpected export kind");
+ }
+
+ // Find all the other public types reachable from directly publicized types.
+ std::vector<HeapType> workList(publicTypes.begin(), publicTypes.end());
+ while (workList.size()) {
+ auto curr = workList.back();
+ workList.pop_back();
+ for (auto t : curr.getReferencedHeapTypes()) {
+ if (!t.isBasic() && publicTypes.insert(t)) {
+ workList.push_back(t);
+ }
+ }
+ }
+
+ return publicTypes;
+}
+
} // anonymous namespace
std::vector<HeapType> collectHeapTypes(Module& wasm) {
- Counts counts = getHeapTypeCounts(wasm);
+ auto counts = getHeapTypeCounts(wasm);
std::vector<HeapType> types;
types.reserve(counts.size());
for (auto& [type, _] : counts) {
@@ -210,6 +286,29 @@ std::vector<HeapType> collectHeapTypes(Module& wasm) {
return types;
}
+std::vector<HeapType> getPublicHeapTypes(Module& wasm) {
+ auto publicTypes = getPublicTypeSet(wasm);
+ std::vector<HeapType> types;
+ types.reserve(publicTypes.size());
+ for (auto type : publicTypes) {
+ types.push_back(type);
+ }
+ return types;
+}
+
+std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
+ auto allTypes = getHeapTypeCounts(wasm);
+ auto publicTypes = getPublicTypeSet(wasm);
+ std::vector<HeapType> types;
+ types.reserve(allTypes.size() - publicTypes.size());
+ for (auto [type, _] : allTypes) {
+ if (!publicTypes.count(type)) {
+ types.push_back(type);
+ }
+ }
+ return types;
+}
+
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
TypeSystem system = getTypeSystem();
Counts counts = getHeapTypeCounts(wasm);
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index aba5b1ec6..7dfd2d42d 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -519,6 +519,14 @@ template<typename T> struct CallGraphPropertyAnalysis {
// module, i.e. the types that would appear in the type section.
std::vector<HeapType> collectHeapTypes(Module& wasm);
+// Collect all the heap types visible on the module boundary that cannot be
+// changed. TODO: For open world use cases, this needs to include all subtypes
+// of public types as well.
+std::vector<HeapType> getPublicHeapTypes(Module& wasm);
+
+// getHeapTypes - getPublicHeapTypes
+std::vector<HeapType> getPrivateHeapTypes(Module& wasm);
+
struct IndexedHeapTypes {
std::vector<HeapType> types;
std::unordered_map<HeapType, Index> indices;
diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp
index 31c110738..9f9692a7b 100644
--- a/src/ir/type-updating.cpp
+++ b/src/ir/type-updating.cpp
@@ -19,6 +19,7 @@
#include "ir/local-structural-dominance.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
+#include "support/topological_sort.h"
#include "wasm-type.h"
#include "wasm.h"
@@ -27,24 +28,55 @@ namespace wasm {
GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {}
void GlobalTypeRewriter::update() {
- indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(wasm);
- if (indexedTypes.types.empty()) {
+ // Find the heap types that are not publicly observable. Even in a closed
+ // world scenario, don't modify public types because we assume that they may
+ // be reflected on or used for linking. Figure out where each private type
+ // will be located in the builder. Sort the private types so that supertypes
+ // come before their subtypes.
+ struct SortedPrivateTypes : TopologicalSort<HeapType, SortedPrivateTypes> {
+ SortedPrivateTypes(Module& wasm) {
+ auto privateTypes = ModuleUtils::getPrivateHeapTypes(wasm);
+ std::unordered_set<HeapType> supertypes;
+ for (auto type : privateTypes) {
+ if (auto super = type.getSuperType()) {
+ supertypes.insert(*super);
+ }
+ }
+ // Types that are not supertypes of others are the roots.
+ for (auto type : privateTypes) {
+ if (!supertypes.count(type)) {
+ push(type);
+ }
+ }
+ }
+
+ void pushPredecessors(HeapType type) {
+ if (auto super = type.getSuperType()) {
+ push(*super);
+ }
+ }
+ };
+
+ Index i = 0;
+ for (auto type : SortedPrivateTypes(wasm)) {
+ typeIndices[type] = i++;
+ }
+
+ if (typeIndices.size() == 0) {
return;
}
- typeBuilder.grow(indexedTypes.types.size());
+ typeBuilder.grow(typeIndices.size());
// All the input types are distinct, so we need to make sure the output types
// are distinct as well. Further, the new types may have more recursions than
// the original types, so the old recursion groups may not be sufficient any
// more. Both of these problems are solved by putting all the new types into a
// single large recursion group.
- // TODO: When we properly analyze which types are external and which are
- // internal to the module, only optimize internal types.
typeBuilder.createRecGroup(0, typeBuilder.size());
// Create the temporary heap types.
- for (Index i = 0; i < indexedTypes.types.size(); i++) {
- auto type = indexedTypes.types[i];
+ i = 0;
+ for (auto [type, _] : typeIndices) {
if (type.isSignature()) {
auto sig = type.getSignature();
TypeList newParams, newResults;
@@ -56,8 +88,8 @@ void GlobalTypeRewriter::update() {
}
Signature newSig(typeBuilder.getTempTupleType(newParams),
typeBuilder.getTempTupleType(newResults));
- modifySignature(indexedTypes.types[i], newSig);
- typeBuilder.setHeapType(i, newSig);
+ modifySignature(type, newSig);
+ typeBuilder[i] = newSig;
} else if (type.isStruct()) {
auto struct_ = type.getStruct();
// Start with a copy to get mutability/packing/etc.
@@ -65,24 +97,29 @@ void GlobalTypeRewriter::update() {
for (auto& field : newStruct.fields) {
field.type = getTempType(field.type);
}
- modifyStruct(indexedTypes.types[i], newStruct);
- typeBuilder.setHeapType(i, newStruct);
+ modifyStruct(type, newStruct);
+ typeBuilder[i] = newStruct;
} else if (type.isArray()) {
auto array = type.getArray();
// Start with a copy to get mutability/packing/etc.
auto newArray = array;
newArray.element.type = getTempType(newArray.element.type);
- modifyArray(indexedTypes.types[i], newArray);
- typeBuilder.setHeapType(i, newArray);
+ modifyArray(type, newArray);
+ typeBuilder[i] = newArray;
} else {
WASM_UNREACHABLE("bad type");
}
// Apply a super, if there is one
if (auto super = type.getSuperType()) {
- typeBuilder.setSubType(
- i, typeBuilder.getTempHeapType(indexedTypes.indices[*super]));
+ if (auto it = typeIndices.find(*super); it != typeIndices.end()) {
+ assert(it->second < i);
+ typeBuilder[i].subTypeOf(typeBuilder[it->second]);
+ } else {
+ typeBuilder[i].subTypeOf(*super);
+ }
}
+ i++;
}
auto buildResults = typeBuilder.build();
@@ -94,19 +131,17 @@ void GlobalTypeRewriter::update() {
#endif
auto& newTypes = *buildResults;
- // Map the old types to the new ones. This uses the fact that type indices
- // are the same in the old and new types, that is, we have not added or
- // removed types, just modified them.
+ // Map the old types to the new ones.
TypeMap oldToNewTypes;
- for (Index i = 0; i < indexedTypes.types.size(); i++) {
- oldToNewTypes[indexedTypes.types[i]] = newTypes[i];
+ for (auto [type, index] : typeIndices) {
+ oldToNewTypes[type] = newTypes[index];
}
// Update type names (doing it before mapTypes can help debugging there, but
// has no other effect; mapTypes does not look at type names).
for (auto& [old, new_] : oldToNewTypes) {
- if (wasm.typeNames.count(old)) {
- wasm.typeNames[new_] = wasm.typeNames[old];
+ if (auto it = wasm.typeNames.find(old); it != wasm.typeNames.end()) {
+ wasm.typeNames[new_] = it->second;
}
}
@@ -114,7 +149,6 @@ void GlobalTypeRewriter::update() {
}
void GlobalTypeRewriter::mapTypes(const TypeMap& oldToNewTypes) {
-
// Replace all the old types in the module with the new ones.
struct CodeUpdater
: public WalkerPass<
@@ -246,14 +280,13 @@ Type GlobalTypeRewriter::getTempType(Type type) {
}
if (type.isRef()) {
auto heapType = type.getHeapType();
- if (!indexedTypes.indices.count(heapType)) {
- // This type was not present in the module, but is now being used when
- // defining new types. That is fine; just use it.
- return type;
+ if (auto it = typeIndices.find(heapType); it != typeIndices.end()) {
+ return typeBuilder.getTempRefType(typeBuilder[it->second],
+ type.getNullability());
}
- return typeBuilder.getTempRefType(
- typeBuilder.getTempHeapType(indexedTypes.indices[heapType]),
- type.getNullability());
+ // This type is not one that is eligible for optimizing. That is fine; just
+ // use it unmodified.
+ return type;
}
if (type.isTuple()) {
auto& tuple = type.getTuple();
diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h
index 12e0b8b57..5cd8eef47 100644
--- a/src/ir/type-updating.h
+++ b/src/ir/type-updating.h
@@ -19,6 +19,7 @@
#include "ir/branch-utils.h"
#include "ir/module-utils.h"
+#include "support/insert_ordered.h"
#include "wasm-traversal.h"
namespace wasm {
@@ -393,8 +394,8 @@ public:
private:
TypeBuilder typeBuilder;
- // The old types and their indices.
- ModuleUtils::IndexedHeapTypes indexedTypes;
+ // Map old types to their indices in the builder.
+ InsertOrderedMap<HeapType, Index> typeIndices;
};
namespace TypeUpdating {
diff --git a/src/pass.h b/src/pass.h
index 2cb661e72..4ab0a8c34 100644
--- a/src/pass.h
+++ b/src/pass.h
@@ -185,12 +185,16 @@ struct PassOptions {
// applied.)
bool zeroFilledMemory = false;
// Assume code outside of the module does not inspect or interact with GC and
- // function references, even if they are passed out. The outside may hold on
- // to them and pass them back in, but not inspect their contents or call them.
- // By default we do not make that assumption, and assume anything that escapes
+ // function references, with the goal of being able to aggressively optimize
+ // all user-defined types. The outside may hold on to references and pass them
+ // back in, but may not inspect their contents, call them, or reflect on their
+ // types in any way.
+ //
+ // By default we do not make this assumption, and assume anything that escapes
// to the outside may be inspected in detail, which prevents us from e.g.
- // changing a type that escapes (so we can't remove or refine fields on an
- // escaping struct type, for example).
+ // changing the type of any value that may escape except by refining it (so we
+ // can't remove or refine fields on an escaping struct type, for example,
+ // unless the new type declares the original type as a supertype).
//
// Note that the module can still have imports and exports - otherwise it
// could do nothing at all! - so the meaning of "closed world" is a little
@@ -202,17 +206,14 @@ struct PassOptions {
// but we also want to keep types of things on the boundary unchanged. For
// example, we should not change an exported function's signature, as the
// outside may need that type to properly call the export.
- // * For now we disallow nontrivial types on the boundary, that is, you
- // cannot use a custom struct type as a function parameter type. Such a
- // type would not be optimizable due to the constraint to not modify the
- // type of the import/export, but it is very simple to use a type in a
- // single import/export and then all of its supertypes become
- // unoptimizable; likewise, in some optimizations all subtypes may be
- // affected (say in not being able to remove a field from them). Overall,
- // there is a risk of missing out on significant optimization
- // opportunities here, and for that reason we error on using such types on
- // the boundary for now. Instead, use basic types like anyref, externref,
- // etc.
+ //
+ // * Since the goal of closedWorld is to optimize types aggressively but
+ // types on the module boundary cannot be changed, we assume the producer
+ // has made a mistake and we consider it a validation error if any user
+ // defined types besides the types of imported or exported functions
+ // themselves appear on the module boundary. For example, no user defined
+ // struct type may be a parameter or result of an exported function. This
+ // error may be relaxed or made more configurable in the future.
bool closedWorld = false;
// Whether to try to preserve debug info through, which are special calls.
bool debugInfo = false;
diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp
index e62b1f3a6..21501c62f 100644
--- a/src/passes/pass.cpp
+++ b/src/passes/pass.cpp
@@ -698,10 +698,6 @@ void PassRunner::run() {
// for debug logging purposes, run each pass in full before running the
// other
auto totalTime = std::chrono::duration<double>(0);
- WasmValidator::Flags validationFlags = WasmValidator::Minimal;
- if (options.validateGlobally) {
- validationFlags = validationFlags | WasmValidator::Globally;
- }
auto what = isNested ? "nested passes" : "passes";
std::cerr << "[PassRunner] running " << what << std::endl;
size_t padding = 0;
@@ -738,7 +734,7 @@ void PassRunner::run() {
if (options.validate && !isNested) {
// validate, ignoring the time
std::cerr << "[PassRunner] (validating)\n";
- if (!WasmValidator().validate(*wasm, validationFlags)) {
+ if (!WasmValidator().validate(*wasm, options)) {
std::cout << *wasm << '\n';
if (passDebug >= 2) {
Fatal() << "Last pass (" << pass->name
@@ -760,7 +756,7 @@ void PassRunner::run() {
<< " seconds." << std::endl;
if (options.validate && !isNested) {
std::cerr << "[PassRunner] (final validation)\n";
- if (!WasmValidator().validate(*wasm, validationFlags)) {
+ if (!WasmValidator().validate(*wasm, options)) {
std::cout << *wasm << '\n';
Fatal() << "final module does not validate\n";
}
diff --git a/src/support/insert_ordered.h b/src/support/insert_ordered.h
index 1460c4329..f39eb61b0 100644
--- a/src/support/insert_ordered.h
+++ b/src/support/insert_ordered.h
@@ -50,12 +50,13 @@ template<typename T> struct InsertOrderedSet {
}
// cheating a bit, not returning the iterator
- void insert(const T& val) {
- auto it = Map.find(val);
- if (it == Map.end()) {
+ bool insert(const T& val) {
+ auto [it, inserted] = Map.insert({val, List.begin()});
+ if (inserted) {
List.push_back(val);
- Map.insert(std::make_pair(val, --List.end()));
+ it->second = --List.end();
}
+ return inserted;
}
size_t size() const { return Map.size(); }
diff --git a/src/tools/wasm-opt.cpp b/src/tools/wasm-opt.cpp
index 8f958f164..0d279dc08 100644
--- a/src/tools/wasm-opt.cpp
+++ b/src/tools/wasm-opt.cpp
@@ -295,7 +295,7 @@ int main(int argc, const char* argv[]) {
}
if (options.passOptions.validate) {
- if (!WasmValidator().validate(wasm)) {
+ if (!WasmValidator().validate(wasm, options.passOptions)) {
exitOnInvalidWasm("error validating input");
}
}
@@ -309,7 +309,7 @@ int main(int argc, const char* argv[]) {
reader.setAllowOOB(fuzzOOB);
reader.build();
if (options.passOptions.validate) {
- if (!WasmValidator().validate(wasm)) {
+ if (!WasmValidator().validate(wasm, options.passOptions)) {
std::cout << wasm << '\n';
Fatal() << "error after translate-to-fuzz";
}
@@ -368,7 +368,7 @@ int main(int argc, const char* argv[]) {
auto runPasses = [&]() {
options.runPasses(wasm);
if (options.passOptions.validate) {
- bool valid = WasmValidator().validate(wasm);
+ bool valid = WasmValidator().validate(wasm, options.passOptions);
if (!valid) {
exitOnInvalidWasm("error after opts");
}
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index ccef84733..6aaa8197c 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -43,6 +43,7 @@
#include <sstream>
#include <unordered_set>
+#include "pass.h"
#include "wasm.h"
namespace wasm {
@@ -52,12 +53,14 @@ struct WasmValidator {
Minimal = 0,
Web = 1 << 0,
Globally = 1 << 1,
- Quiet = 1 << 2
+ Quiet = 1 << 2,
+ ClosedWorld = 1 << 3,
};
using Flags = uint32_t;
// Validate an entire module.
bool validate(Module& module, Flags flags = Globally);
+ bool validate(Module& module, const PassOptions& options);
// Validate a specific function.
bool validate(Function* func, Module& module, Flags flags = Globally);
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 16b6f5314..726c70c7c 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -61,6 +61,7 @@ struct ValidationInfo {
bool validateWeb;
bool validateGlobally;
bool quiet;
+ bool closedWorld;
std::atomic<bool> valid;
@@ -3500,6 +3501,32 @@ static void validateFeatures(Module& module, ValidationInfo& info) {
}
}
+static void validateClosedWorldInterface(Module& module, ValidationInfo& info) {
+ // Error if there are any publicly exposed heap types beyond the types of
+ // publicly exposed functions.
+ std::unordered_set<HeapType> publicFuncTypes;
+ ModuleUtils::iterImportedFunctions(
+ module, [&](Function* func) { publicFuncTypes.insert(func->type); });
+ for (auto& ex : module.exports) {
+ if (ex->kind == ExternalKind::Function) {
+ publicFuncTypes.insert(module.getFunction(ex->value)->type);
+ }
+ }
+
+ for (auto type : ModuleUtils::getPublicHeapTypes(module)) {
+ if (!publicFuncTypes.count(type)) {
+ auto name = type.toString();
+ if (auto it = module.typeNames.find(type); it != module.typeNames.end()) {
+ name = it->second.name.toString();
+ }
+ info.fail("publicly exposed type disallowed with a closed world: $" +
+ name,
+ type,
+ nullptr);
+ }
+ }
+}
+
// TODO: If we want the validator to be part of libwasm rather than libpasses,
// then Using PassRunner::getPassDebug causes a circular dependence. We should
// fix that, perhaps by moving some of the pass infrastructure into libsupport.
@@ -3508,6 +3535,7 @@ bool WasmValidator::validate(Module& module, Flags flags) {
info.validateWeb = (flags & Web) != 0;
info.validateGlobally = (flags & Globally) != 0;
info.quiet = (flags & Quiet) != 0;
+ info.closedWorld = (flags & ClosedWorld) != 0;
// parallel wasm logic validation
PassRunner runner(&module);
FunctionValidator(module, &info).validate(&runner);
@@ -3522,6 +3550,9 @@ bool WasmValidator::validate(Module& module, Flags flags) {
validateTags(module, info);
validateModule(module, info);
validateFeatures(module, info);
+ if (info.closedWorld) {
+ validateClosedWorldInterface(module, info);
+ }
}
// validate additional internal IR details when in pass-debug mode
if (PassRunner::getPassDebug()) {
@@ -3537,6 +3568,14 @@ bool WasmValidator::validate(Module& module, Flags flags) {
return info.valid.load();
}
+bool WasmValidator::validate(Module& module, const PassOptions& options) {
+ Flags flags = options.validateGlobally ? Globally : Minimal;
+ if (options.closedWorld) {
+ flags |= ClosedWorld;
+ }
+ return validate(module, flags);
+}
+
bool WasmValidator::validate(Function* func, Module& module, Flags flags) {
ValidationInfo info(module);
info.validateWeb = (flags & Web) != 0;