summaryrefslogtreecommitdiff
path: root/src/passes/Print.cpp
diff options
context:
space:
mode:
authorThomas Lively <7121787+tlively@users.noreply.github.com>2021-02-25 16:21:35 -0800
committerGitHub <noreply@github.com>2021-02-25 16:21:35 -0800
commitb89b601a36e9cfe17dc1f09c641266ac2a715299 (patch)
tree950ff11ff513bef4bf87678043461cb572d9564a /src/passes/Print.cpp
parent142d5f32ce792327de62b62f09f25528dcd86950 (diff)
downloadbinaryen-b89b601a36e9cfe17dc1f09c641266ac2a715299.tar.gz
binaryen-b89b601a36e9cfe17dc1f09c641266ac2a715299.tar.bz2
binaryen-b89b601a36e9cfe17dc1f09c641266ac2a715299.zip
Support comparing, subtyping, and naming recursive types (#3610)
When the type section is emitted, types with an equal amount of references are ordered by an arbitrary measure of simplicity, which previously would infinitely recurse on structurally equivalent recursive types. Similarly, calculating whether an recursive type was a subtype of another recursive type could have infinitely recursed. This PR avoids infinite recursions in both cases by switching the algorithms from using normal inductive recursion to using coinductive recursion. The difference is that while the inductive algorithms assume the relations do not hold for a pair of HeapTypes until they have been exhaustively shown to hold, the coinductive algorithms assume the relations hold unless a counterexample can be found. In addition to those two algorithms, this PR also implement name generation for recursive types, using de Bruijn indices to stand in for inner uses of the recursive types. There are additional algorithms that will need to be switched from inductive to coinductive recursion, such as least upper bound generation, but these presented a good starting point and are sufficient to get some interesting programs working end-to-end.
Diffstat (limited to 'src/passes/Print.cpp')
-rw-r--r--src/passes/Print.cpp245
1 files changed, 150 insertions, 95 deletions
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 445021bcf..8c32642b4 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -67,70 +67,56 @@ static std::ostream& printLocal(Index index, Function* func, std::ostream& o) {
return printName(name, o);
}
-static void printHeapTypeName(std::ostream& os,
- HeapType type,
- Module* wasm = nullptr,
- bool first = true);
+namespace {
+
+// Helper for printing the name of a type. This output is guaranteed to not
+// contain spaces.
+struct TypeNamePrinter {
+ // Optional. If present, the module's HeapType names will be used.
+ Module* wasm;
+
+ // Keep track of the first depth at which we see each HeapType so if we see it
+ // again, we can unambiguously refer to it without infinitely recursing.
+ size_t currHeapTypeDepth = 0;
+ std::unordered_map<HeapType, size_t> heapTypeDepths;
+
+ // The stream we are printing to.
+ std::ostream& os;
+
+ TypeNamePrinter(std::ostream& os, Module* wasm = nullptr)
+ : wasm(wasm), os(os) {}
+
+ void print(Type type);
+ void print(HeapType heapType);
+ void print(const Tuple& tuple);
+ void print(const Field& field);
+ void print(const Signature& sig);
+ void print(const Struct& struct_);
+ void print(const Array& array);
+ void print(const Rtt& rtt);
+};
-// Prints the name of a type. This output is guaranteed to not contain spaces.
-static void printTypeName(std::ostream& os, Type type, Module* wasm = nullptr) {
+void TypeNamePrinter::print(Type type) {
if (type.isBasic()) {
os << type;
- return;
- }
- if (type.isRtt()) {
- auto rtt = type.getRtt();
- os << "rtt_";
- if (rtt.hasDepth()) {
- os << rtt.depth << '_';
- }
- printHeapTypeName(os, rtt.heapType, wasm);
- return;
- }
- if (type.isTuple()) {
- auto sep = "";
- for (auto t : type) {
- os << sep;
- sep = "_";
- printTypeName(os, t, wasm);
- }
- return;
- }
- if (type.isRef()) {
+ } else if (type.isTuple()) {
+ print(type.getTuple());
+ } else if (type.isRtt()) {
+ print(type.getRtt());
+ } else if (type.isRef()) {
os << "ref";
if (type.isNullable()) {
os << "?";
}
- os << "|";
- printHeapTypeName(os, type.getHeapType(), wasm, false);
- os << "|";
- return;
- }
- WASM_UNREACHABLE("unsupported print type");
-}
-
-static void
-printFieldName(std::ostream& os, const Field& field, Module* wasm = nullptr) {
- if (field.mutable_) {
- os << "mut:";
- }
- if (field.type == Type::i32 && field.packedType != Field::not_packed) {
- if (field.packedType == Field::i8) {
- os << "i8";
- } else if (field.packedType == Field::i16) {
- os << "i16";
- } else {
- WASM_UNREACHABLE("invalid packed type");
- }
+ os << '|';
+ print(type.getHeapType());
+ os << '|';
} else {
- printTypeName(os, field.type, wasm);
+ WASM_UNREACHABLE("unexpected type");
}
}
-// Prints the name of a heap type. As with printTypeName, this output is
-// guaranteed to not contain spaces.
-static void
-printHeapTypeName(std::ostream& os, HeapType type, Module* wasm, bool first) {
+void TypeNamePrinter::print(HeapType type) {
if (type.isBasic()) {
os << type;
return;
@@ -144,37 +130,104 @@ printHeapTypeName(std::ostream& os, HeapType type, Module* wasm, bool first) {
os << '$' << wasm->typeNames[type].name;
return;
}
- if (first) {
- os << '$';
+ // If we have seen this HeapType before, just print its relative depth instead
+ // of infinitely recursing.
+ auto it = heapTypeDepths.find(type);
+ if (it != heapTypeDepths.end()) {
+ assert(it->second <= currHeapTypeDepth);
+ size_t relativeDepth = currHeapTypeDepth - it->second;
+ os << "..." << relativeDepth;
+ return;
+ }
+
+ // If this is the top-level heap type, add a $
+ if (currHeapTypeDepth == 0) {
+ os << "$";
}
+
+ // Update the context for the current HeapType before recursing.
+ heapTypeDepths[type] = ++currHeapTypeDepth;
+
if (type.isSignature()) {
- auto sig = type.getSignature();
- printTypeName(os, sig.params, wasm);
- if (first) {
- os << "_=>_";
- } else {
- os << "_->_";
- }
- printTypeName(os, sig.results, wasm);
+ print(type.getSignature());
} else if (type.isStruct()) {
- auto struct_ = type.getStruct();
- os << "{";
- auto sep = "";
- for (auto& field : struct_.fields) {
- os << sep;
- sep = "_";
- printFieldName(os, field, wasm);
- }
- os << "}";
+ print(type.getStruct());
} else if (type.isArray()) {
- os << "[";
- printFieldName(os, type.getArray().element, wasm);
- os << "]";
+ print(type.getArray());
} else {
- os << type;
+ WASM_UNREACHABLE("unexpected type");
+ }
+
+ // Restore the previous context after the recursion.
+ heapTypeDepths.erase(type);
+ --currHeapTypeDepth;
+}
+
+void TypeNamePrinter::print(const Tuple& tuple) {
+ auto sep = "";
+ for (auto type : tuple.types) {
+ os << sep;
+ sep = "_";
+ print(type);
}
}
+void TypeNamePrinter::print(const Field& field) {
+ if (field.mutable_) {
+ os << "mut:";
+ }
+ if (field.type == Type::i32 && field.packedType != Field::not_packed) {
+ if (field.packedType == Field::i8) {
+ os << "i8";
+ } else if (field.packedType == Field::i16) {
+ os << "i16";
+ } else {
+ WASM_UNREACHABLE("invalid packed type");
+ }
+ } else {
+ print(field.type);
+ }
+}
+
+void TypeNamePrinter::print(const Signature& sig) {
+ // TODO: Switch to using an unambiguous delimiter rather than differentiating
+ // only the top level with a different arrow.
+ print(sig.params);
+ if (currHeapTypeDepth == 1) {
+ os << "_=>_";
+ } else {
+ os << "_->_";
+ }
+ print(sig.results);
+}
+
+void TypeNamePrinter::print(const Struct& struct_) {
+ os << '{';
+ auto sep = "";
+ for (const auto& field : struct_.fields) {
+ os << sep;
+ sep = "_";
+ print(field);
+ }
+ os << '}';
+}
+
+void TypeNamePrinter::print(const Array& array) {
+ os << '[';
+ print(array.element);
+ os << ']';
+}
+
+void TypeNamePrinter::print(const Rtt& rtt) {
+ os << "rtt_";
+ if (rtt.hasDepth()) {
+ os << rtt.depth << '_';
+ }
+ print(rtt.heapType);
+}
+
+} // anonymous namespace
+
// Unlike the default format, tuple types in s-expressions should not have
// commas.
struct SExprType {
@@ -186,7 +239,9 @@ static std::ostream& printSExprType(std::ostream& o,
const SExprType& sType,
Module* wasm = nullptr) {
Type type = sType.type;
- if (type.isTuple()) {
+ if (type.isBasic()) {
+ o << type;
+ } else if (type.isTuple()) {
o << '(';
auto sep = "";
for (const auto& t : type) {
@@ -201,17 +256,17 @@ static std::ostream& printSExprType(std::ostream& o,
if (rtt.hasDepth()) {
o << rtt.depth << ' ';
}
- printHeapTypeName(o, rtt.heapType, wasm);
+ TypeNamePrinter(o, wasm).print(rtt.heapType);
o << ')';
} else if (type.isRef() && !type.isBasic()) {
o << "(ref ";
if (type.isNullable()) {
o << "null ";
}
- printHeapTypeName(o, type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(type.getHeapType());
o << ')';
} else {
- printTypeName(o, sType.type, wasm);
+ WASM_UNREACHABLE("unexpected type");
}
return o;
}
@@ -356,7 +411,7 @@ struct PrintExpressionContents
o << '(';
printMinor(o, "type ");
- printHeapTypeName(o, curr->sig, wasm);
+ TypeNamePrinter(o, wasm).print(HeapType(curr->sig));
o << ')';
}
void visitLocalGet(LocalGet* curr) {
@@ -1793,7 +1848,7 @@ struct PrintExpressionContents
void visitMemoryGrow(MemoryGrow* curr) { printMedium(o, "memory.grow"); }
void visitRefNull(RefNull* curr) {
printMedium(o, "ref.null ");
- printHeapTypeName(o, curr->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->type.getHeapType());
}
void visitRefIs(RefIs* curr) {
switch (curr->op) {
@@ -1864,11 +1919,11 @@ struct PrintExpressionContents
}
void visitRefTest(RefTest* curr) {
printMedium(o, "ref.test ");
- printHeapTypeName(o, curr->getCastType().getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->getCastType().getHeapType());
}
void visitRefCast(RefCast* curr) {
printMedium(o, "ref.cast ");
- printHeapTypeName(o, curr->getCastType().getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->getCastType().getHeapType());
}
void visitBrOn(BrOn* curr) {
switch (curr->op) {
@@ -1894,11 +1949,11 @@ struct PrintExpressionContents
}
void visitRttCanon(RttCanon* curr) {
printMedium(o, "rtt.canon ");
- printHeapTypeName(o, curr->type.getRtt().heapType, wasm);
+ TypeNamePrinter(o, wasm).print(curr->type.getRtt().heapType);
}
void visitRttSub(RttSub* curr) {
printMedium(o, "rtt.sub ");
- printHeapTypeName(o, curr->type.getRtt().heapType, wasm);
+ TypeNamePrinter(o, wasm).print(curr->type.getRtt().heapType);
}
void visitStructNew(StructNew* curr) {
printMedium(o, "struct.new_");
@@ -1906,7 +1961,7 @@ struct PrintExpressionContents
o << "default_";
}
o << "with_rtt ";
- printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->rtt->type.getHeapType());
}
void printUnreachableReplacement() {
// If we cannot print a valid unreachable instruction (say, a struct.get,
@@ -1940,7 +1995,7 @@ struct PrintExpressionContents
} else {
printMedium(o, "struct.get ");
}
- printHeapTypeName(o, heapType, wasm);
+ TypeNamePrinter(o, wasm).print(heapType);
o << ' ';
printFieldName(heapType, curr->index);
}
@@ -1951,7 +2006,7 @@ struct PrintExpressionContents
}
printMedium(o, "struct.set ");
auto heapType = curr->ref->type.getHeapType();
- printHeapTypeName(o, heapType, wasm);
+ TypeNamePrinter(o, wasm).print(heapType);
o << ' ';
printFieldName(heapType, curr->index);
}
@@ -1961,7 +2016,7 @@ struct PrintExpressionContents
o << "default_";
}
o << "with_rtt ";
- printHeapTypeName(o, curr->rtt->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->rtt->type.getHeapType());
}
void visitArrayGet(ArrayGet* curr) {
const auto& element = curr->ref->type.getHeapType().getArray().element;
@@ -1974,15 +2029,15 @@ struct PrintExpressionContents
} else {
printMedium(o, "array.get ");
}
- printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType());
}
void visitArraySet(ArraySet* curr) {
printMedium(o, "array.set ");
- printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType());
}
void visitArrayLen(ArrayLen* curr) {
printMedium(o, "array.len ");
- printHeapTypeName(o, curr->ref->type.getHeapType(), wasm);
+ TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType());
}
void visitRefAs(RefAs* curr) {
switch (curr->op) {
@@ -3251,7 +3306,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
doIndent(o, indent);
o << '(';
printMedium(o, "type") << ' ';
- printHeapTypeName(o, type, curr);
+ TypeNamePrinter(o, curr).print(type);
o << ' ';
handleHeapType(type);
o << ")" << maybeNewLine;
@@ -3432,7 +3487,7 @@ printStackInst(StackInst* inst, std::ostream& o, Function* func) {
case StackInst::TryEnd: {
printMedium(o, "end");
o << " ;; type: ";
- printTypeName(o, inst->type);
+ TypeNamePrinter(o).print(inst->type);
break;
}
case StackInst::IfElse: {