summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2021-02-23 22:46:13 +0000
committerGitHub <noreply@github.com>2021-02-23 22:46:13 +0000
commita7c66754ba86854ea3e0381986796e6565b93199 (patch)
tree940abdbce4bcf5a10eab090c6b67c37dbb5d2c3b /src
parentc127eccf753ab86a4d5deecbd0f3fa78a83e42ad (diff)
downloadbinaryen-a7c66754ba86854ea3e0381986796e6565b93199.tar.gz
binaryen-a7c66754ba86854ea3e0381986796e6565b93199.tar.bz2
binaryen-a7c66754ba86854ea3e0381986796e6565b93199.zip
Properly use text format type names in printing (#3591)
This adds a TypeNames entry to modules, which can store names for types. So far this PR uses that to store type names from text format. Future PRs will add support for field names and for the binary format. (Field names are added to wasm.h here to see if we agree on this direction.) Most of the work here is threading a module through the various functions in Print.cpp. This keeps the module optional, so that we can still print an expression independently of a module, which has always been the case, and which I think we should keep (but, if a module was mandatory perhaps this would be a little simpler, and could be refactored into a form that depends on that). 99% of this diff are test updates, since almost all our tests use the text format, and many of them specify a type name but we used to ignore it. This is a step towards a proper solution for #3589
Diffstat (limited to 'src')
-rw-r--r--src/passes/Print.cpp132
-rw-r--r--src/wasm.h10
-rw-r--r--src/wasm/wasm-s-parser.cpp12
3 files changed, 100 insertions, 54 deletions
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index a75c15534..643810be4 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -67,11 +67,13 @@ static std::ostream& printLocal(Index index, Function* func, std::ostream& o) {
return printName(name, o);
}
-static void
-printHeapTypeName(std::ostream& os, HeapType type, bool first = true);
+static void printHeapTypeName(std::ostream& os,
+ HeapType type,
+ Module* wasm = nullptr,
+ bool first = true);
// Prints the name of a type. This output is guaranteed to not contain spaces.
-static void printTypeName(std::ostream& os, Type type) {
+static void printTypeName(std::ostream& os, Type type, Module* wasm = nullptr) {
if (type.isBasic()) {
os << type;
return;
@@ -82,7 +84,7 @@ static void printTypeName(std::ostream& os, Type type) {
if (rtt.hasDepth()) {
os << rtt.depth << '_';
}
- printHeapTypeName(os, rtt.heapType);
+ printHeapTypeName(os, rtt.heapType, wasm);
return;
}
if (type.isTuple()) {
@@ -90,7 +92,7 @@ static void printTypeName(std::ostream& os, Type type) {
for (auto t : type) {
os << sep;
sep = "_";
- printTypeName(os, t);
+ printTypeName(os, t, wasm);
}
return;
}
@@ -100,14 +102,15 @@ static void printTypeName(std::ostream& os, Type type) {
os << "?";
}
os << "|";
- printHeapTypeName(os, type.getHeapType(), false);
+ printHeapTypeName(os, type.getHeapType(), wasm, false);
os << "|";
return;
}
WASM_UNREACHABLE("unsupported print type");
}
-static void printFieldName(std::ostream& os, const Field& field) {
+static void
+printFieldName(std::ostream& os, const Field& field, Module* wasm = nullptr) {
if (field.mutable_) {
os << "mut:";
}
@@ -120,29 +123,39 @@ static void printFieldName(std::ostream& os, const Field& field) {
WASM_UNREACHABLE("invalid packed type");
}
} else {
- printTypeName(os, field.type);
+ printTypeName(os, field.type, wasm);
}
}
// Prints the name of a heap type. As with printTypeName, this output is
// guaranteed to not contain spaces.
-static void printHeapTypeName(std::ostream& os, HeapType type, bool first) {
+static void
+printHeapTypeName(std::ostream& os, HeapType type, Module* wasm, bool first) {
if (type.isBasic()) {
os << type;
return;
}
+ // If there is a name for this type in this module, use it.
+ // FIXME: in theory there could be two types, one with a name, and one
+ // without, and the one without gets an automatic name that matches the
+ // other's. To check for that, if (first) we could assert at the very end of
+ // this function that the automatic name is not present in the given names.
+ if (wasm && wasm->typeNames.count(type)) {
+ os << '$' << wasm->typeNames[type].name;
+ return;
+ }
if (first) {
os << '$';
}
if (type.isSignature()) {
auto sig = type.getSignature();
- printTypeName(os, sig.params);
+ printTypeName(os, sig.params, wasm);
if (first) {
os << "_=>_";
} else {
os << "_->_";
}
- printTypeName(os, sig.results);
+ printTypeName(os, sig.results, wasm);
} else if (type.isStruct()) {
auto struct_ = type.getStruct();
os << "{";
@@ -150,12 +163,12 @@ static void printHeapTypeName(std::ostream& os, HeapType type, bool first) {
for (auto& field : struct_.fields) {
os << sep;
sep = "_";
- printFieldName(os, field);
+ printFieldName(os, field, wasm);
}
os << "}";
} else if (type.isArray()) {
os << "[";
- printFieldName(os, type.getArray().element);
+ printFieldName(os, type.getArray().element, wasm);
os << "]";
} else {
os << type;
@@ -169,13 +182,16 @@ struct SExprType {
SExprType(Type type) : type(type){};
};
-static std::ostream& operator<<(std::ostream& o, const SExprType& sType) {
+static std::ostream& printSExprType(std::ostream& o,
+ const SExprType& sType,
+ Module* wasm = nullptr) {
Type type = sType.type;
if (type.isTuple()) {
o << '(';
auto sep = "";
for (const auto& t : type) {
- o << sep << SExprType(t);
+ o << sep;
+ printSExprType(o, t, wasm);
sep = " ";
}
o << ')';
@@ -185,17 +201,17 @@ static std::ostream& operator<<(std::ostream& o, const SExprType& sType) {
if (rtt.hasDepth()) {
o << rtt.depth << ' ';
}
- printHeapTypeName(o, rtt.heapType);
+ printHeapTypeName(o, rtt.heapType, wasm);
o << ')';
} else if (type.isRef() && !type.isBasic()) {
o << "(ref ";
if (type.isNullable()) {
o << "null ";
}
- printHeapTypeName(o, type.getHeapType());
+ printHeapTypeName(o, type.getHeapType(), wasm);
o << ')';
} else {
- printTypeName(o, sType.type);
+ printTypeName(o, sType.type, wasm);
}
return o;
}
@@ -207,7 +223,9 @@ struct ResultTypeName {
ResultTypeName(Type type) : type(type) {}
};
-std::ostream& operator<<(std::ostream& os, ResultTypeName typeName) {
+std::ostream& printResultTypeName(std::ostream& os,
+ ResultTypeName typeName,
+ Module* wasm = nullptr) {
auto type = typeName.type;
os << "(result ";
if (type.isTuple()) {
@@ -217,10 +235,10 @@ std::ostream& operator<<(std::ostream& os, ResultTypeName typeName) {
for (auto t : type) {
os << sep;
sep = " ";
- os << SExprType(t);
+ printSExprType(os, t, wasm);
}
} else {
- os << SExprType(type);
+ printSExprType(os, type, wasm);
}
os << ')';
return os;
@@ -238,14 +256,13 @@ static Type forceConcrete(Type type) {
// the children.
struct PrintExpressionContents
: public OverriddenVisitor<PrintExpressionContents> {
+ Module* wasm = nullptr;
Function* currFunction = nullptr;
std::ostream& o;
FeatureSet features;
- PrintExpressionContents(Function* currFunction,
- FeatureSet features,
- std::ostream& o)
- : currFunction(currFunction), o(o), features(features) {}
+ PrintExpressionContents(Module* wasm, Function* currFunction, std::ostream& o)
+ : wasm(wasm), currFunction(currFunction), o(o), features(wasm->features) {}
PrintExpressionContents(Function* currFunction, std::ostream& o)
: currFunction(currFunction), o(o), features(FeatureSet::All) {}
@@ -257,13 +274,15 @@ struct PrintExpressionContents
printName(curr->name, o);
}
if (curr->type.isConcrete()) {
- o << ' ' << ResultTypeName(curr->type);
+ o << ' ';
+ printResultTypeName(o, curr->type, wasm);
}
}
void visitIf(If* curr) {
printMedium(o, "if");
if (curr->type.isConcrete()) {
- o << ' ' << ResultTypeName(curr->type);
+ o << ' ';
+ printResultTypeName(o, curr->type, wasm);
}
}
void visitLoop(Loop* curr) {
@@ -273,7 +292,8 @@ struct PrintExpressionContents
printName(curr->name, o);
}
if (curr->type.isConcrete()) {
- o << ' ' << ResultTypeName(curr->type);
+ o << ' ';
+ printResultTypeName(o, curr->type, wasm);
}
}
void visitBreak(Break* curr) {
@@ -315,7 +335,7 @@ struct PrintExpressionContents
o << '(';
printMinor(o, "type ");
- printHeapTypeName(o, curr->sig);
+ printHeapTypeName(o, curr->sig, wasm);
o << ')';
}
void visitLocalGet(LocalGet* curr) {
@@ -1728,7 +1748,8 @@ struct PrintExpressionContents
void visitSelect(Select* curr) {
prepareColor(o) << "select";
if (curr->type.isRef()) {
- o << ' ' << ResultTypeName(curr->type);
+ o << ' ';
+ printResultTypeName(o, curr->type, wasm);
}
}
void visitDrop(Drop* curr) { printMedium(o, "drop"); }
@@ -1737,7 +1758,7 @@ struct PrintExpressionContents
void visitMemoryGrow(MemoryGrow* curr) { printMedium(o, "memory.grow"); }
void visitRefNull(RefNull* curr) {
printMedium(o, "ref.null ");
- printHeapTypeName(o, curr->type.getHeapType());
+ printHeapTypeName(o, curr->type.getHeapType(), wasm);
}
void visitRefIs(RefIs* curr) {
switch (curr->op) {
@@ -1808,11 +1829,11 @@ struct PrintExpressionContents
}
void visitRefTest(RefTest* curr) {
printMedium(o, "ref.test ");
- printHeapTypeName(o, curr->getCastType().getHeapType());
+ printHeapTypeName(o, curr->getCastType().getHeapType(), wasm);
}
void visitRefCast(RefCast* curr) {
printMedium(o, "ref.cast ");
- printHeapTypeName(o, curr->getCastType().getHeapType());
+ printHeapTypeName(o, curr->getCastType().getHeapType(), wasm);
}
void visitBrOn(BrOn* curr) {
switch (curr->op) {
@@ -1838,11 +1859,11 @@ struct PrintExpressionContents
}
void visitRttCanon(RttCanon* curr) {
printMedium(o, "rtt.canon ");
- printHeapTypeName(o, curr->type.getRtt().heapType);
+ printHeapTypeName(o, curr->type.getRtt().heapType, wasm);
}
void visitRttSub(RttSub* curr) {
printMedium(o, "rtt.sub ");
- printHeapTypeName(o, curr->type.getRtt().heapType);
+ printHeapTypeName(o, curr->type.getRtt().heapType, wasm);
}
void visitStructNew(StructNew* curr) {
printMedium(o, "struct.new_");
@@ -1850,7 +1871,7 @@ struct PrintExpressionContents
o << "default_";
}
o << "with_rtt ";
- printHeapTypeName(o, curr->rtt->type.getHeapType());
+ printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm);
}
void printUnreachableReplacement() {
// If we cannot print a valid unreachable instruction (say, a struct.get,
@@ -1875,7 +1896,7 @@ struct PrintExpressionContents
} else {
printMedium(o, "struct.get ");
}
- printHeapTypeName(o, curr->ref->type.getHeapType());
+ printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
o << ' ';
o << curr->index;
}
@@ -1885,7 +1906,7 @@ struct PrintExpressionContents
return;
}
printMedium(o, "struct.set ");
- printHeapTypeName(o, curr->ref->type.getHeapType());
+ printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
o << ' ';
o << curr->index;
}
@@ -1895,7 +1916,7 @@ struct PrintExpressionContents
o << "default_";
}
o << "with_rtt ";
- printHeapTypeName(o, curr->rtt->type.getHeapType());
+ printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm);
}
void visitArrayGet(ArrayGet* curr) {
const auto& element = curr->ref->type.getHeapType().getArray().element;
@@ -1908,15 +1929,15 @@ struct PrintExpressionContents
} else {
printMedium(o, "array.get ");
}
- printHeapTypeName(o, curr->ref->type.getHeapType());
+ printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
}
void visitArraySet(ArraySet* curr) {
printMedium(o, "array.set ");
- printHeapTypeName(o, curr->ref->type.getHeapType());
+ printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
}
void visitArrayLen(ArrayLen* curr) {
printMedium(o, "array.len ");
- printHeapTypeName(o, curr->ref->type.getHeapType());
+ printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
}
void visitRefAs(RefAs* curr) {
switch (curr->op) {
@@ -2019,8 +2040,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
void printExpressionContents(Expression* curr) {
if (currModule) {
- PrintExpressionContents(currFunction, currModule->features, o)
- .visit(curr);
+ PrintExpressionContents(currModule, currFunction, o).visit(curr);
} else {
PrintExpressionContents(currFunction, o).visit(curr);
}
@@ -2761,7 +2781,8 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
o << "(param ";
auto sep = "";
for (auto type : curr.params) {
- o << sep << SExprType(type);
+ o << sep;
+ printSExprType(o, type, currModule);
sep = " ";
}
o << ')';
@@ -2771,7 +2792,8 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
o << "(result ";
auto sep = "";
for (auto type : curr.results) {
- o << sep << SExprType(type);
+ o << sep;
+ printSExprType(o, type, currModule);
sep = " ";
}
o << ')';
@@ -2791,7 +2813,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
WASM_UNREACHABLE("invalid packed type");
}
} else {
- o << SExprType(field.type);
+ printSExprType(o, field.type, currModule);
}
if (field.mutable_) {
o << ')';
@@ -2864,9 +2886,10 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
}
void emitGlobalType(Global* curr) {
if (curr->mutable_) {
- o << "(mut " << SExprType(curr->type) << ')';
+ o << "(mut ";
+ printSExprType(o, curr->type, currModule) << ')';
} else {
- o << SExprType(curr->type);
+ printSExprType(o, curr->type, currModule);
}
}
void visitImportedGlobal(Global* curr) {
@@ -2926,21 +2949,22 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
o << '(';
printMinor(o, "param ");
printLocal(i, currFunction, o);
- o << ' ' << SExprType(param) << ')';
+ o << ' ';
+ printSExprType(o, param, currModule) << ')';
++i;
}
}
if (curr->sig.results != Type::none) {
o << maybeSpace;
- o << ResultTypeName(curr->sig.results);
+ printResultTypeName(o, curr->sig.results, currModule);
}
incIndent();
for (size_t i = curr->getVarIndexBase(); i < curr->getNumLocals(); i++) {
doIndent(o, indent);
o << '(';
printMinor(o, "local ");
- printLocal(i, currFunction, o)
- << ' ' << SExprType(curr->getLocalType(i)) << ')';
+ printLocal(i, currFunction, o) << ' ';
+ printSExprType(o, curr->getLocalType(i), currModule) << ')';
o << maybeNewLine;
}
// Print the body.
@@ -3175,7 +3199,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
doIndent(o, indent);
o << '(';
printMedium(o, "type") << ' ';
- printHeapTypeName(o, type);
+ printHeapTypeName(o, type, curr);
o << ' ';
handleHeapType(type);
o << ")" << maybeNewLine;
diff --git a/src/wasm.h b/src/wasm.h
index 199257925..4060bac4d 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1872,6 +1872,16 @@ public:
// Module name, if specified. Serves a documentary role only.
Name name;
+ // Optional type name information, used in printing only. Note that Types are
+ // globally interned, but type names are specific to a module.
+ struct TypeNames {
+ // The name of the type.
+ Name name;
+ // For a Struct, names of fields.
+ std::unordered_map<Index, Name> fieldNames;
+ };
+ std::unordered_map<HeapType, TypeNames> typeNames;
+
MixedArena allocator;
private:
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 4e9ebcc45..df01a18d8 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -843,6 +843,18 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) {
});
types = builder.build();
+
+ for (auto& pair : typeIndices) {
+ auto name = pair.first;
+ auto type = types[pair.second];
+ // A type may appear in the type section more than once, but we canonicalize
+ // types internally, so there will be a single name chosen for that type. Do
+ // so determistically.
+ if (wasm.typeNames.count(type) && wasm.typeNames[type].name.str < name) {
+ continue;
+ }
+ wasm.typeNames[type].name = name;
+ }
}
void SExpressionWasmBuilder::preParseFunctionType(Element& s) {