diff options
-rw-r--r-- | src/wasm-type.h | 50 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 10 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 21 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 66 | ||||
-rw-r--r-- | test/example/type-builder.cpp | 119 | ||||
-rw-r--r-- | test/example/type-builder.txt | 3 |
6 files changed, 190 insertions, 79 deletions
diff --git a/src/wasm-type.h b/src/wasm-type.h index b638108fc..befac3c1a 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -499,28 +499,70 @@ struct TypeBuilder { std::unique_ptr<Impl> impl; TypeBuilder(size_t n); + TypeBuilder() : TypeBuilder(0) {} ~TypeBuilder(); TypeBuilder(TypeBuilder& other) = delete; TypeBuilder(TypeBuilder&& other) = delete; TypeBuilder& operator=(TypeBuilder&) = delete; + // Append `n` new uninitialized HeapType slots to the end of the TypeBuilder. + void grow(size_t n); + + // The number of HeapType slots in the TypeBuilder. + size_t size(); + // Sets the heap type at index `i`. May only be called before `build`. + void setHeapType(size_t i, HeapType::BasicHeapType basic); void setHeapType(size_t i, Signature signature); void setHeapType(size_t i, const Struct& struct_); void setHeapType(size_t i, Struct&& struct_); void setHeapType(size_t i, Array array); + // Gets the temporary HeapType at index `i`. This HeapType should only be used + // to construct temporary Types using the methods below. + HeapType getTempHeapType(size_t i); + // Gets a temporary type or heap type for use in initializing the - // TypeBuilder's HeapTypes. Temporary Ref and Rtt types are backed by the - // HeapType at index `i`. + // TypeBuilder's HeapTypes. For Ref and Rtt types, the HeapType may be a + // temporary HeapType owned by this builder or a canonical HeapType. HeapType Type getTempTupleType(const Tuple&); - Type getTempRefType(size_t i, Nullability nullable); - Type getTempRttType(size_t i, uint32_t depth); + Type getTempRefType(HeapType heapType, Nullability nullable); + Type getTempRttType(Rtt rtt); // Canonicalizes and returns all of the heap types. May only be called once // all of the heap types have been initialized with `setHeapType`. std::vector<HeapType> build(); + + // Utility for ergonomically using operator[] instead of explicit setHeapType + // and getTempHeapType methods. + struct Entry { + TypeBuilder& builder; + size_t index; + operator HeapType() const { return builder.getTempHeapType(index); } + Entry& operator=(HeapType::BasicHeapType basic) { + builder.setHeapType(index, basic); + return *this; + } + Entry& operator=(Signature signature) { + builder.setHeapType(index, signature); + return *this; + } + Entry& operator=(const Struct& struct_) { + builder.setHeapType(index, struct_); + return *this; + } + Entry& operator=(Struct&& struct_) { + builder.setHeapType(index, std::move(struct_)); + return *this; + } + Entry& operator=(Array array) { + builder.setHeapType(index, array); + return *this; + } + }; + + Entry operator[](size_t i) { return Entry{*this, i}; } }; std::ostream& operator<<(std::ostream&, Type); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 5fff58045..5bd368e17 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1816,7 +1816,7 @@ void WasmBinaryBuilder::readTypes() { if (size_t(htCode) >= numTypes) { throwError("invalid type index: " + std::to_string(htCode)); } - return builder.getTempRefType(size_t(htCode), nullability); + return builder.getTempRefType(builder[size_t(htCode)], nullability); } case BinaryConsts::EncodedType::rtt_n: case BinaryConsts::EncodedType::rtt: { @@ -1826,7 +1826,7 @@ void WasmBinaryBuilder::readTypes() { if (size_t(htCode) >= numTypes) { throwError("invalid type index: " + std::to_string(htCode)); } - return builder.getTempRttType(htCode, depth); + return builder.getTempRttType(Rtt(depth, builder[htCode])); } default: throwError("unexpected type index: " + std::to_string(typeCode)); @@ -1896,11 +1896,11 @@ void WasmBinaryBuilder::readTypes() { BYN_TRACE("read one\n"); auto form = getS32LEB(); if (form == BinaryConsts::EncodedType::Func) { - builder.setHeapType(i, readSignatureDef()); + builder[i] = readSignatureDef(); } else if (form == BinaryConsts::EncodedType::Struct) { - builder.setHeapType(i, readStructDef()); + builder[i] = readStructDef(); } else if (form == BinaryConsts::EncodedType::Array) { - builder.setHeapType(i, Array(readFieldDef())); + builder[i] = Array(readFieldDef()); } else { throwError("bad type form " + std::to_string(form)); } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 7774eac3e..32f555d5a 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -691,13 +691,13 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { auto& referent = nullable ? *elem[2] : *elem[1]; const char* name = referent.c_str(); if (referent.dollared()) { - return builder.getTempRefType(typeIndices[name], nullable); + return builder.getTempRefType(builder[typeIndices[name]], nullable); } else if (String::isNumber(name)) { size_t index = atoi(name); if (index >= numTypes) { throw ParseException("invalid type index", elem.line, elem.col); } - return builder.getTempRefType(index, nullable); + return builder.getTempRefType(builder[index], nullable); } else { return Type(stringToHeapType(name), nullable); } @@ -725,12 +725,15 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { break; } if (idx->dollared()) { - return builder.getTempRttType(typeIndices[idx->c_str()], depth); + HeapType type = builder[typeIndices[idx->c_str()]]; + return builder.getTempRttType(Rtt(depth, type)); } else if (String::isNumber(idx->c_str())) { - return builder.getTempRttType(atoi(idx->c_str()), depth); - } else { - throw ParseException("invalid type index", idx->line, idx->col); + size_t index = atoi(idx->c_str()); + if (index < numTypes) { + return builder.getTempRttType(Rtt(depth, builder[index])); + } } + throw ParseException("invalid type index", idx->line, idx->col); }; auto parseValType = [&](Element& elem) { @@ -840,12 +843,12 @@ void SExpressionWasmBuilder::preParseHeapTypes(Element& module) { Element& def = elem[1]->dollared() ? *elem[2] : *elem[1]; Element& kind = *def[0]; if (kind == FUNC) { - builder.setHeapType(index++, parseSignatureDef(def)); + builder[index++] = parseSignatureDef(def); } else if (kind == STRUCT) { - builder.setHeapType(index, parseStructDef(def, index)); + builder[index] = parseStructDef(def, index); index++; } else if (kind == ARRAY) { - builder.setHeapType(index++, parseArrayDef(def)); + builder[index++] = parseArrayDef(def); } else { throw ParseException("unknown heaptype kind", kind.line, kind.col); } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 12b223707..73ff47ee8 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -80,16 +80,19 @@ struct HeapTypeInfo { // TypeBuilder interface, so hashing and equality use pointer identity. bool isFinalized = true; enum Kind { + BasicKind, SignatureKind, StructKind, ArrayKind, } kind; union { + HeapType::BasicHeapType basic; Signature signature; Struct struct_; Array array; }; + HeapTypeInfo(HeapType::BasicHeapType basic) : kind(BasicKind), basic(basic) {} HeapTypeInfo(Signature sig) : kind(SignatureKind), signature(sig) {} HeapTypeInfo(const Struct& struct_) : kind(StructKind), struct_(struct_) {} HeapTypeInfo(Struct&& struct_) @@ -323,6 +326,9 @@ bool TypeInfo::operator==(const TypeInfo& other) const { HeapTypeInfo::HeapTypeInfo(const HeapTypeInfo& other) { kind = other.kind; switch (kind) { + case BasicKind: + new (&basic) auto(other.basic); + return; case SignatureKind: new (&signature) auto(other.signature); return; @@ -338,6 +344,8 @@ HeapTypeInfo::HeapTypeInfo(const HeapTypeInfo& other) { HeapTypeInfo::~HeapTypeInfo() { switch (kind) { + case BasicKind: + return; case SignatureKind: signature.~Signature(); return; @@ -385,8 +393,11 @@ struct TypeStore : Store<TypeInfo> { }; struct HeapTypeStore : Store<HeapTypeInfo> { - HeapType canonicalize(const HeapTypeInfo& other) { - return Store<HeapTypeInfo>::canonicalize(other); + HeapType canonicalize(const HeapTypeInfo& info) { + if (info.kind == HeapTypeInfo::BasicKind) { + return info.basic; + } + return Store<HeapTypeInfo>::canonicalize(info); } HeapType canonicalize(std::unique_ptr<HeapTypeInfo>&& info); }; @@ -438,6 +449,9 @@ typename Info::type_t Store<Info>::canonicalize(const Info& info) { } HeapType HeapTypeStore::canonicalize(std::unique_ptr<HeapTypeInfo>&& info) { + if (info->kind == HeapTypeInfo::BasicKind) { + return info->basic; + } std::lock_guard<std::mutex> lock(mutex); auto indexIt = typeIDs.find(std::cref(*info)); if (indexIt != typeIDs.end()) { @@ -993,6 +1007,8 @@ bool TypeComparator::lessThan(const HeapTypeInfo& a, const HeapTypeInfo& b) { return a.kind < b.kind; } switch (a.kind) { + case HeapTypeInfo::BasicKind: + return a.basic < b.basic; case HeapTypeInfo::SignatureKind: return lessThan(a.signature, b.signature); case HeapTypeInfo::StructKind: @@ -1411,6 +1427,9 @@ size_t FiniteShapeHasher::hash(const HeapTypeInfo& info) { } rehash(digest, info.kind); switch (info.kind) { + case HeapTypeInfo::BasicKind: + hash_combine(digest, wasm::hash(info.basic)); + return digest; case HeapTypeInfo::SignatureKind: hash_combine(digest, hash(info.signature)); return digest; @@ -1523,6 +1542,8 @@ bool FiniteShapeEquator::eq(const HeapTypeInfo& a, const HeapTypeInfo& b) { return false; } switch (a.kind) { + case HeapTypeInfo::BasicKind: + return a.basic == b.basic; case HeapTypeInfo::SignatureKind: return eq(a.signature, b.signature); case HeapTypeInfo::StructKind: @@ -1601,26 +1622,43 @@ TypeBuilder::TypeBuilder(size_t n) { TypeBuilder::~TypeBuilder() = default; +void TypeBuilder::grow(size_t n) { + assert(size() + n > size()); + impl->entries.resize(size() + n); +} + +size_t TypeBuilder::size() { return impl->entries.size(); } + +void TypeBuilder::setHeapType(size_t i, HeapType::BasicHeapType basic) { + assert(i < size() && "Index out of bounds"); + impl->entries[i].set(basic); +} + void TypeBuilder::setHeapType(size_t i, Signature signature) { - assert(i < impl->entries.size() && "Index out of bounds"); + assert(i < size() && "Index out of bounds"); impl->entries[i].set(signature); } void TypeBuilder::setHeapType(size_t i, const Struct& struct_) { - assert(i < impl->entries.size() && "index out of bounds"); + assert(i < size() && "index out of bounds"); impl->entries[i].set(struct_); } void TypeBuilder::setHeapType(size_t i, Struct&& struct_) { - assert(i < impl->entries.size() && "index out of bounds"); + assert(i < size() && "index out of bounds"); impl->entries[i].set(std::move(struct_)); } void TypeBuilder::setHeapType(size_t i, Array array) { - assert(i < impl->entries.size() && "index out of bounds"); + assert(i < size() && "index out of bounds"); impl->entries[i].set(array); } +HeapType TypeBuilder::getTempHeapType(size_t i) { + assert(i < size() && "index out of bounds"); + return impl->entries[i].get(); +} + Type TypeBuilder::getTempTupleType(const Tuple& tuple) { Type ret = impl->typeStore.canonicalize(tuple); if (tuple.types.size() > 1) { @@ -1631,16 +1669,12 @@ Type TypeBuilder::getTempTupleType(const Tuple& tuple) { } } -Type TypeBuilder::getTempRefType(size_t i, Nullability nullable) { - assert(i < impl->entries.size() && "Index out of bounds"); - return markTemp( - impl->typeStore.canonicalize(TypeInfo(impl->entries[i].get(), nullable))); +Type TypeBuilder::getTempRefType(HeapType type, Nullability nullable) { + return markTemp(impl->typeStore.canonicalize(TypeInfo(type, nullable))); } -Type TypeBuilder::getTempRttType(size_t i, uint32_t depth) { - assert(i < impl->entries.size() && "Index out of bounds"); - return markTemp( - impl->typeStore.canonicalize(Rtt(depth, impl->entries[i].get()))); +Type TypeBuilder::getTempRttType(Rtt rtt) { + return markTemp(impl->typeStore.canonicalize(rtt)); } namespace { @@ -1879,6 +1913,8 @@ std::vector<HeapType*> ShapeCanonicalizer::getChildren(HeapType heapType) { assert(!heapType.isBasic() && "Cannot have basic defined HeapType"); auto* info = getHeapTypeInfo(heapType); switch (info->kind) { + case HeapTypeInfo::BasicKind: + return children; case HeapTypeInfo::SignatureKind: scanType(info->signature.params); scanType(info->signature.results); @@ -2049,6 +2085,8 @@ void GlobalCanonicalizer::scanHeapType(HeapType* ht) { auto* info = getHeapTypeInfo(*ht); switch (info->kind) { + case HeapTypeInfo::BasicKind: + break; case HeapTypeInfo::SignatureKind: noteChild(*ht, &info->signature.params); noteChild(*ht, &info->signature.results); diff --git a/test/example/type-builder.cpp b/test/example/type-builder.cpp index 0ac13f974..29ff20f75 100644 --- a/test/example/type-builder.cpp +++ b/test/example/type-builder.cpp @@ -13,13 +13,16 @@ void test_builder() { // (type $struct (struct (field (ref null $array) (mut rtt 0 $array)))) // (type $array (array (mut externref))) - TypeBuilder builder(3); - - Type refSig = builder.getTempRefType(0, NonNullable); - Type refStruct = builder.getTempRefType(1, NonNullable); - Type refArray = builder.getTempRefType(2, NonNullable); - Type refNullArray = builder.getTempRefType(2, Nullable); - Type rttArray = builder.getTempRttType(2, 0); + TypeBuilder builder; + assert(builder.size() == 0); + builder.grow(3); + assert(builder.size() == 3); + + Type refSig = builder.getTempRefType(builder[0], NonNullable); + Type refStruct = builder.getTempRefType(builder[1], NonNullable); + Type refArray = builder.getTempRefType(builder[2], NonNullable); + Type refNullArray = builder.getTempRefType(builder[2], Nullable); + Type rttArray = builder.getTempRttType(Rtt(0, builder[2])); Type refNullExt(HeapType::ext, Nullable); Signature sig(refStruct, builder.getTempTupleType({refArray, Type::i32})); @@ -33,9 +36,9 @@ void test_builder() { std::cout << "(ref null $array) => " << refNullArray << "\n"; std::cout << "(rtt 0 $array) => " << rttArray << "\n\n"; - builder.setHeapType(0, sig); - builder.setHeapType(1, struct_); - builder.setHeapType(2, array); + builder[0] = sig; + builder[1] = struct_; + builder[2] = array; std::cout << "After setting heap types:\n"; std::cout << "(ref $sig) => " << refSig << "\n"; @@ -78,19 +81,19 @@ void test_canonicalization() { TypeBuilder builder(4); - Type tempSigRef1 = builder.getTempRefType(2, Nullable); - Type tempSigRef2 = builder.getTempRefType(3, Nullable); + Type tempSigRef1 = builder.getTempRefType(builder[2], Nullable); + Type tempSigRef2 = builder.getTempRefType(builder[3], Nullable); assert(tempSigRef1 != tempSigRef2); assert(tempSigRef1 != Type(sig, Nullable)); assert(tempSigRef2 != Type(sig, Nullable)); - builder.setHeapType( - 0, Struct({Field(tempSigRef1, Immutable), Field(tempSigRef1, Immutable)})); - builder.setHeapType( - 1, Struct({Field(tempSigRef2, Immutable), Field(tempSigRef2, Immutable)})); - builder.setHeapType(2, Signature(Type::none, Type::none)); - builder.setHeapType(3, Signature(Type::none, Type::none)); + builder[0] = + Struct({Field(tempSigRef1, Immutable), Field(tempSigRef1, Immutable)}); + builder[1] = + Struct({Field(tempSigRef2, Immutable), Field(tempSigRef2, Immutable)}); + builder[2] = Signature(Type::none, Type::none); + builder[3] = Signature(Type::none, Type::none); std::vector<HeapType> built = builder.build(); @@ -108,8 +111,8 @@ void test_recursive() { std::vector<HeapType> built; { TypeBuilder builder(1); - Type temp = builder.getTempRefType(0, Nullable); - builder.setHeapType(0, Signature(Type::none, temp)); + Type temp = builder.getTempRefType(builder[0], Nullable); + builder[0] = Signature(Type::none, temp); built = builder.build(); } std::cout << built[0] << "\n\n"; @@ -122,10 +125,10 @@ void test_recursive() { std::vector<HeapType> built; { TypeBuilder builder(2); - Type temp0 = builder.getTempRefType(0, Nullable); - Type temp1 = builder.getTempRefType(1, Nullable); - builder.setHeapType(0, Signature(Type::none, temp1)); - builder.setHeapType(1, Signature(Type::none, temp0)); + Type temp0 = builder.getTempRefType(builder[0], Nullable); + Type temp1 = builder.getTempRefType(builder[1], Nullable); + builder[0] = Signature(Type::none, temp1); + builder[1] = Signature(Type::none, temp0); built = builder.build(); } std::cout << built[0] << "\n"; @@ -140,16 +143,16 @@ void test_recursive() { std::vector<HeapType> built; { TypeBuilder builder(5); - Type temp0 = builder.getTempRefType(0, Nullable); - Type temp1 = builder.getTempRefType(1, Nullable); - Type temp2 = builder.getTempRefType(2, Nullable); - Type temp3 = builder.getTempRefType(3, Nullable); - Type temp4 = builder.getTempRefType(4, Nullable); - builder.setHeapType(0, Signature(Type::none, temp1)); - builder.setHeapType(1, Signature(Type::none, temp2)); - builder.setHeapType(2, Signature(Type::none, temp3)); - builder.setHeapType(3, Signature(Type::none, temp4)); - builder.setHeapType(4, Signature(Type::none, temp0)); + Type temp0 = builder.getTempRefType(builder[0], Nullable); + Type temp1 = builder.getTempRefType(builder[1], Nullable); + Type temp2 = builder.getTempRefType(builder[2], Nullable); + Type temp3 = builder.getTempRefType(builder[3], Nullable); + Type temp4 = builder.getTempRefType(builder[4], Nullable); + builder[0] = Signature(Type::none, temp1); + builder[1] = Signature(Type::none, temp2); + builder[2] = Signature(Type::none, temp3); + builder[3] = Signature(Type::none, temp4); + builder[4] = Signature(Type::none, temp0); built = builder.build(); } std::cout << built[0] << "\n"; @@ -174,18 +177,18 @@ void test_recursive() { std::vector<HeapType> built; { TypeBuilder builder(6); - Type temp0 = builder.getTempRefType(0, Nullable); - Type temp1 = builder.getTempRefType(1, Nullable); - Type temp2 = builder.getTempRefType(2, Nullable); - Type temp3 = builder.getTempRefType(3, Nullable); + Type temp0 = builder.getTempRefType(builder[0], Nullable); + Type temp1 = builder.getTempRefType(builder[1], Nullable); + Type temp2 = builder.getTempRefType(builder[2], Nullable); + Type temp3 = builder.getTempRefType(builder[3], Nullable); Type tuple0_2 = builder.getTempTupleType({temp0, temp2}); Type tuple1_3 = builder.getTempTupleType({temp1, temp3}); - builder.setHeapType(0, Signature(Type::none, tuple0_2)); - builder.setHeapType(1, Signature(Type::none, tuple1_3)); - builder.setHeapType(2, Signature()); - builder.setHeapType(3, Signature()); - builder.setHeapType(4, Signature(Type::none, temp0)); - builder.setHeapType(5, Signature(Type::none, temp1)); + builder[0] = Signature(Type::none, tuple0_2); + builder[1] = Signature(Type::none, tuple1_3); + builder[2] = Signature(); + builder[3] = Signature(); + builder[4] = Signature(Type::none, temp0); + builder[5] = Signature(Type::none, temp1); built = builder.build(); } std::cout << built[0] << "\n"; @@ -210,16 +213,38 @@ void test_recursive() { std::vector<HeapType> built; { TypeBuilder builder(2); - Type temp0 = builder.getTempRefType(0, Nullable); - builder.setHeapType(0, Signature(Type::none, temp0)); - builder.setHeapType(1, Signature(Type::none, temp0)); + Type temp0 = builder.getTempRefType(builder[0], Nullable); + builder[0] = Signature(Type::none, temp0); + builder[1] = Signature(Type::none, temp0); + built = builder.build(); + } + std::cout << built[0] << "\n"; + std::cout << built[1] << "\n\n"; + assert(built[0].getSignature().results.getHeapType() == built[0]); + assert(built[1].getSignature().results.getHeapType() == built[0]); + assert(built[0] == built[1]); + } + + { + // Including a basic heap type + std::vector<HeapType> built; + { + TypeBuilder builder(3); + Type temp0 = builder.getTempRefType(builder[0], Nullable); + Type anyref = builder.getTempRefType(builder[2], Nullable); + builder[0] = Signature(anyref, temp0); + builder[1] = Signature(anyref, temp0); + builder[2] = HeapType::any; built = builder.build(); } std::cout << built[0] << "\n"; std::cout << built[1] << "\n\n"; assert(built[0].getSignature().results.getHeapType() == built[0]); assert(built[1].getSignature().results.getHeapType() == built[0]); + assert(built[0].getSignature().params == Type::anyref); + assert(built[1].getSignature().params == Type::anyref); assert(built[0] == built[1]); + assert(built[2] == HeapType::any); } } diff --git a/test/example/type-builder.txt b/test/example/type-builder.txt index a2985c2ed..f65a94321 100644 --- a/test/example/type-builder.txt +++ b/test/example/type-builder.txt @@ -43,3 +43,6 @@ After building types: (func (result (ref null ...1))) (func (result (ref null ...1))) +(func (param anyref) (result (ref null ...1))) +(func (param anyref) (result (ref null ...1))) + |