diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/binaryen-c.cpp | 8 | ||||
-rw-r--r-- | src/binaryen-c.h | 3 | ||||
-rw-r--r-- | src/ir/ReFinalize.cpp | 6 | ||||
-rw-r--r-- | src/ir/module-utils.h | 78 | ||||
-rw-r--r-- | src/js/binaryen.js-post.js | 4 | ||||
-rw-r--r-- | src/passes/InstrumentLocals.cpp | 81 | ||||
-rw-r--r-- | src/tools/fuzzing.h | 54 | ||||
-rw-r--r-- | src/tools/tool-options.h | 2 | ||||
-rw-r--r-- | src/wasm-binary.h | 86 | ||||
-rw-r--r-- | src/wasm-builder.h | 14 | ||||
-rw-r--r-- | src/wasm-features.h | 11 | ||||
-rw-r--r-- | src/wasm.h | 1 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 145 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 5 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 18 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 8 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 1 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 8 |
18 files changed, 367 insertions, 166 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 443549de9..7365a3863 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -1187,9 +1187,11 @@ BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module, Builder(*(Module*)module).makeRefIsNull((Expression*)value)); } -BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module, - const char* func) { - return static_cast<Expression*>(Builder(*(Module*)module).makeRefFunc(func)); +BinaryenExpressionRef +BinaryenRefFunc(BinaryenModuleRef module, const char* func, BinaryenType type) { + Type type_(type); + return static_cast<Expression*>( + Builder(*(Module*)module).makeRefFunc(func, type_)); } BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module, diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 45beb3657..c4517257a 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -792,7 +792,8 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module, BINARYEN_API BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module, BinaryenExpressionRef value); BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module, - const char* func); + const char* func, + BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module, BinaryenExpressionRef left, BinaryenExpressionRef right); diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index 19fed54a7..448c1f30a 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp @@ -126,7 +126,11 @@ void ReFinalize::visitMemorySize(MemorySize* curr) { curr->finalize(); } void ReFinalize::visitMemoryGrow(MemoryGrow* curr) { curr->finalize(); } void ReFinalize::visitRefNull(RefNull* curr) { curr->finalize(); } void ReFinalize::visitRefIsNull(RefIsNull* curr) { curr->finalize(); } -void ReFinalize::visitRefFunc(RefFunc* curr) { curr->finalize(); } +void ReFinalize::visitRefFunc(RefFunc* curr) { + // TODO: should we look up the function and update the type from there? This + // could handle a change to the function's type, but is also not really what + // this class has been meant to do. +} void ReFinalize::visitRefEq(RefEq* curr) { curr->finalize(); } void ReFinalize::visitTry(Try* curr) { curr->finalize(); } void ReFinalize::visitThrow(Throw* curr) { curr->finalize(); } diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index a2d256073..c8776297c 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -414,16 +414,29 @@ collectSignatures(Module& wasm, Counts& counts; TypeCounter(Counts& counts) : counts(counts) {} + void visitExpression(Expression* curr) { - if (auto* call = curr->dynCast<CallIndirect>()) { + if (curr->is<RefNull>()) { + maybeNote(curr->type); + } else if (auto* call = curr->dynCast<CallIndirect>()) { counts[call->sig]++; } else if (Properties::isControlFlowStructure(curr)) { - // TODO: Allow control flow to have input types as well + maybeNote(curr->type); if (curr->type.isTuple()) { + // TODO: Allow control flow to have input types as well counts[Signature(Type::none, curr->type)]++; } } } + + void maybeNote(Type type) { + if (type.isRef()) { + auto heapType = type.getHeapType(); + if (heapType.isSignature()) { + counts[heapType.getSignature()]++; + } + } + } }; TypeCounter(counts).walk(func->body); }; @@ -434,6 +447,14 @@ collectSignatures(Module& wasm, Counts counts; for (auto& curr : wasm.functions) { counts[curr->sig]++; + for (auto type : curr->vars) { + if (type.isRef()) { + auto heapType = type.getHeapType(); + if (heapType.isSignature()) { + counts[heapType.getSignature()]++; + } + } + } } for (auto& curr : wasm.events) { counts[curr->sig]++; @@ -444,10 +465,61 @@ collectSignatures(Module& wasm, counts[innerPair.first] += innerPair.second; } } + + // TODO: recursively traverse each reference type, which may have a child type + // this is itself a reference type. + + // We must sort all the dependencies of a signature before it. For example, + // (func (param (ref (func)))) must appear after (func). To do that, find the + // depth of dependencies of each signature. For example, if A depends on B + // which depends on C, then A's depth is 2, B's is 1, and C's is 0 (assuming + // no other dependencies). + Counts depthOfDependencies; + std::unordered_map<Signature, std::unordered_set<Signature>> isDependencyOf; + // To calculate the depth of dependencies, we'll do a flow analysis, visiting + // each signature as we find out new things about it. + std::set<Signature> toVisit; + for (auto& pair : counts) { + auto sig = pair.first; + depthOfDependencies[sig] = 0; + toVisit.insert(sig); + for (Type type : {sig.params, sig.results}) { + for (auto element : type) { + if (element.isRef()) { + auto heapType = element.getHeapType(); + if (heapType.isSignature()) { + isDependencyOf[heapType.getSignature()].insert(sig); + } + } + } + } + } + while (!toVisit.empty()) { + auto iter = toVisit.begin(); + auto sig = *iter; + toVisit.erase(iter); + // Anything that depends on this has a depth of dependencies equal to this + // signature's, plus this signature itself. + auto newDepth = depthOfDependencies[sig] + 1; + if (newDepth > counts.size()) { + Fatal() << "Cyclic signatures detected, cannot sort them."; + } + for (auto& other : isDependencyOf[sig]) { + if (depthOfDependencies[other] < newDepth) { + // We found something new to propagate. + depthOfDependencies[other] = newDepth; + toVisit.insert(other); + } + } + } + // Sort by frequency and then simplicity, and also keeping every signature + // before things that depend on it. std::vector<std::pair<Signature, size_t>> sorted(counts.begin(), counts.end()); std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { - // order by frequency then simplicity + if (depthOfDependencies[a.first] != depthOfDependencies[b.first]) { + return depthOfDependencies[a.first] < depthOfDependencies[b.first]; + } if (a.second != b.second) { return a.second > b.second; } diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index b78fdd996..9abe5e718 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -2112,8 +2112,8 @@ function wrapModule(module, self = {}) { 'is_null'(value) { return Module['_BinaryenRefIsNull'](module, value); }, - 'func'(func) { - return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func))); + 'func'(func, type) { + return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func), type)); }, 'eq'(left, right) { return Module['_BinaryenRefEq'](module, left, right); diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp index 81494463f..004bfba74 100644 --- a/src/passes/InstrumentLocals.cpp +++ b/src/passes/InstrumentLocals.cpp @@ -135,45 +135,48 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { Builder builder(*getModule()); Name import; auto type = curr->value->type; - if (type.isFunction()) { - import = set_funcref; - } else { - TODO_SINGLE_COMPOUND(curr->value->type); - switch (type.getBasic()) { - case Type::i32: - import = set_i32; - break; - case Type::i64: - return; // TODO - case Type::f32: - import = set_f32; - break; - case Type::f64: - import = set_f64; - break; - case Type::v128: - import = set_v128; - break; - case Type::externref: - import = set_externref; - break; - case Type::exnref: - import = set_exnref; - break; - case Type::anyref: - import = set_anyref; - break; - case Type::eqref: - import = set_eqref; - break; - case Type::i31ref: - import = set_i31ref; - break; - case Type::unreachable: - return; // nothing to do here - default: - WASM_UNREACHABLE("unexpected type"); - } + if (type.isFunction() && type != Type::funcref) { + // FIXME: support typed function references + return; + } + TODO_SINGLE_COMPOUND(curr->value->type); + switch (type.getBasic()) { + case Type::i32: + import = set_i32; + break; + case Type::i64: + return; // TODO + case Type::f32: + import = set_f32; + break; + case Type::f64: + import = set_f64; + break; + case Type::v128: + import = set_v128; + break; + case Type::funcref: + import = set_funcref; + break; + case Type::externref: + import = set_externref; + break; + case Type::exnref: + import = set_exnref; + break; + case Type::anyref: + import = set_anyref; + break; + case Type::eqref: + import = set_eqref; + break; + case Type::i31ref: + import = set_i31ref; + break; + case Type::unreachable: + return; // nothing to do here + default: + WASM_UNREACHABLE("unexpected type"); } curr->value = builder.makeCall(import, {builder.makeConst(int32_t(id++)), diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index df7cf2226..1c1359586 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -321,6 +321,10 @@ private: } return Type(types); } + if (type.isFunction() && type != Type::funcref) { + // TODO: specific typed function references types. + return type; + } SmallVector<Type, 2> options; options.push_back(type); // includes itself TODO_SINGLE_COMPOUND(type); @@ -653,6 +657,10 @@ private: Index numVars = upToSquared(MAX_VARS); for (Index i = 0; i < numVars; i++) { auto type = getConcreteType(); + if (type.isRef() && !type.isNullable()) { + // We can't use a nullable type as a var, which is null-initialized. + continue; + } funcContext->typeLocals[type].push_back(params.size() + func->vars.size()); func->vars.push_back(type); @@ -1371,7 +1379,6 @@ private: } Expression* makeCall(Type type) { - // seems ok, go on int tries = TRIES; bool isReturn; while (tries-- > 0) { @@ -1392,7 +1399,7 @@ private: return builder.makeCall(target->name, args, type, isReturn); } // we failed to find something - return make(type); + return makeTrivial(type); } Expression* makeCallIndirect(Type type) { @@ -1418,7 +1425,7 @@ private: i = 0; } if (i == start) { - return make(type); + return makeTrivial(type); } } // with high probability, make sure the type is valid otherwise, most are @@ -2018,12 +2025,28 @@ private: if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) { target = pick(wasm.functions).get(); } - return builder.makeRefFunc(target->name); + auto type = Type(HeapType(target->sig), /* nullable = */ true); + return builder.makeRefFunc(target->name, type); } if (type == Type::i31ref) { return builder.makeI31New(makeConst(Type::i32)); } - return builder.makeRefNull(type); + if (oneIn(2) && type.isNullable()) { + return builder.makeRefNull(type); + } + // TODO: randomize the order + for (auto& func : wasm.functions) { + // FIXME: RefFunc type should be non-nullable, but we emit nullable + // types for now. + if (type == Type(HeapType(func->sig), /* nullable = */ true)) { + return builder.makeRefFunc(func->name, type); + } + } + // We failed to find a function, so create a null reference if we can. + if (type.isNullable()) { + return builder.makeRefNull(type); + } + WASM_UNREACHABLE("un-handleable non-nullable type"); } if (type.isTuple()) { std::vector<Expression*> operands; @@ -2972,6 +2995,7 @@ private: Type::anyref, Type::eqref, Type::i31ref)); + // TODO: emit typed function references types } Type getSingleConcreteType() { return pick(getSingleConcreteTypes()); } @@ -2997,12 +3021,24 @@ private: Type getEqReferenceType() { return pick(getEqReferenceTypes()); } + Type getMVPType() { + return pick(items(FeatureOptions<Type>().add( + FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64))); + } + Type getTupleType() { std::vector<Type> elements; - size_t numElements = 2 + upTo(MAX_TUPLE_SIZE - 1); - elements.resize(numElements); - for (size_t i = 0; i < numElements; ++i) { - elements[i] = getSingleConcreteType(); + size_t maxElements = 2 + upTo(MAX_TUPLE_SIZE - 1); + for (size_t i = 0; i < maxElements; ++i) { + auto type = getSingleConcreteType(); + // Don't add a non-nullable type into a tuple, as currently we can't spill + // them into locals (that would require a "let"). + if (!type.isNullable()) { + elements.push_back(type); + } + } + while (elements.size() < 2) { + elements.push_back(getMVPType()); } return Type(elements); } diff --git a/src/tools/tool-options.h b/src/tools/tool-options.h index 4b084e191..70ce4efc0 100644 --- a/src/tools/tool-options.h +++ b/src/tools/tool-options.h @@ -89,6 +89,8 @@ struct ToolOptions : public Options { .addFeature(FeatureSet::Multivalue, "multivalue functions") .addFeature(FeatureSet::GC, "garbage collection") .addFeature(FeatureSet::Memory64, "memory64") + .addFeature(FeatureSet::TypedFunctionReferences, + "typed function references") .add("--no-validation", "-n", "Disables validation, assumes inputs are correct", diff --git a/src/wasm-binary.h b/src/wasm-binary.h index ef3f9c9d1..0918151c5 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -346,6 +346,10 @@ enum EncodedType { anyref = -0x12, // 0x6e // comparable reference type eqref = -0x13, // 0x6d + // nullable typed function reference type, with parameter + nullable = -0x14, // 0x6c + // non-nullable typed function reference type, with parameter + nonnullable = -0x15, // 0x6b // integer reference type i31ref = -0x16, // 0x6a // exception reference type @@ -386,6 +390,7 @@ extern const char* ReferenceTypesFeature; extern const char* MultivalueFeature; extern const char* GCFeature; extern const char* Memory64Feature; +extern const char* TypedFunctionReferencesFeature; enum Subsection { NameModule = 0, @@ -1009,82 +1014,6 @@ enum FeaturePrefix { } // namespace BinaryConsts -inline S32LEB binaryType(Type type) { - int ret = 0; - TODO_SINGLE_COMPOUND(type); - switch (type.getBasic()) { - // None only used for block signatures. TODO: Separate out? - case Type::none: - ret = BinaryConsts::EncodedType::Empty; - break; - case Type::i32: - ret = BinaryConsts::EncodedType::i32; - break; - case Type::i64: - ret = BinaryConsts::EncodedType::i64; - break; - case Type::f32: - ret = BinaryConsts::EncodedType::f32; - break; - case Type::f64: - ret = BinaryConsts::EncodedType::f64; - break; - case Type::v128: - ret = BinaryConsts::EncodedType::v128; - break; - case Type::funcref: - ret = BinaryConsts::EncodedType::funcref; - break; - case Type::externref: - ret = BinaryConsts::EncodedType::externref; - break; - case Type::exnref: - ret = BinaryConsts::EncodedType::exnref; - break; - case Type::anyref: - ret = BinaryConsts::EncodedType::anyref; - break; - case Type::eqref: - ret = BinaryConsts::EncodedType::eqref; - break; - case Type::i31ref: - ret = BinaryConsts::EncodedType::i31ref; - break; - case Type::unreachable: - WASM_UNREACHABLE("unexpected type"); - } - return S32LEB(ret); -} - -inline S32LEB binaryHeapType(HeapType type) { - int ret = 0; - switch (type.kind) { - case HeapType::FuncKind: - ret = BinaryConsts::EncodedHeapType::func; - break; - case HeapType::ExternKind: - ret = BinaryConsts::EncodedHeapType::extern_; - break; - case HeapType::ExnKind: - ret = BinaryConsts::EncodedHeapType::exn; - break; - case HeapType::AnyKind: - ret = BinaryConsts::EncodedHeapType::any; - break; - case HeapType::EqKind: - ret = BinaryConsts::EncodedHeapType::eq; - break; - case HeapType::I31Kind: - ret = BinaryConsts::EncodedHeapType::i31; - break; - case HeapType::SignatureKind: - case HeapType::StructKind: - case HeapType::ArrayKind: - WASM_UNREACHABLE("TODO: compound GC types"); - } - return S32LEB(ret); // TODO: Actually encoded as s33 -} - // Writes out wasm to the binary format class WasmBinaryWriter { @@ -1234,6 +1163,9 @@ public: Module* getModule() { return wasm; } + void writeType(Type type); + void writeHeapType(HeapType type); + private: Module* wasm; BufferWithRandomAccess& o; @@ -1342,6 +1274,8 @@ public: std::vector<Signature> functionSignatures; void readFunctionSignatures(); + Signature getFunctionSignatureByIndex(Index index); + size_t nextLabel; Name getNextLabel(); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index d3af93896..6800aa2ed 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -588,10 +588,10 @@ public: ret->finalize(); return ret; } - RefFunc* makeRefFunc(Name func) { + RefFunc* makeRefFunc(Name func, Type type) { auto* ret = wasm.allocator.alloc<RefFunc>(); ret->func = func; - ret->finalize(); + ret->finalize(type); return ret; } RefEq* makeRefEq(Expression* left, Expression* right) { @@ -769,8 +769,7 @@ public: } if (type.isFunction()) { if (!value.isNull()) { - // TODO: with typed function references we need to do more for the type - return makeRefFunc(value.getFunc()); + return makeRefFunc(value.getFunc(), type); } return makeRefNull(type); } @@ -951,7 +950,12 @@ public: return makeConstantExpression(Literal::makeZeros(curr->type)); } if (curr->type.isFunction()) { - return ExpressionManipulator::refNull(curr, curr->type); + if (curr->type.isNullable()) { + return ExpressionManipulator::refNull(curr, curr->type); + } else { + // We can't do any better, keep the original. + return curr; + } } Literal value; // TODO: reuse node conditionally when possible for literals diff --git a/src/wasm-features.h b/src/wasm-features.h index a2bb52971..d2e3f343f 100644 --- a/src/wasm-features.h +++ b/src/wasm-features.h @@ -38,7 +38,8 @@ struct FeatureSet { Multivalue = 1 << 9, GC = 1 << 10, Memory64 = 1 << 11, - All = (1 << 12) - 1 + TypedFunctionReferences = 1 << 12, + All = (1 << 13) - 1 }; static std::string toString(Feature f) { @@ -67,6 +68,8 @@ struct FeatureSet { return "gc"; case Memory64: return "memory64"; + case TypedFunctionReferences: + return "typed-function-references"; default: WASM_UNREACHABLE("unexpected feature"); } @@ -92,6 +95,9 @@ struct FeatureSet { bool hasMultivalue() const { return (features & Multivalue) != 0; } bool hasGC() const { return (features & GC) != 0; } bool hasMemory64() const { return (features & Memory64) != 0; } + bool hasTypedFunctionReferences() const { + return (features & TypedFunctionReferences) != 0; + } bool hasAll() const { return (features & All) != 0; } void makeMVP() { features = MVP; } @@ -110,6 +116,9 @@ struct FeatureSet { void setMultivalue(bool v = true) { set(Multivalue, v); } void setGC(bool v = true) { set(GC, v); } void setMemory64(bool v = true) { set(Memory64, v); } + void setTypedFunctionReferences(bool v = true) { + set(TypedFunctionReferences, v); + } void setAll(bool v = true) { features = v ? All : MVP; } void enable(const FeatureSet& other) { features |= other.features; } diff --git a/src/wasm.h b/src/wasm.h index 1204eee0f..e9fb4461b 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1198,6 +1198,7 @@ public: Name func; void finalize(); + void finalize(Type type_); }; class RefEq : public SpecificExpression<Expression::RefEqId> { diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index b343c6caf..a96039bc2 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -221,7 +221,7 @@ void WasmBinaryWriter::writeTypes() { for (auto& sigType : {sig.params, sig.results}) { o << U32LEB(sigType.size()); for (const auto& type : sigType) { - o << binaryType(type); + writeType(type); } } } @@ -250,7 +250,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one global\n"); writeImportHeader(global); o << U32LEB(int32_t(ExternalKind::Global)); - o << binaryType(global->type); + writeType(global->type); o << U32LEB(global->mutable_); }); ModuleUtils::iterImportedEvents(*wasm, [&](Event* event) { @@ -389,7 +389,7 @@ void WasmBinaryWriter::writeGlobals() { BYN_TRACE("write one\n"); size_t i = 0; for (const auto& t : global->type) { - o << binaryType(t); + writeType(t); o << U32LEB(global->mutable_); if (global->type.size() == 1) { writeExpression(global->init); @@ -492,7 +492,12 @@ uint32_t WasmBinaryWriter::getEventIndex(Name name) const { uint32_t WasmBinaryWriter::getTypeIndex(Signature sig) const { auto it = typeIndices.find(sig); - assert(it != typeIndices.end()); +#ifndef NDEBUG + if (it == typeIndices.end()) { + std::cout << "Missing signature: " << sig << '\n'; + assert(0); + } +#endif return it->second; } @@ -799,6 +804,8 @@ void WasmBinaryWriter::writeFeaturesSection() { return BinaryConsts::UserSections::GCFeature; case FeatureSet::Memory64: return BinaryConsts::UserSections::Memory64Feature; + case FeatureSet::TypedFunctionReferences: + return BinaryConsts::UserSections::TypedFunctionReferencesFeature; default: WASM_UNREACHABLE("unexpected feature flag"); } @@ -950,6 +957,100 @@ void WasmBinaryWriter::finishUp() { } } +void WasmBinaryWriter::writeType(Type type) { + if (type.isRef()) { + auto heapType = type.getHeapType(); + // TODO: fully handle non-signature reference types (GC), and in reading + if (heapType.isSignature()) { + if (type.isNullable()) { + o << S32LEB(BinaryConsts::EncodedType::nullable); + } else { + o << S32LEB(BinaryConsts::EncodedType::nonnullable); + } + writeHeapType(heapType); + return; + } + } + int ret = 0; + TODO_SINGLE_COMPOUND(type); + switch (type.getBasic()) { + // None only used for block signatures. TODO: Separate out? + case Type::none: + ret = BinaryConsts::EncodedType::Empty; + break; + case Type::i32: + ret = BinaryConsts::EncodedType::i32; + break; + case Type::i64: + ret = BinaryConsts::EncodedType::i64; + break; + case Type::f32: + ret = BinaryConsts::EncodedType::f32; + break; + case Type::f64: + ret = BinaryConsts::EncodedType::f64; + break; + case Type::v128: + ret = BinaryConsts::EncodedType::v128; + break; + case Type::funcref: + ret = BinaryConsts::EncodedType::funcref; + break; + case Type::externref: + ret = BinaryConsts::EncodedType::externref; + break; + case Type::exnref: + ret = BinaryConsts::EncodedType::exnref; + break; + case Type::anyref: + ret = BinaryConsts::EncodedType::anyref; + break; + case Type::eqref: + ret = BinaryConsts::EncodedType::eqref; + break; + case Type::i31ref: + ret = BinaryConsts::EncodedType::i31ref; + break; + default: + WASM_UNREACHABLE("unexpected type"); + } + o << S32LEB(ret); +} + +void WasmBinaryWriter::writeHeapType(HeapType type) { + if (type.isSignature()) { + auto sig = type.getSignature(); + o << S32LEB(getTypeIndex(sig)); + return; + } + int ret = 0; + switch (type.kind) { + case HeapType::FuncKind: + ret = BinaryConsts::EncodedHeapType::func; + break; + case HeapType::ExternKind: + ret = BinaryConsts::EncodedHeapType::extern_; + break; + case HeapType::ExnKind: + ret = BinaryConsts::EncodedHeapType::exn; + break; + case HeapType::AnyKind: + ret = BinaryConsts::EncodedHeapType::any; + break; + case HeapType::EqKind: + ret = BinaryConsts::EncodedHeapType::eq; + break; + case HeapType::I31Kind: + ret = BinaryConsts::EncodedHeapType::i31; + break; + case HeapType::SignatureKind: + case HeapType::StructKind: + case HeapType::ArrayKind: + WASM_UNREACHABLE("TODO: compound GC types"); + } + o << S32LEB(ret); // TODO: Actually encoded as s33 +} + // reader bool WasmBinaryBuilder::hasDWARFSections() { @@ -1253,6 +1354,10 @@ Type WasmBinaryBuilder::getType() { return Type::anyref; case BinaryConsts::EncodedType::eqref: return Type::eqref; + case BinaryConsts::EncodedType::nullable: + return Type(getHeapType(), /* nullable = */ true); + case BinaryConsts::EncodedType::nonnullable: + return Type(getHeapType(), /* nullable = */ false); case BinaryConsts::EncodedType::i31ref: return Type::i31ref; default: @@ -1581,6 +1686,18 @@ void WasmBinaryBuilder::readFunctionSignatures() { } } +Signature WasmBinaryBuilder::getFunctionSignatureByIndex(Index index) { + Signature sig; + if (index < functionImports.size()) { + return functionImports[index]->sig; + } + Index adjustedIndex = index - functionImports.size(); + if (adjustedIndex >= functionSignatures.size()) { + throwError("invalid function index"); + } + return functionSignatures[adjustedIndex]; +} + void WasmBinaryBuilder::readFunctions() { BYN_TRACE("== readFunctions\n"); size_t total = getU32LEB(); @@ -2471,6 +2588,9 @@ void WasmBinaryBuilder::readFeatures(size_t payloadLen) { wasm.features.setGC(); } else if (name == BinaryConsts::UserSections::Memory64Feature) { wasm.features.setMemory64(); + } else if (name == + BinaryConsts::UserSections::TypedFunctionReferencesFeature) { + wasm.features.setTypedFunctionReferences(); } } } @@ -3042,17 +3162,7 @@ void WasmBinaryBuilder::visitSwitch(Switch* curr) { void WasmBinaryBuilder::visitCall(Call* curr) { BYN_TRACE("zz node: Call\n"); auto index = getU32LEB(); - Signature sig; - if (index < functionImports.size()) { - auto* import = functionImports[index]; - sig = import->sig; - } else { - Index adjustedIndex = index - functionImports.size(); - if (adjustedIndex >= functionSignatures.size()) { - throwError("invalid call index"); - } - sig = functionSignatures[adjustedIndex]; - } + auto sig = getFunctionSignatureByIndex(index); auto num = sig.params.size(); curr->operands.resize(num); for (size_t i = 0; i < num; i++) { @@ -5169,7 +5279,10 @@ void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { throwError("ref.func: invalid call index"); } functionRefs[index].push_back(curr); // we don't know function names yet - curr->finalize(); + // To support typed function refs, we give the reference not just a general + // funcref, but a specific subtype with the actual signature. + curr->finalize( + Type(HeapType(getFunctionSignatureByIndex(index)), /* nullable = */ true)); } void WasmBinaryBuilder::visitRefEq(RefEq* curr) { diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 0636836d7..6286ae090 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1890,7 +1890,10 @@ Expression* SExpressionWasmBuilder::makeRefFunc(Element& s) { auto func = getFunctionName(*s[1]); auto ret = allocator.alloc<RefFunc>(); ret->func = func; - ret->finalize(); + // To support typed function refs, we give the reference not just a general + // funcref, but a specific subtype with the actual signature. + ret->finalize( + Type(HeapType(functionSignatures[func]), /* nullable = */ true)); return ret; } diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index c8a4f7a90..021b05cb6 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -24,11 +24,11 @@ static Name IMPOSSIBLE_CONTINUE("impossible-continue"); void BinaryInstWriter::emitResultType(Type type) { if (type == Type::unreachable) { - o << binaryType(Type::none); + parent.writeType(Type::none); } else if (type.isTuple()) { o << S32LEB(parent.getTypeIndex(Signature(Type::none, type))); } else { - o << binaryType(type); + parent.writeType(type); } } @@ -1756,8 +1756,8 @@ void BinaryInstWriter::visitSelect(Select* curr) { if (curr->type.isRef()) { o << int8_t(BinaryConsts::SelectWithType) << U32LEB(curr->type.size()); for (size_t i = 0; i < curr->type.size(); i++) { - o << binaryType(curr->type != Type::unreachable ? curr->type - : Type::none); + parent.writeType(curr->type != Type::unreachable ? curr->type + : Type::none); } } else { o << int8_t(BinaryConsts::Select); @@ -1779,8 +1779,8 @@ void BinaryInstWriter::visitMemoryGrow(MemoryGrow* curr) { } void BinaryInstWriter::visitRefNull(RefNull* curr) { - o << int8_t(BinaryConsts::RefNull) - << binaryHeapType(curr->type.getHeapType()); + o << int8_t(BinaryConsts::RefNull); + parent.writeHeapType(curr->type.getHeapType()); } void BinaryInstWriter::visitRefIsNull(RefIsNull* curr) { @@ -1966,7 +1966,8 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { o << U32LEB(func->getNumVars()); for (Index i = varStart; i < varEnd; i++) { mappedLocals[std::make_pair(i, 0)] = i; - o << U32LEB(1) << binaryType(func->getLocalType(i)); + o << U32LEB(1); + parent.writeType(func->getLocalType(i)); } return; } @@ -1995,7 +1996,8 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { setScratchLocals(); o << U32LEB(numLocalsByType.size()); for (auto& typeCount : numLocalsByType) { - o << U32LEB(typeCount.second) << binaryType(typeCount.first); + o << U32LEB(typeCount.second); + parent.writeType(typeCount.first); } } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index dc4d50ef4..cf4404739 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -460,6 +460,14 @@ Type Type::reinterpret() const { FeatureSet Type::getFeatures() const { auto getSingleFeatures = [](Type t) -> FeatureSet { + if (t != Type::funcref && t.isFunction()) { + // Strictly speaking, typed function references require the typed function + // references feature, however, we use these types internally regardless + // of the presence of features (in particular, since during load of the + // wasm we don't know the features yet, so we apply the more refined + // types). + return FeatureSet::ReferenceTypes; + } TODO_SINGLE_COMPOUND(t); switch (t.getBasic()) { case Type::v128: diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 809ca5a6a..78e123a90 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2313,6 +2313,7 @@ void FunctionValidator::visitFunction(Function* curr) { for (const auto& var : curr->vars) { features |= var.getFeatures(); shouldBeTrue(var.isConcrete(), curr, "vars must be concretely typed"); + // TODO: check for nullability } shouldBeTrue(features <= getModule()->features, curr->name, diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index c7a187b43..6245a3575 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -47,6 +47,7 @@ const char* ReferenceTypesFeature = "reference-types"; const char* MultivalueFeature = "multivalue"; const char* GCFeature = "gc"; const char* Memory64Feature = "memory64"; +const char* TypedFunctionReferencesFeature = "typed-function-references"; } // namespace UserSections } // namespace BinaryConsts @@ -984,7 +985,12 @@ void RefIsNull::finalize() { type = Type::i32; } -void RefFunc::finalize() { type = Type::funcref; } +void RefFunc::finalize() { + // No-op. We assume that the full proper typed function type has been applied + // previously. +} + +void RefFunc::finalize(Type type_) { type = type_; } void RefEq::finalize() { if (left->type == Type::unreachable || right->type == Type::unreachable) { |