diff options
-rw-r--r-- | src/passes/Print.cpp | 86 | ||||
-rw-r--r-- | src/wasm-type.h | 18 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 389 | ||||
-rw-r--r-- | test/example/type-builder.cpp | 14 | ||||
-rw-r--r-- | test/example/type-builder.txt | 18 |
5 files changed, 285 insertions, 240 deletions
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 7b0d7a678..2467bf9bd 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -229,17 +229,7 @@ void TypeNamePrinter::print(const Rtt& rtt) { } // anonymous namespace -// Unlike the default format, tuple types in s-expressions should not have -// commas. -struct SExprType { - Type type; - SExprType(Type type) : type(type){}; -}; - -static std::ostream& printSExprType(std::ostream& o, - const SExprType& sType, - Module* wasm = nullptr) { - Type type = sType.type; +static std::ostream& printType(std::ostream& o, Type type, Module* wasm) { if (type.isBasic()) { o << type; } else if (type.isTuple()) { @@ -247,7 +237,7 @@ static std::ostream& printSExprType(std::ostream& o, auto sep = ""; for (const auto& t : type) { o << sep; - printSExprType(o, t, wasm); + printType(o, t, wasm); sep = " "; } o << ')'; @@ -272,32 +262,35 @@ static std::ostream& printSExprType(std::ostream& o, return o; } -// TODO: try to simplify or even remove this, as we may be able to do the same -// things with SExprType -struct ResultTypeName { - Type type; - ResultTypeName(Type type) : type(type) {} -}; - -std::ostream& printResultTypeName(std::ostream& os, - ResultTypeName typeName, - Module* wasm = nullptr) { - auto type = typeName.type; - os << "(result "; +static std::ostream& printPrefixedTypes(std::ostream& o, + const char* prefix, + Type type, + Module* wasm) { + o << '(' << prefix; + if (type == Type::none) { + return o << ')'; + } if (type.isTuple()) { // Tuple types are not printed in parens, we can just emit them one after // the other in the same list as the "result". - auto sep = ""; for (auto t : type) { - os << sep; - sep = " "; - printSExprType(os, t, wasm); + o << ' '; + printType(o, t, wasm); } } else { - printSExprType(os, type, wasm); + o << ' '; + printType(o, type, wasm); } - os << ')'; - return os; + o << ')'; + return o; +} + +static std::ostream& printResultType(std::ostream& o, Type type, Module* wasm) { + return printPrefixedTypes(o, "result", type, wasm); +} + +static std::ostream& printParamType(std::ostream& o, Type type, Module* wasm) { + return printPrefixedTypes(o, "param", type, wasm); } // Generic processing of a struct's field, given an optional module. Calls func @@ -352,14 +345,14 @@ struct PrintExpressionContents } if (curr->type.isConcrete()) { o << ' '; - printResultTypeName(o, curr->type, wasm); + printResultType(o, curr->type, wasm); } } void visitIf(If* curr) { printMedium(o, "if"); if (curr->type.isConcrete()) { o << ' '; - printResultTypeName(o, curr->type, wasm); + printResultType(o, curr->type, wasm); } } void visitLoop(Loop* curr) { @@ -370,7 +363,7 @@ struct PrintExpressionContents } if (curr->type.isConcrete()) { o << ' '; - printResultTypeName(o, curr->type, wasm); + printResultType(o, curr->type, wasm); } } void visitBreak(Break* curr) { @@ -1840,7 +1833,7 @@ struct PrintExpressionContents restoreNormalColor(o); if (curr->type.isRef()) { o << ' '; - printResultTypeName(o, curr->type, wasm); + printResultType(o, curr->type, wasm); } } void visitDrop(Drop* curr) { printMedium(o, "drop"); } @@ -1881,7 +1874,8 @@ struct PrintExpressionContents printName(curr->name, o); } if (curr->type.isConcrete()) { - o << ' ' << ResultType(curr->type); + o << ' '; + printResultType(o, curr->type, wasm); } } void visitThrow(Throw* curr) { @@ -2401,7 +2395,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { auto sep = ""; for (auto type : curr.params) { o << sep; - printSExprType(o, type, currModule); + printType(o, type, currModule); sep = " "; } o << ')'; @@ -2412,7 +2406,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { auto sep = ""; for (auto type : curr.results) { o << sep; - printSExprType(o, type, currModule); + printType(o, type, currModule); sep = " "; } o << ')'; @@ -2432,7 +2426,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { WASM_UNREACHABLE("invalid packed type"); } } else { - printSExprType(o, field.type, currModule); + printType(o, field.type, currModule); } if (field.mutable_) { o << ')'; @@ -2513,9 +2507,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { void emitGlobalType(Global* curr) { if (curr->mutable_) { o << "(mut "; - printSExprType(o, curr->type, currModule) << ')'; + printType(o, curr->type, currModule) << ')'; } else { - printSExprType(o, curr->type, currModule); + printType(o, curr->type, currModule); } } void visitImportedGlobal(Global* curr) { @@ -2576,13 +2570,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printMinor(o, "param "); printLocal(i, currFunction, o); o << ' '; - printSExprType(o, param, currModule) << ')'; + printType(o, param, currModule) << ')'; ++i; } } if (curr->sig.results != Type::none) { o << maybeSpace; - printResultTypeName(o, curr->sig.results, currModule); + printResultType(o, curr->sig.results, currModule); } incIndent(); for (size_t i = curr->getVarIndexBase(); i < curr->getNumLocals(); i++) { @@ -2590,7 +2584,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { o << '('; printMinor(o, "local "); printLocal(i, currFunction, o) << ' '; - printSExprType(o, curr->getLocalType(i), currModule) << ')'; + printType(o, curr->getLocalType(i), currModule) << ')'; o << maybeNewLine; } // Print the body. @@ -2641,7 +2635,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { o << "(event "; printName(curr->name, o); o << maybeSpace << "(attr " << curr->attribute << ')' << maybeSpace; - o << ParamType(curr->sig.params); + printParamType(o, curr->sig.params, currModule); o << "))"; o << maybeNewLine; } @@ -2651,7 +2645,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { printMedium(o, "event "); printName(curr->name, o); o << maybeSpace << "(attr " << curr->attribute << ')' << maybeSpace; - o << ParamType(curr->sig.params); + printParamType(o, curr->sig.params, currModule); o << ")" << maybeNewLine; } void printTableHeader(Table* curr) { diff --git a/src/wasm-type.h b/src/wasm-type.h index d0008d357..8c1d72085 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -281,20 +281,6 @@ public: const Type& operator[](size_t i) const; }; -// Wrapper type for formatting types as "(param i32 i64 f32)" -struct ParamType { - Type type; - ParamType(Type type) : type(type) {} - std::string toString() const; -}; - -// Wrapper type for formatting types as "(result i32 i64 f32)" -struct ResultType { - Type type; - ResultType(Type type) : type(type) {} - std::string toString() const; -}; - class HeapType { // Unlike `Type`, which represents the types of values on the WebAssembly // stack, `HeapType` is used to describe the structures that reference types @@ -509,14 +495,12 @@ struct TypeBuilder { }; std::ostream& operator<<(std::ostream&, Type); -std::ostream& operator<<(std::ostream&, ParamType); -std::ostream& operator<<(std::ostream&, ResultType); +std::ostream& operator<<(std::ostream&, HeapType); std::ostream& operator<<(std::ostream&, Tuple); std::ostream& operator<<(std::ostream&, Signature); std::ostream& operator<<(std::ostream&, Field); std::ostream& operator<<(std::ostream&, Struct); std::ostream& operator<<(std::ostream&, Array); -std::ostream& operator<<(std::ostream&, HeapType); std::ostream& operator<<(std::ostream&, Rtt); } // namespace wasm diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 1000e8b92..0941f8c58 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -132,6 +132,40 @@ struct SubTyper { bool isSubType(const Rtt& a, const Rtt& b); }; +// Helper for printing types without infinitely recursing on recursive types. +struct TypePrinter { + size_t currDepth = 0; + std::unordered_map<TypeID, size_t> depths; + + // The stream we are printing to. + std::ostream& os; + + TypePrinter(std::ostream& os) : os(os) {} + + std::ostream& print(Type type); + std::ostream& print(HeapType heapType); + std::ostream& print(const Tuple& tuple); + std::ostream& print(const Field& field); + std::ostream& print(const Signature& sig); + std::ostream& print(const Struct& struct_); + std::ostream& print(const Array& array); + std::ostream& print(const Rtt& rtt); + +private: + template<typename T, typename F> std::ostream& printChild(T curr, F printer) { + auto it = depths.find(curr.getID()); + if (it != depths.end()) { + assert(it->second <= currDepth); + size_t relativeDepth = currDepth - it->second; + return os << "..." << relativeDepth; + } + depths[curr.getID()] = ++currDepth; + printer(); + depths.erase(curr.getID()); + return os; + } +}; + } // anonymous namespace } // namespace wasm @@ -761,206 +795,41 @@ bool Signature::operator<(const Signature& other) const { return TypeComparator().lessThan(*this, other); } -namespace { - -std::ostream& -printPrefixedTypes(std::ostream& os, const char* prefix, Type type) { - os << '(' << prefix; - for (const auto& t : type) { - os << " " << t; - } - os << ')'; - return os; -} - -template<typename T> std::string genericToString(const T& t) { +template<typename T> static std::string genericToString(const T& t) { std::ostringstream ss; ss << t; return ss.str(); } - -} // anonymous namespace - std::string Type::toString() const { return genericToString(*this); } - -std::string ParamType::toString() const { return genericToString(*this); } - -std::string ResultType::toString() const { return genericToString(*this); } - +std::string HeapType::toString() const { return genericToString(*this); } std::string Tuple::toString() const { return genericToString(*this); } - std::string Signature::toString() const { return genericToString(*this); } - std::string Struct::toString() const { return genericToString(*this); } - std::string Array::toString() const { return genericToString(*this); } - -std::string HeapType::toString() const { return genericToString(*this); } - std::string Rtt::toString() const { return genericToString(*this); } - -std::ostream& operator<<(std::ostream&, TypeInfo); -std::ostream& operator<<(std::ostream&, HeapTypeInfo); - std::ostream& operator<<(std::ostream& os, Type type) { - if (type.isBasic()) { - switch (type.getBasic()) { - case Type::none: - return os << "none"; - case Type::unreachable: - return os << "unreachable"; - case Type::i32: - return os << "i32"; - case Type::i64: - return os << "i64"; - case Type::f32: - return os << "f32"; - case Type::f64: - return os << "f64"; - case Type::v128: - return os << "v128"; - case Type::funcref: - return os << "funcref"; - case Type::externref: - return os << "externref"; - case Type::anyref: - return os << "anyref"; - case Type::eqref: - return os << "eqref"; - case Type::i31ref: - return os << "i31ref"; - case Type::dataref: - return os << "dataref"; - } - } - return os << *getTypeInfo(type); -} - -std::ostream& operator<<(std::ostream& os, ParamType param) { - return printPrefixedTypes(os, "param", param.type); + return TypePrinter(os).print(type); } - -std::ostream& operator<<(std::ostream& os, ResultType param) { - return printPrefixedTypes(os, "result", param.type); +std::ostream& operator<<(std::ostream& os, HeapType heapType) { + return TypePrinter(os).print(heapType); } - std::ostream& operator<<(std::ostream& os, Tuple tuple) { - auto& types = tuple.types; - auto size = types.size(); - os << "("; - if (size) { - os << types[0]; - for (size_t i = 1; i < size; ++i) { - os << " " << types[i]; - } - } - return os << ")"; + return TypePrinter(os).print(tuple); } - std::ostream& operator<<(std::ostream& os, Signature sig) { - os << "(func"; - if (sig.params.getID() != Type::none) { - os << " "; - printPrefixedTypes(os, "param", sig.params); - } - if (sig.results.getID() != Type::none) { - os << " "; - printPrefixedTypes(os, "result", sig.results); - } - return os << ")"; + return TypePrinter(os).print(sig); } - std::ostream& operator<<(std::ostream& os, Field field) { - if (field.mutable_) { - os << "(mut "; - } - if (field.isPacked()) { - auto packedType = field.packedType; - if (packedType == Field::PackedType::i8) { - os << "i8"; - } else if (packedType == Field::PackedType::i16) { - os << "i16"; - } else { - WASM_UNREACHABLE("unexpected packed type"); - } - } else { - os << field.type; - } - if (field.mutable_) { - os << ")"; - } - return os; -}; - + return TypePrinter(os).print(field); +} std::ostream& operator<<(std::ostream& os, Struct struct_) { - os << "(struct"; - if (struct_.fields.size()) { - os << " (field"; - for (auto f : struct_.fields) { - os << " " << f; - } - os << ")"; - } - return os << ")"; + return TypePrinter(os).print(struct_); } - std::ostream& operator<<(std::ostream& os, Array array) { - return os << "(array " << array.element << ")"; -} - -std::ostream& operator<<(std::ostream& os, HeapType heapType) { - if (heapType.isBasic()) { - switch (heapType.getBasic()) { - case HeapType::func: - return os << "func"; - case HeapType::ext: - return os << "extern"; - case HeapType::any: - return os << "any"; - case HeapType::eq: - return os << "eq"; - case HeapType::i31: - return os << "i31"; - case HeapType::data: - return os << "data"; - } - } - return os << *getHeapTypeInfo(heapType); + return TypePrinter(os).print(array); } - std::ostream& operator<<(std::ostream& os, Rtt rtt) { - return os << "(rtt " << rtt.depth << " " << rtt.heapType << ")"; -} - -std::ostream& operator<<(std::ostream& os, TypeInfo info) { - switch (info.kind) { - case TypeInfo::TupleKind: { - return os << info.tuple; - } - case TypeInfo::RefKind: { - os << "(ref "; - if (info.ref.nullable) { - os << "null "; - } - return os << info.ref.heapType << ")"; - } - case TypeInfo::RttKind: { - return os << info.rtt; - } - } - WASM_UNREACHABLE("unexpected kind"); -} - -std::ostream& operator<<(std::ostream& os, HeapTypeInfo info) { - switch (info.kind) { - case HeapTypeInfo::SignatureKind: - return os << info.signature; - case HeapTypeInfo::StructKind: - return os << info.struct_; - case HeapTypeInfo::ArrayKind: - return os << info.array; - } - WASM_UNREACHABLE("unexpected kind"); + return TypePrinter(os).print(rtt); } namespace { @@ -1196,6 +1065,172 @@ bool SubTyper::isSubType(const Rtt& a, const Rtt& b) { return a.heapType == b.heapType && a.hasDepth() && !b.hasDepth(); } +std::ostream& TypePrinter::print(Type type) { + if (type.isBasic()) { + switch (type.getBasic()) { + case Type::none: + return os << "none"; + case Type::unreachable: + return os << "unreachable"; + case Type::i32: + return os << "i32"; + case Type::i64: + return os << "i64"; + case Type::f32: + return os << "f32"; + case Type::f64: + return os << "f64"; + case Type::v128: + return os << "v128"; + case Type::funcref: + return os << "funcref"; + case Type::externref: + return os << "externref"; + case Type::anyref: + return os << "anyref"; + case Type::eqref: + return os << "eqref"; + case Type::i31ref: + return os << "i31ref"; + case Type::dataref: + return os << "dataref"; + } + } + + return printChild(type, [&]() { + if (type.isTuple()) { + print(type.getTuple()); + } else if (type.isRef()) { + os << "(ref "; + if (type.isNullable()) { + os << "null "; + } + print(type.getHeapType()); + os << ')'; + } else if (type.isRtt()) { + print(type.getRtt()); + } else { + WASM_UNREACHABLE("unexpected type"); + } + }); +} + +std::ostream& TypePrinter::print(HeapType heapType) { + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::func: + return os << "func"; + case HeapType::ext: + return os << "extern"; + case HeapType::any: + return os << "any"; + case HeapType::eq: + return os << "eq"; + case HeapType::i31: + return os << "i31"; + case HeapType::data: + return os << "data"; + } + } + + return printChild(heapType, [&]() { + if (heapType.isSignature()) { + print(heapType.getSignature()); + } else if (heapType.isStruct()) { + print(heapType.getStruct()); + } else if (heapType.isArray()) { + print(heapType.getArray()); + } else { + WASM_UNREACHABLE("unexpected type"); + } + }); +} + +std::ostream& TypePrinter::print(const Tuple& tuple) { + os << '('; + auto sep = ""; + for (Type type : tuple.types) { + os << sep; + sep = " "; + print(type); + } + return os << ')'; +} + +std::ostream& TypePrinter::print(const Field& field) { + if (field.mutable_) { + os << "(mut "; + } + if (field.isPacked()) { + auto packedType = field.packedType; + if (packedType == Field::PackedType::i8) { + os << "i8"; + } else if (packedType == Field::PackedType::i16) { + os << "i16"; + } else { + WASM_UNREACHABLE("unexpected packed type"); + } + } else { + print(field.type); + } + if (field.mutable_) { + os << ')'; + } + return os; +} + +std::ostream& TypePrinter::print(const Signature& sig) { + auto printPrefixed = [&](const char* prefix, Type type) { + os << '(' << prefix; + for (Type t : type) { + os << ' '; + print(t); + } + os << ')'; + }; + + os << "(func"; + if (sig.params.getID() != Type::none) { + os << ' '; + printPrefixed("param", sig.params); + } + if (sig.results.getID() != Type::none) { + os << ' '; + printPrefixed("result", sig.results); + } + return os << ')'; +} + +std::ostream& TypePrinter::print(const Struct& struct_) { + os << "(struct"; + if (struct_.fields.size()) { + os << " (field"; + } + for (const Field& field : struct_.fields) { + os << ' '; + print(field); + } + if (struct_.fields.size()) { + os << ')'; + } + return os << ')'; +} + +std::ostream& TypePrinter::print(const Array& array) { + os << "(array "; + print(array.element); + return os << ')'; +} + +std::ostream& TypePrinter::print(const Rtt& rtt) { + os << "(rtt "; + if (rtt.hasDepth()) { + os << rtt.depth << ' '; + } + print(rtt.heapType); + return os << ')'; +} + } // anonymous namespace struct TypeBuilder::Impl { diff --git a/test/example/type-builder.cpp b/test/example/type-builder.cpp index 8503c037b..59222775b 100644 --- a/test/example/type-builder.cpp +++ b/test/example/type-builder.cpp @@ -112,6 +112,7 @@ void test_recursive() { builder.setHeapType(0, Signature(Type::none, temp)); built = builder.build(); } + std::cout << built[0] << "\n\n"; assert(built[0] == built[0].getSignature().results.getHeapType()); assert(Type(built[0], Nullable) == built[0].getSignature().results); } @@ -127,6 +128,8 @@ void test_recursive() { builder.setHeapType(1, Signature(Type::none, temp0)); built = builder.build(); } + std::cout << built[0] << "\n"; + std::cout << built[1] << "\n\n"; assert(built[0].getSignature().results.getHeapType() == built[1]); assert(built[1].getSignature().results.getHeapType() == built[0]); } @@ -148,6 +151,11 @@ void test_recursive() { builder.setHeapType(4, Signature(Type::none, temp0)); built = builder.build(); } + std::cout << built[0] << "\n"; + std::cout << built[1] << "\n"; + std::cout << built[2] << "\n"; + std::cout << built[3] << "\n"; + std::cout << built[4] << "\n\n"; assert(built[0].getSignature().results.getHeapType() == built[1]); assert(built[1].getSignature().results.getHeapType() == built[2]); assert(built[2].getSignature().results.getHeapType() == built[3]); @@ -175,6 +183,12 @@ void test_recursive() { builder.setHeapType(5, Signature(Type::none, temp1)); built = builder.build(); } + std::cout << built[0] << "\n"; + std::cout << built[1] << "\n"; + std::cout << built[2] << "\n"; + std::cout << built[3] << "\n"; + std::cout << built[4] << "\n"; + std::cout << built[5] << "\n\n"; assert(built[0] != built[1]); // TODO: canonicalize recursive types assert(built[2] == built[3]); assert(built[4] != built[5]); // Contain "different" recursive types diff --git a/test/example/type-builder.txt b/test/example/type-builder.txt index 0cb414be1..6b42ade8c 100644 --- a/test/example/type-builder.txt +++ b/test/example/type-builder.txt @@ -22,3 +22,21 @@ After building types: ;; Test canonicalization ;; Test recursive types +(func (result (ref null ...1))) + +(func (result (ref null (func (result (ref null ...3)))))) +(func (result (ref null (func (result (ref null ...3)))))) + +(func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null ...9))))))))))))))) +(func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null ...9))))))))))))))) +(func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null ...9))))))))))))))) +(func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null ...9))))))))))))))) +(func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null (func (result (ref null ...9))))))))))))))) + +(func (result (ref null ...1) (ref null (func)))) +(func (result (ref null ...1) (ref null (func)))) +(func) +(func) +(func (result (ref null (func (result ...1 (ref null (func))))))) +(func (result (ref null (func (result ...1 (ref null (func))))))) + |