diff options
author | Alon Zakai <azakai@google.com> | 2021-02-23 22:46:13 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-23 22:46:13 +0000 |
commit | a7c66754ba86854ea3e0381986796e6565b93199 (patch) | |
tree | 940abdbce4bcf5a10eab090c6b67c37dbb5d2c3b /src | |
parent | c127eccf753ab86a4d5deecbd0f3fa78a83e42ad (diff) | |
download | binaryen-a7c66754ba86854ea3e0381986796e6565b93199.tar.gz binaryen-a7c66754ba86854ea3e0381986796e6565b93199.tar.bz2 binaryen-a7c66754ba86854ea3e0381986796e6565b93199.zip |
Properly use text format type names in printing (#3591)
This adds a TypeNames entry to modules, which can store names for types. So
far this PR uses that to store type names from text format. Future PRs will add
support for field names and for the binary format.
(Field names are added to wasm.h here to see if we agree on this direction.)
Most of the work here is threading a module through the various functions in
Print.cpp. This keeps the module optional, so that we can still print an
expression independently of a module, which has always been the case, and
which I think we should keep (but, if a module was mandatory perhaps this
would be a little simpler, and could be refactored into a form that depends on
that).
99% of this diff are test updates, since almost all our tests use the text
format, and many of them specify a type name but we used to ignore it.
This is a step towards a proper solution for #3589
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/Print.cpp | 132 | ||||
-rw-r--r-- | src/wasm.h | 10 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 12 |
3 files changed, 100 insertions, 54 deletions
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index a75c15534..643810be4 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -67,11 +67,13 @@ static std::ostream& printLocal(Index index, Function* func, std::ostream& o) { return printName(name, o); } -static void -printHeapTypeName(std::ostream& os, HeapType type, bool first = true); +static void printHeapTypeName(std::ostream& os, + HeapType type, + Module* wasm = nullptr, + bool first = true); // Prints the name of a type. This output is guaranteed to not contain spaces. -static void printTypeName(std::ostream& os, Type type) { +static void printTypeName(std::ostream& os, Type type, Module* wasm = nullptr) { if (type.isBasic()) { os << type; return; @@ -82,7 +84,7 @@ static void printTypeName(std::ostream& os, Type type) { if (rtt.hasDepth()) { os << rtt.depth << '_'; } - printHeapTypeName(os, rtt.heapType); + printHeapTypeName(os, rtt.heapType, wasm); return; } if (type.isTuple()) { @@ -90,7 +92,7 @@ static void printTypeName(std::ostream& os, Type type) { for (auto t : type) { os << sep; sep = "_"; - printTypeName(os, t); + printTypeName(os, t, wasm); } return; } @@ -100,14 +102,15 @@ static void printTypeName(std::ostream& os, Type type) { os << "?"; } os << "|"; - printHeapTypeName(os, type.getHeapType(), false); + printHeapTypeName(os, type.getHeapType(), wasm, false); os << "|"; return; } WASM_UNREACHABLE("unsupported print type"); } -static void printFieldName(std::ostream& os, const Field& field) { +static void +printFieldName(std::ostream& os, const Field& field, Module* wasm = nullptr) { if (field.mutable_) { os << "mut:"; } @@ -120,29 +123,39 @@ static void printFieldName(std::ostream& os, const Field& field) { WASM_UNREACHABLE("invalid packed type"); } } else { - printTypeName(os, field.type); + printTypeName(os, field.type, wasm); } } // Prints the name of a heap type. As with printTypeName, this output is // guaranteed to not contain spaces. -static void printHeapTypeName(std::ostream& os, HeapType type, bool first) { +static void +printHeapTypeName(std::ostream& os, HeapType type, Module* wasm, bool first) { if (type.isBasic()) { os << type; return; } + // If there is a name for this type in this module, use it. + // FIXME: in theory there could be two types, one with a name, and one + // without, and the one without gets an automatic name that matches the + // other's. To check for that, if (first) we could assert at the very end of + // this function that the automatic name is not present in the given names. + if (wasm && wasm->typeNames.count(type)) { + os << '$' << wasm->typeNames[type].name; + return; + } if (first) { os << '$'; } if (type.isSignature()) { auto sig = type.getSignature(); - printTypeName(os, sig.params); + printTypeName(os, sig.params, wasm); if (first) { os << "_=>_"; } else { os << "_->_"; } - printTypeName(os, sig.results); + printTypeName(os, sig.results, wasm); } else if (type.isStruct()) { auto struct_ = type.getStruct(); os << "{"; @@ -150,12 +163,12 @@ static void printHeapTypeName(std::ostream& os, HeapType type, bool first) { for (auto& field : struct_.fields) { os << sep; sep = "_"; - printFieldName(os, field); + printFieldName(os, field, wasm); } os << "}"; } else if (type.isArray()) { os << "["; - printFieldName(os, type.getArray().element); + printFieldName(os, type.getArray().element, wasm); os << "]"; } else { os << type; @@ -169,13 +182,16 @@ struct SExprType { SExprType(Type type) : type(type){}; }; -static std::ostream& operator<<(std::ostream& o, const SExprType& sType) { +static std::ostream& printSExprType(std::ostream& o, + const SExprType& sType, + Module* wasm = nullptr) { Type type = sType.type; if (type.isTuple()) { o << '('; auto sep = ""; for (const auto& t : type) { - o << sep << SExprType(t); + o << sep; + printSExprType(o, t, wasm); sep = " "; } o << ')'; @@ -185,17 +201,17 @@ static std::ostream& operator<<(std::ostream& o, const SExprType& sType) { if (rtt.hasDepth()) { o << rtt.depth << ' '; } - printHeapTypeName(o, rtt.heapType); + printHeapTypeName(o, rtt.heapType, wasm); o << ')'; } else if (type.isRef() && !type.isBasic()) { o << "(ref "; if (type.isNullable()) { o << "null "; } - printHeapTypeName(o, type.getHeapType()); + printHeapTypeName(o, type.getHeapType(), wasm); o << ')'; } else { - printTypeName(o, sType.type); + printTypeName(o, sType.type, wasm); } return o; } @@ -207,7 +223,9 @@ struct ResultTypeName { ResultTypeName(Type type) : type(type) {} }; -std::ostream& operator<<(std::ostream& os, ResultTypeName typeName) { +std::ostream& printResultTypeName(std::ostream& os, + ResultTypeName typeName, + Module* wasm = nullptr) { auto type = typeName.type; os << "(result "; if (type.isTuple()) { @@ -217,10 +235,10 @@ std::ostream& operator<<(std::ostream& os, ResultTypeName typeName) { for (auto t : type) { os << sep; sep = " "; - os << SExprType(t); + printSExprType(os, t, wasm); } } else { - os << SExprType(type); + printSExprType(os, type, wasm); } os << ')'; return os; @@ -238,14 +256,13 @@ static Type forceConcrete(Type type) { // the children. struct PrintExpressionContents : public OverriddenVisitor<PrintExpressionContents> { + Module* wasm = nullptr; Function* currFunction = nullptr; std::ostream& o; FeatureSet features; - PrintExpressionContents(Function* currFunction, - FeatureSet features, - std::ostream& o) - : currFunction(currFunction), o(o), features(features) {} + PrintExpressionContents(Module* wasm, Function* currFunction, std::ostream& o) + : wasm(wasm), currFunction(currFunction), o(o), features(wasm->features) {} PrintExpressionContents(Function* currFunction, std::ostream& o) : currFunction(currFunction), o(o), features(FeatureSet::All) {} @@ -257,13 +274,15 @@ struct PrintExpressionContents printName(curr->name, o); } if (curr->type.isConcrete()) { - o << ' ' << ResultTypeName(curr->type); + o << ' '; + printResultTypeName(o, curr->type, wasm); } } void visitIf(If* curr) { printMedium(o, "if"); if (curr->type.isConcrete()) { - o << ' ' << ResultTypeName(curr->type); + o << ' '; + printResultTypeName(o, curr->type, wasm); } } void visitLoop(Loop* curr) { @@ -273,7 +292,8 @@ struct PrintExpressionContents printName(curr->name, o); } if (curr->type.isConcrete()) { - o << ' ' << ResultTypeName(curr->type); + o << ' '; + printResultTypeName(o, curr->type, wasm); } } void visitBreak(Break* curr) { @@ -315,7 +335,7 @@ struct PrintExpressionContents o << '('; printMinor(o, "type "); - printHeapTypeName(o, curr->sig); + printHeapTypeName(o, curr->sig, wasm); o << ')'; } void visitLocalGet(LocalGet* curr) { @@ -1728,7 +1748,8 @@ struct PrintExpressionContents void visitSelect(Select* curr) { prepareColor(o) << "select"; if (curr->type.isRef()) { - o << ' ' << ResultTypeName(curr->type); + o << ' '; + printResultTypeName(o, curr->type, wasm); } } void visitDrop(Drop* curr) { printMedium(o, "drop"); } @@ -1737,7 +1758,7 @@ struct PrintExpressionContents void visitMemoryGrow(MemoryGrow* curr) { printMedium(o, "memory.grow"); } void visitRefNull(RefNull* curr) { printMedium(o, "ref.null "); - printHeapTypeName(o, curr->type.getHeapType()); + printHeapTypeName(o, curr->type.getHeapType(), wasm); } void visitRefIs(RefIs* curr) { switch (curr->op) { @@ -1808,11 +1829,11 @@ struct PrintExpressionContents } void visitRefTest(RefTest* curr) { printMedium(o, "ref.test "); - printHeapTypeName(o, curr->getCastType().getHeapType()); + printHeapTypeName(o, curr->getCastType().getHeapType(), wasm); } void visitRefCast(RefCast* curr) { printMedium(o, "ref.cast "); - printHeapTypeName(o, curr->getCastType().getHeapType()); + printHeapTypeName(o, curr->getCastType().getHeapType(), wasm); } void visitBrOn(BrOn* curr) { switch (curr->op) { @@ -1838,11 +1859,11 @@ struct PrintExpressionContents } void visitRttCanon(RttCanon* curr) { printMedium(o, "rtt.canon "); - printHeapTypeName(o, curr->type.getRtt().heapType); + printHeapTypeName(o, curr->type.getRtt().heapType, wasm); } void visitRttSub(RttSub* curr) { printMedium(o, "rtt.sub "); - printHeapTypeName(o, curr->type.getRtt().heapType); + printHeapTypeName(o, curr->type.getRtt().heapType, wasm); } void visitStructNew(StructNew* curr) { printMedium(o, "struct.new_"); @@ -1850,7 +1871,7 @@ struct PrintExpressionContents o << "default_"; } o << "with_rtt "; - printHeapTypeName(o, curr->rtt->type.getHeapType()); + printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm); } void printUnreachableReplacement() { // If we cannot print a valid unreachable instruction (say, a struct.get, @@ -1875,7 +1896,7 @@ struct PrintExpressionContents } else { printMedium(o, "struct.get "); } - printHeapTypeName(o, curr->ref->type.getHeapType()); + printHeapTypeName(o, curr->ref->type.getHeapType(), wasm); o << ' '; o << curr->index; } @@ -1885,7 +1906,7 @@ struct PrintExpressionContents return; } printMedium(o, "struct.set "); - printHeapTypeName(o, curr->ref->type.getHeapType()); + printHeapTypeName(o, curr->ref->type.getHeapType(), wasm); o << ' '; o << curr->index; } @@ -1895,7 +1916,7 @@ struct PrintExpressionContents o << "default_"; } o << "with_rtt "; - printHeapTypeName(o, curr->rtt->type.getHeapType()); + printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm); } void visitArrayGet(ArrayGet* curr) { const auto& element = curr->ref->type.getHeapType().getArray().element; @@ -1908,15 +1929,15 @@ struct PrintExpressionContents } else { printMedium(o, "array.get "); } - printHeapTypeName(o, curr->ref->type.getHeapType()); + printHeapTypeName(o, curr->ref->type.getHeapType(), wasm); } void visitArraySet(ArraySet* curr) { printMedium(o, "array.set "); - printHeapTypeName(o, curr->ref->type.getHeapType()); + printHeapTypeName(o, curr->ref->type.getHeapType(), wasm); } void visitArrayLen(ArrayLen* curr) { printMedium(o, "array.len "); - printHeapTypeName(o, curr->ref->type.getHeapType()); + printHeapTypeName(o, curr->ref->type.getHeapType(), wasm); } void visitRefAs(RefAs* curr) { switch (curr->op) { @@ -2019,8 +2040,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { void printExpressionContents(Expression* curr) { if (currModule) { - PrintExpressionContents(currFunction, currModule->features, o) - .visit(curr); + PrintExpressionContents(currModule, currFunction, o).visit(curr); } else { PrintExpressionContents(currFunction, o).visit(curr); } @@ -2761,7 +2781,8 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { o << "(param "; auto sep = ""; for (auto type : curr.params) { - o << sep << SExprType(type); + o << sep; + printSExprType(o, type, currModule); sep = " "; } o << ')'; @@ -2771,7 +2792,8 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { o << "(result "; auto sep = ""; for (auto type : curr.results) { - o << sep << SExprType(type); + o << sep; + printSExprType(o, type, currModule); sep = " "; } o << ')'; @@ -2791,7 +2813,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { WASM_UNREACHABLE("invalid packed type"); } } else { - o << SExprType(field.type); + printSExprType(o, field.type, currModule); } if (field.mutable_) { o << ')'; @@ -2864,9 +2886,10 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { } void emitGlobalType(Global* curr) { if (curr->mutable_) { - o << "(mut " << SExprType(curr->type) << ')'; + o << "(mut "; + printSExprType(o, curr->type, currModule) << ')'; } else { - o << SExprType(curr->type); + printSExprType(o, curr->type, currModule); } } void visitImportedGlobal(Global* curr) { @@ -2926,21 +2949,22 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { o << '('; printMinor(o, "param "); printLocal(i, currFunction, o); - o << ' ' << SExprType(param) << ')'; + o << ' '; + printSExprType(o, param, currModule) << ')'; ++i; } } if (curr->sig.results != Type::none) { o << maybeSpace; - o << ResultTypeName(curr->sig.results); + printResultTypeName(o, curr->sig.results, currModule); } incIndent(); for (size_t i = curr->getVarIndexBase(); i < curr->getNumLocals(); i++) { doIndent(o, indent); o << '('; printMinor(o, "local "); - printLocal(i, currFunction, o) - << ' ' << SExprType(curr->getLocalType(i)) << ')'; + printLocal(i, currFunction, o) << ' '; + printSExprType(o, curr->getLocalType(i), currModule) << ')'; o << maybeNewLine; } // Print the body. @@ -3175,7 +3199,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { doIndent(o, indent); o << '('; printMedium(o, "type") << ' '; - printHeapTypeName(o, type); + printHeapTypeName(o, type, curr); o << ' '; handleHeapType(type); o << ")" << maybeNewLine; diff --git a/src/wasm.h b/src/wasm.h index 199257925..4060bac4d 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1872,6 +1872,16 @@ public: // Module name, if specified. Serves a documentary role only. Name name; + // Optional type name information, used in printing only. Note that Types are + // globally interned, but type names are specific to a module. + struct TypeNames { + // The name of the type. + Name name; + // For a Struct, names of fields. + std::unordered_map<Index, Name> fieldNames; + }; + std::unordered_map<HeapType, TypeNames> typeNames; + MixedArena allocator; private: diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 4e9ebcc45..df01a18d8 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -843,6 +843,18 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { }); types = builder.build(); + + for (auto& pair : typeIndices) { + auto name = pair.first; + auto type = types[pair.second]; + // A type may appear in the type section more than once, but we canonicalize + // types internally, so there will be a single name chosen for that type. Do + // so determistically. + if (wasm.typeNames.count(type) && wasm.typeNames[type].name.str < name) { + continue; + } + wasm.typeNames[type].name = name; + } } void SExpressionWasmBuilder::preParseFunctionType(Element& s) { |