diff options
Diffstat (limited to 'src/ir')
-rw-r--r-- | src/ir/type-updating.cpp | 208 | ||||
-rw-r--r-- | src/ir/type-updating.h | 37 |
2 files changed, 245 insertions, 0 deletions
diff --git a/src/ir/type-updating.cpp b/src/ir/type-updating.cpp index a3ce8aad7..91f74cff5 100644 --- a/src/ir/type-updating.cpp +++ b/src/ir/type-updating.cpp @@ -16,9 +16,217 @@ #include "type-updating.h" #include "find_all.h" +#include "ir/module-utils.h" +#include "wasm-type.h" +#include "wasm.h" namespace wasm { +GlobalTypeRewriter::GlobalTypeRewriter(Module& wasm) : wasm(wasm) {} + +void GlobalTypeRewriter::update() { + ModuleUtils::collectHeapTypes(wasm, types, typeIndices); + typeBuilder.grow(types.size()); + + // Create the temporary heap types. + for (Index i = 0; i < types.size(); i++) { + auto type = types[i]; + if (type.isSignature()) { + auto sig = type.getSignature(); + TypeList newParams, newResults; + for (auto t : sig.params) { + newParams.push_back(getTempType(t)); + } + for (auto t : sig.results) { + newResults.push_back(getTempType(t)); + } + Signature newSig(typeBuilder.getTempTupleType(newParams), + typeBuilder.getTempTupleType(newResults)); + modifySignature(types[i], newSig); + typeBuilder.setHeapType(i, newSig); + } else if (type.isStruct()) { + auto struct_ = type.getStruct(); + // Start with a copy to get mutability/packing/etc. + auto newStruct = struct_; + for (auto& field : newStruct.fields) { + field.type = getTempType(field.type); + } + modifyStruct(types[i], newStruct); + typeBuilder.setHeapType(i, newStruct); + } else if (type.isArray()) { + auto array = type.getArray(); + // Start with a copy to get mutability/packing/etc. + auto newArray = array; + newArray.element.type = getTempType(newArray.element.type); + modifyArray(types[i], newArray); + typeBuilder.setHeapType(i, newArray); + } else { + WASM_UNREACHABLE("bad type"); + } + + // Apply a super, if there is one + HeapType super; + if (type.getSuperType(super)) { + typeBuilder.setSubType(i, typeIndices[super]); + } + } + + auto newTypes = typeBuilder.build(); + + // Map the old types to the new ones. This uses the fact that type indices + // are the same in the old and new types, that is, we have not added or + // removed types, just modified them. + using OldToNewTypes = std::unordered_map<HeapType, HeapType>; + OldToNewTypes oldToNewTypes; + for (Index i = 0; i < types.size(); i++) { + oldToNewTypes[types[i]] = newTypes[i]; + } + + // Replace all the old types in the module with the new ones. + struct CodeUpdater + : public WalkerPass< + PostWalker<CodeUpdater, UnifiedExpressionVisitor<CodeUpdater>>> { + bool isFunctionParallel() override { return true; } + + OldToNewTypes& oldToNewTypes; + + CodeUpdater(OldToNewTypes& oldToNewTypes) : oldToNewTypes(oldToNewTypes) {} + + CodeUpdater* create() override { return new CodeUpdater(oldToNewTypes); } + + Type getNew(Type type) { + if (type.isRef()) { + return Type(getNew(type.getHeapType()), type.getNullability()); + } + if (type.isRtt()) { + return Type(Rtt(type.getRtt().depth, getNew(type.getHeapType()))); + } + return type; + } + + HeapType getNew(HeapType type) { + if (type.isBasic()) { + return type; + } + if (type.isFunction() || type.isData()) { + assert(oldToNewTypes.count(type)); + return oldToNewTypes[type]; + } + return type; + } + + Signature getNew(Signature sig) { + return Signature(getNew(sig.params), getNew(sig.results)); + } + + void visitExpression(Expression* curr) { + // Update the type to the new one. + curr->type = getNew(curr->type); + + // Update any other type fields as well. + +#define DELEGATE_ID curr->_id + +#define DELEGATE_START(id) \ + auto* cast = curr->cast<id>(); \ + WASM_UNUSED(cast); + +#define DELEGATE_GET_FIELD(id, name) cast->name + +#define DELEGATE_FIELD_TYPE(id, name) cast->name = getNew(cast->name); + +#define DELEGATE_FIELD_HEAPTYPE(id, name) cast->name = getNew(cast->name); + +#define DELEGATE_FIELD_SIGNATURE(id, name) cast->name = getNew(cast->name); + +#define DELEGATE_FIELD_CHILD(id, name) +#define DELEGATE_FIELD_OPTIONAL_CHILD(id, name) +#define DELEGATE_FIELD_INT(id, name) +#define DELEGATE_FIELD_INT_ARRAY(id, name) +#define DELEGATE_FIELD_LITERAL(id, name) +#define DELEGATE_FIELD_NAME(id, name) +#define DELEGATE_FIELD_NAME_VECTOR(id, name) +#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, name) +#define DELEGATE_FIELD_SCOPE_NAME_USE(id, name) +#define DELEGATE_FIELD_SCOPE_NAME_USE_VECTOR(id, name) +#define DELEGATE_FIELD_ADDRESS(id, name) + +#include "wasm-delegations-fields.def" + } + }; + + CodeUpdater updater(oldToNewTypes); + PassRunner runner(&wasm); + updater.run(&runner, &wasm); + updater.walkModuleCode(&wasm); + + // Update global locations that refer to types. + for (auto& table : wasm.tables) { + table->type = updater.getNew(table->type); + } + for (auto& elementSegment : wasm.elementSegments) { + elementSegment->type = updater.getNew(elementSegment->type); + } + for (auto& global : wasm.globals) { + global->type = updater.getNew(global->type); + } + for (auto& func : wasm.functions) { + func->type = updater.getNew(func->type); + for (auto& var : func->vars) { + var = updater.getNew(var); + } + } + for (auto& tag : wasm.tags) { + tag->sig = updater.getNew(tag->sig); + } + + // Update type names. + for (auto& kv : oldToNewTypes) { + auto old = kv.first; + auto new_ = kv.second; + if (wasm.typeNames.count(old)) { + wasm.typeNames[new_] = wasm.typeNames[old]; + } + } +} + +Type GlobalTypeRewriter::getTempType(Type type) { + if (type.isBasic()) { + return type; + } + if (type.isRef()) { + auto heapType = type.getHeapType(); + if (!typeIndices.count(heapType)) { + // This type was not present in the module, but is now being used when + // defining new types. That is fine; just use it. + return type; + } + return typeBuilder.getTempRefType( + typeBuilder.getTempHeapType(typeIndices[heapType]), + type.getNullability()); + } + if (type.isRtt()) { + auto rtt = type.getRtt(); + auto newRtt = rtt; + auto heapType = type.getHeapType(); + if (!typeIndices.count(heapType)) { + // See above with references. + return type; + } + newRtt.heapType = typeBuilder.getTempHeapType(typeIndices[heapType]); + return typeBuilder.getTempRttType(newRtt); + } + if (type.isTuple()) { + auto& tuple = type.getTuple(); + auto newTuple = tuple; + for (auto& t : newTuple.types) { + t = getTempType(t); + } + return typeBuilder.getTempTupleType(newTuple); + } + WASM_UNREACHABLE("bad type"); +} + namespace TypeUpdating { bool canHandleAsLocal(Type type) { diff --git a/src/ir/type-updating.h b/src/ir/type-updating.h index 4668c0ad5..83c1e1aa1 100644 --- a/src/ir/type-updating.h +++ b/src/ir/type-updating.h @@ -305,6 +305,43 @@ struct TypeUpdater } }; +// Rewrites global heap types across an entire module, allowing changes to be +// made while doing so. +class GlobalTypeRewriter { +public: + GlobalTypeRewriter(Module& wasm); + virtual ~GlobalTypeRewriter() {} + + // Main entry point. This performs the entire process of creating new heap + // types and calling the hooks below, then applies the new types throughout + // the module. + void update(); + + // Subclasses can implement these methods to modify the new set of types that + // we map to. By default, we simply copy over the types, and these functions + // are the hooks to apply changes through. The methods receive as input the + // old type, and a structure that they can modify. That structure is the one + // used to define the new type in the TypeBuilder. + virtual void modifyStruct(HeapType oldType, Struct& struct_) {} + virtual void modifyArray(HeapType oldType, Array& array) {} + virtual void modifySignature(HeapType oldType, Signature& sig) {} + + // Map an old type to a temp type. This can be called from the above hooks, + // so that they can use a proper temp type of the TypeBuilder while modifying + // things. + Type getTempType(Type type); + +private: + Module& wasm; + TypeBuilder typeBuilder; + + // The list of old types. + std::vector<HeapType> types; + + // Type indices of the old types. + std::unordered_map<HeapType, Index> typeIndices; +}; + namespace TypeUpdating { // Checks whether a type is valid as a local, or whether |