summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/binaryen-c.cpp8
-rw-r--r--src/binaryen-c.h3
-rw-r--r--src/ir/ReFinalize.cpp6
-rw-r--r--src/ir/module-utils.h78
-rw-r--r--src/js/binaryen.js-post.js4
-rw-r--r--src/passes/InstrumentLocals.cpp81
-rw-r--r--src/tools/fuzzing.h54
-rw-r--r--src/tools/tool-options.h2
-rw-r--r--src/wasm-binary.h86
-rw-r--r--src/wasm-builder.h14
-rw-r--r--src/wasm-features.h11
-rw-r--r--src/wasm.h1
-rw-r--r--src/wasm/wasm-binary.cpp145
-rw-r--r--src/wasm/wasm-s-parser.cpp5
-rw-r--r--src/wasm/wasm-stack.cpp18
-rw-r--r--src/wasm/wasm-type.cpp8
-rw-r--r--src/wasm/wasm-validator.cpp1
-rw-r--r--src/wasm/wasm.cpp8
18 files changed, 367 insertions, 166 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index 443549de9..7365a3863 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -1187,9 +1187,11 @@ BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module,
Builder(*(Module*)module).makeRefIsNull((Expression*)value));
}
-BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
- const char* func) {
- return static_cast<Expression*>(Builder(*(Module*)module).makeRefFunc(func));
+BinaryenExpressionRef
+BinaryenRefFunc(BinaryenModuleRef module, const char* func, BinaryenType type) {
+ Type type_(type);
+ return static_cast<Expression*>(
+ Builder(*(Module*)module).makeRefFunc(func, type_));
}
BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index 45beb3657..c4517257a 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -792,7 +792,8 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module,
BINARYEN_API BinaryenExpressionRef
BinaryenRefIsNull(BinaryenModuleRef module, BinaryenExpressionRef value);
BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
- const char* func);
+ const char* func,
+ BinaryenType type);
BINARYEN_API BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,
BinaryenExpressionRef left,
BinaryenExpressionRef right);
diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp
index 19fed54a7..448c1f30a 100644
--- a/src/ir/ReFinalize.cpp
+++ b/src/ir/ReFinalize.cpp
@@ -126,7 +126,11 @@ void ReFinalize::visitMemorySize(MemorySize* curr) { curr->finalize(); }
void ReFinalize::visitMemoryGrow(MemoryGrow* curr) { curr->finalize(); }
void ReFinalize::visitRefNull(RefNull* curr) { curr->finalize(); }
void ReFinalize::visitRefIsNull(RefIsNull* curr) { curr->finalize(); }
-void ReFinalize::visitRefFunc(RefFunc* curr) { curr->finalize(); }
+void ReFinalize::visitRefFunc(RefFunc* curr) {
+ // TODO: should we look up the function and update the type from there? This
+ // could handle a change to the function's type, but is also not really what
+ // this class has been meant to do.
+}
void ReFinalize::visitRefEq(RefEq* curr) { curr->finalize(); }
void ReFinalize::visitTry(Try* curr) { curr->finalize(); }
void ReFinalize::visitThrow(Throw* curr) { curr->finalize(); }
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index a2d256073..c8776297c 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -414,16 +414,29 @@ collectSignatures(Module& wasm,
Counts& counts;
TypeCounter(Counts& counts) : counts(counts) {}
+
void visitExpression(Expression* curr) {
- if (auto* call = curr->dynCast<CallIndirect>()) {
+ if (curr->is<RefNull>()) {
+ maybeNote(curr->type);
+ } else if (auto* call = curr->dynCast<CallIndirect>()) {
counts[call->sig]++;
} else if (Properties::isControlFlowStructure(curr)) {
- // TODO: Allow control flow to have input types as well
+ maybeNote(curr->type);
if (curr->type.isTuple()) {
+ // TODO: Allow control flow to have input types as well
counts[Signature(Type::none, curr->type)]++;
}
}
}
+
+ void maybeNote(Type type) {
+ if (type.isRef()) {
+ auto heapType = type.getHeapType();
+ if (heapType.isSignature()) {
+ counts[heapType.getSignature()]++;
+ }
+ }
+ }
};
TypeCounter(counts).walk(func->body);
};
@@ -434,6 +447,14 @@ collectSignatures(Module& wasm,
Counts counts;
for (auto& curr : wasm.functions) {
counts[curr->sig]++;
+ for (auto type : curr->vars) {
+ if (type.isRef()) {
+ auto heapType = type.getHeapType();
+ if (heapType.isSignature()) {
+ counts[heapType.getSignature()]++;
+ }
+ }
+ }
}
for (auto& curr : wasm.events) {
counts[curr->sig]++;
@@ -444,10 +465,61 @@ collectSignatures(Module& wasm,
counts[innerPair.first] += innerPair.second;
}
}
+
+ // TODO: recursively traverse each reference type, which may have a child type
+ // this is itself a reference type.
+
+ // We must sort all the dependencies of a signature before it. For example,
+ // (func (param (ref (func)))) must appear after (func). To do that, find the
+ // depth of dependencies of each signature. For example, if A depends on B
+ // which depends on C, then A's depth is 2, B's is 1, and C's is 0 (assuming
+ // no other dependencies).
+ Counts depthOfDependencies;
+ std::unordered_map<Signature, std::unordered_set<Signature>> isDependencyOf;
+ // To calculate the depth of dependencies, we'll do a flow analysis, visiting
+ // each signature as we find out new things about it.
+ std::set<Signature> toVisit;
+ for (auto& pair : counts) {
+ auto sig = pair.first;
+ depthOfDependencies[sig] = 0;
+ toVisit.insert(sig);
+ for (Type type : {sig.params, sig.results}) {
+ for (auto element : type) {
+ if (element.isRef()) {
+ auto heapType = element.getHeapType();
+ if (heapType.isSignature()) {
+ isDependencyOf[heapType.getSignature()].insert(sig);
+ }
+ }
+ }
+ }
+ }
+ while (!toVisit.empty()) {
+ auto iter = toVisit.begin();
+ auto sig = *iter;
+ toVisit.erase(iter);
+ // Anything that depends on this has a depth of dependencies equal to this
+ // signature's, plus this signature itself.
+ auto newDepth = depthOfDependencies[sig] + 1;
+ if (newDepth > counts.size()) {
+ Fatal() << "Cyclic signatures detected, cannot sort them.";
+ }
+ for (auto& other : isDependencyOf[sig]) {
+ if (depthOfDependencies[other] < newDepth) {
+ // We found something new to propagate.
+ depthOfDependencies[other] = newDepth;
+ toVisit.insert(other);
+ }
+ }
+ }
+ // Sort by frequency and then simplicity, and also keeping every signature
+ // before things that depend on it.
std::vector<std::pair<Signature, size_t>> sorted(counts.begin(),
counts.end());
std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) {
- // order by frequency then simplicity
+ if (depthOfDependencies[a.first] != depthOfDependencies[b.first]) {
+ return depthOfDependencies[a.first] < depthOfDependencies[b.first];
+ }
if (a.second != b.second) {
return a.second > b.second;
}
diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js
index b78fdd996..9abe5e718 100644
--- a/src/js/binaryen.js-post.js
+++ b/src/js/binaryen.js-post.js
@@ -2112,8 +2112,8 @@ function wrapModule(module, self = {}) {
'is_null'(value) {
return Module['_BinaryenRefIsNull'](module, value);
},
- 'func'(func) {
- return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func)));
+ 'func'(func, type) {
+ return preserveStack(() => Module['_BinaryenRefFunc'](module, strToStack(func), type));
},
'eq'(left, right) {
return Module['_BinaryenRefEq'](module, left, right);
diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp
index 81494463f..004bfba74 100644
--- a/src/passes/InstrumentLocals.cpp
+++ b/src/passes/InstrumentLocals.cpp
@@ -135,45 +135,48 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
Builder builder(*getModule());
Name import;
auto type = curr->value->type;
- if (type.isFunction()) {
- import = set_funcref;
- } else {
- TODO_SINGLE_COMPOUND(curr->value->type);
- switch (type.getBasic()) {
- case Type::i32:
- import = set_i32;
- break;
- case Type::i64:
- return; // TODO
- case Type::f32:
- import = set_f32;
- break;
- case Type::f64:
- import = set_f64;
- break;
- case Type::v128:
- import = set_v128;
- break;
- case Type::externref:
- import = set_externref;
- break;
- case Type::exnref:
- import = set_exnref;
- break;
- case Type::anyref:
- import = set_anyref;
- break;
- case Type::eqref:
- import = set_eqref;
- break;
- case Type::i31ref:
- import = set_i31ref;
- break;
- case Type::unreachable:
- return; // nothing to do here
- default:
- WASM_UNREACHABLE("unexpected type");
- }
+ if (type.isFunction() && type != Type::funcref) {
+ // FIXME: support typed function references
+ return;
+ }
+ TODO_SINGLE_COMPOUND(curr->value->type);
+ switch (type.getBasic()) {
+ case Type::i32:
+ import = set_i32;
+ break;
+ case Type::i64:
+ return; // TODO
+ case Type::f32:
+ import = set_f32;
+ break;
+ case Type::f64:
+ import = set_f64;
+ break;
+ case Type::v128:
+ import = set_v128;
+ break;
+ case Type::funcref:
+ import = set_funcref;
+ break;
+ case Type::externref:
+ import = set_externref;
+ break;
+ case Type::exnref:
+ import = set_exnref;
+ break;
+ case Type::anyref:
+ import = set_anyref;
+ break;
+ case Type::eqref:
+ import = set_eqref;
+ break;
+ case Type::i31ref:
+ import = set_i31ref;
+ break;
+ case Type::unreachable:
+ return; // nothing to do here
+ default:
+ WASM_UNREACHABLE("unexpected type");
}
curr->value = builder.makeCall(import,
{builder.makeConst(int32_t(id++)),
diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h
index df7cf2226..1c1359586 100644
--- a/src/tools/fuzzing.h
+++ b/src/tools/fuzzing.h
@@ -321,6 +321,10 @@ private:
}
return Type(types);
}
+ if (type.isFunction() && type != Type::funcref) {
+ // TODO: specific typed function references types.
+ return type;
+ }
SmallVector<Type, 2> options;
options.push_back(type); // includes itself
TODO_SINGLE_COMPOUND(type);
@@ -653,6 +657,10 @@ private:
Index numVars = upToSquared(MAX_VARS);
for (Index i = 0; i < numVars; i++) {
auto type = getConcreteType();
+ if (type.isRef() && !type.isNullable()) {
+ // We can't use a nullable type as a var, which is null-initialized.
+ continue;
+ }
funcContext->typeLocals[type].push_back(params.size() +
func->vars.size());
func->vars.push_back(type);
@@ -1371,7 +1379,6 @@ private:
}
Expression* makeCall(Type type) {
- // seems ok, go on
int tries = TRIES;
bool isReturn;
while (tries-- > 0) {
@@ -1392,7 +1399,7 @@ private:
return builder.makeCall(target->name, args, type, isReturn);
}
// we failed to find something
- return make(type);
+ return makeTrivial(type);
}
Expression* makeCallIndirect(Type type) {
@@ -1418,7 +1425,7 @@ private:
i = 0;
}
if (i == start) {
- return make(type);
+ return makeTrivial(type);
}
}
// with high probability, make sure the type is valid otherwise, most are
@@ -2018,12 +2025,28 @@ private:
if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) {
target = pick(wasm.functions).get();
}
- return builder.makeRefFunc(target->name);
+ auto type = Type(HeapType(target->sig), /* nullable = */ true);
+ return builder.makeRefFunc(target->name, type);
}
if (type == Type::i31ref) {
return builder.makeI31New(makeConst(Type::i32));
}
- return builder.makeRefNull(type);
+ if (oneIn(2) && type.isNullable()) {
+ return builder.makeRefNull(type);
+ }
+ // TODO: randomize the order
+ for (auto& func : wasm.functions) {
+ // FIXME: RefFunc type should be non-nullable, but we emit nullable
+ // types for now.
+ if (type == Type(HeapType(func->sig), /* nullable = */ true)) {
+ return builder.makeRefFunc(func->name, type);
+ }
+ }
+ // We failed to find a function, so create a null reference if we can.
+ if (type.isNullable()) {
+ return builder.makeRefNull(type);
+ }
+ WASM_UNREACHABLE("un-handleable non-nullable type");
}
if (type.isTuple()) {
std::vector<Expression*> operands;
@@ -2972,6 +2995,7 @@ private:
Type::anyref,
Type::eqref,
Type::i31ref));
+ // TODO: emit typed function references types
}
Type getSingleConcreteType() { return pick(getSingleConcreteTypes()); }
@@ -2997,12 +3021,24 @@ private:
Type getEqReferenceType() { return pick(getEqReferenceTypes()); }
+ Type getMVPType() {
+ return pick(items(FeatureOptions<Type>().add(
+ FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64)));
+ }
+
Type getTupleType() {
std::vector<Type> elements;
- size_t numElements = 2 + upTo(MAX_TUPLE_SIZE - 1);
- elements.resize(numElements);
- for (size_t i = 0; i < numElements; ++i) {
- elements[i] = getSingleConcreteType();
+ size_t maxElements = 2 + upTo(MAX_TUPLE_SIZE - 1);
+ for (size_t i = 0; i < maxElements; ++i) {
+ auto type = getSingleConcreteType();
+ // Don't add a non-nullable type into a tuple, as currently we can't spill
+ // them into locals (that would require a "let").
+ if (!type.isNullable()) {
+ elements.push_back(type);
+ }
+ }
+ while (elements.size() < 2) {
+ elements.push_back(getMVPType());
}
return Type(elements);
}
diff --git a/src/tools/tool-options.h b/src/tools/tool-options.h
index 4b084e191..70ce4efc0 100644
--- a/src/tools/tool-options.h
+++ b/src/tools/tool-options.h
@@ -89,6 +89,8 @@ struct ToolOptions : public Options {
.addFeature(FeatureSet::Multivalue, "multivalue functions")
.addFeature(FeatureSet::GC, "garbage collection")
.addFeature(FeatureSet::Memory64, "memory64")
+ .addFeature(FeatureSet::TypedFunctionReferences,
+ "typed function references")
.add("--no-validation",
"-n",
"Disables validation, assumes inputs are correct",
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index ef3f9c9d1..0918151c5 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -346,6 +346,10 @@ enum EncodedType {
anyref = -0x12, // 0x6e
// comparable reference type
eqref = -0x13, // 0x6d
+ // nullable typed function reference type, with parameter
+ nullable = -0x14, // 0x6c
+ // non-nullable typed function reference type, with parameter
+ nonnullable = -0x15, // 0x6b
// integer reference type
i31ref = -0x16, // 0x6a
// exception reference type
@@ -386,6 +390,7 @@ extern const char* ReferenceTypesFeature;
extern const char* MultivalueFeature;
extern const char* GCFeature;
extern const char* Memory64Feature;
+extern const char* TypedFunctionReferencesFeature;
enum Subsection {
NameModule = 0,
@@ -1009,82 +1014,6 @@ enum FeaturePrefix {
} // namespace BinaryConsts
-inline S32LEB binaryType(Type type) {
- int ret = 0;
- TODO_SINGLE_COMPOUND(type);
- switch (type.getBasic()) {
- // None only used for block signatures. TODO: Separate out?
- case Type::none:
- ret = BinaryConsts::EncodedType::Empty;
- break;
- case Type::i32:
- ret = BinaryConsts::EncodedType::i32;
- break;
- case Type::i64:
- ret = BinaryConsts::EncodedType::i64;
- break;
- case Type::f32:
- ret = BinaryConsts::EncodedType::f32;
- break;
- case Type::f64:
- ret = BinaryConsts::EncodedType::f64;
- break;
- case Type::v128:
- ret = BinaryConsts::EncodedType::v128;
- break;
- case Type::funcref:
- ret = BinaryConsts::EncodedType::funcref;
- break;
- case Type::externref:
- ret = BinaryConsts::EncodedType::externref;
- break;
- case Type::exnref:
- ret = BinaryConsts::EncodedType::exnref;
- break;
- case Type::anyref:
- ret = BinaryConsts::EncodedType::anyref;
- break;
- case Type::eqref:
- ret = BinaryConsts::EncodedType::eqref;
- break;
- case Type::i31ref:
- ret = BinaryConsts::EncodedType::i31ref;
- break;
- case Type::unreachable:
- WASM_UNREACHABLE("unexpected type");
- }
- return S32LEB(ret);
-}
-
-inline S32LEB binaryHeapType(HeapType type) {
- int ret = 0;
- switch (type.kind) {
- case HeapType::FuncKind:
- ret = BinaryConsts::EncodedHeapType::func;
- break;
- case HeapType::ExternKind:
- ret = BinaryConsts::EncodedHeapType::extern_;
- break;
- case HeapType::ExnKind:
- ret = BinaryConsts::EncodedHeapType::exn;
- break;
- case HeapType::AnyKind:
- ret = BinaryConsts::EncodedHeapType::any;
- break;
- case HeapType::EqKind:
- ret = BinaryConsts::EncodedHeapType::eq;
- break;
- case HeapType::I31Kind:
- ret = BinaryConsts::EncodedHeapType::i31;
- break;
- case HeapType::SignatureKind:
- case HeapType::StructKind:
- case HeapType::ArrayKind:
- WASM_UNREACHABLE("TODO: compound GC types");
- }
- return S32LEB(ret); // TODO: Actually encoded as s33
-}
-
// Writes out wasm to the binary format
class WasmBinaryWriter {
@@ -1234,6 +1163,9 @@ public:
Module* getModule() { return wasm; }
+ void writeType(Type type);
+ void writeHeapType(HeapType type);
+
private:
Module* wasm;
BufferWithRandomAccess& o;
@@ -1342,6 +1274,8 @@ public:
std::vector<Signature> functionSignatures;
void readFunctionSignatures();
+ Signature getFunctionSignatureByIndex(Index index);
+
size_t nextLabel;
Name getNextLabel();
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index d3af93896..6800aa2ed 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -588,10 +588,10 @@ public:
ret->finalize();
return ret;
}
- RefFunc* makeRefFunc(Name func) {
+ RefFunc* makeRefFunc(Name func, Type type) {
auto* ret = wasm.allocator.alloc<RefFunc>();
ret->func = func;
- ret->finalize();
+ ret->finalize(type);
return ret;
}
RefEq* makeRefEq(Expression* left, Expression* right) {
@@ -769,8 +769,7 @@ public:
}
if (type.isFunction()) {
if (!value.isNull()) {
- // TODO: with typed function references we need to do more for the type
- return makeRefFunc(value.getFunc());
+ return makeRefFunc(value.getFunc(), type);
}
return makeRefNull(type);
}
@@ -951,7 +950,12 @@ public:
return makeConstantExpression(Literal::makeZeros(curr->type));
}
if (curr->type.isFunction()) {
- return ExpressionManipulator::refNull(curr, curr->type);
+ if (curr->type.isNullable()) {
+ return ExpressionManipulator::refNull(curr, curr->type);
+ } else {
+ // We can't do any better, keep the original.
+ return curr;
+ }
}
Literal value;
// TODO: reuse node conditionally when possible for literals
diff --git a/src/wasm-features.h b/src/wasm-features.h
index a2bb52971..d2e3f343f 100644
--- a/src/wasm-features.h
+++ b/src/wasm-features.h
@@ -38,7 +38,8 @@ struct FeatureSet {
Multivalue = 1 << 9,
GC = 1 << 10,
Memory64 = 1 << 11,
- All = (1 << 12) - 1
+ TypedFunctionReferences = 1 << 12,
+ All = (1 << 13) - 1
};
static std::string toString(Feature f) {
@@ -67,6 +68,8 @@ struct FeatureSet {
return "gc";
case Memory64:
return "memory64";
+ case TypedFunctionReferences:
+ return "typed-function-references";
default:
WASM_UNREACHABLE("unexpected feature");
}
@@ -92,6 +95,9 @@ struct FeatureSet {
bool hasMultivalue() const { return (features & Multivalue) != 0; }
bool hasGC() const { return (features & GC) != 0; }
bool hasMemory64() const { return (features & Memory64) != 0; }
+ bool hasTypedFunctionReferences() const {
+ return (features & TypedFunctionReferences) != 0;
+ }
bool hasAll() const { return (features & All) != 0; }
void makeMVP() { features = MVP; }
@@ -110,6 +116,9 @@ struct FeatureSet {
void setMultivalue(bool v = true) { set(Multivalue, v); }
void setGC(bool v = true) { set(GC, v); }
void setMemory64(bool v = true) { set(Memory64, v); }
+ void setTypedFunctionReferences(bool v = true) {
+ set(TypedFunctionReferences, v);
+ }
void setAll(bool v = true) { features = v ? All : MVP; }
void enable(const FeatureSet& other) { features |= other.features; }
diff --git a/src/wasm.h b/src/wasm.h
index 1204eee0f..e9fb4461b 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1198,6 +1198,7 @@ public:
Name func;
void finalize();
+ void finalize(Type type_);
};
class RefEq : public SpecificExpression<Expression::RefEqId> {
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index b343c6caf..a96039bc2 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -221,7 +221,7 @@ void WasmBinaryWriter::writeTypes() {
for (auto& sigType : {sig.params, sig.results}) {
o << U32LEB(sigType.size());
for (const auto& type : sigType) {
- o << binaryType(type);
+ writeType(type);
}
}
}
@@ -250,7 +250,7 @@ void WasmBinaryWriter::writeImports() {
BYN_TRACE("write one global\n");
writeImportHeader(global);
o << U32LEB(int32_t(ExternalKind::Global));
- o << binaryType(global->type);
+ writeType(global->type);
o << U32LEB(global->mutable_);
});
ModuleUtils::iterImportedEvents(*wasm, [&](Event* event) {
@@ -389,7 +389,7 @@ void WasmBinaryWriter::writeGlobals() {
BYN_TRACE("write one\n");
size_t i = 0;
for (const auto& t : global->type) {
- o << binaryType(t);
+ writeType(t);
o << U32LEB(global->mutable_);
if (global->type.size() == 1) {
writeExpression(global->init);
@@ -492,7 +492,12 @@ uint32_t WasmBinaryWriter::getEventIndex(Name name) const {
uint32_t WasmBinaryWriter::getTypeIndex(Signature sig) const {
auto it = typeIndices.find(sig);
- assert(it != typeIndices.end());
+#ifndef NDEBUG
+ if (it == typeIndices.end()) {
+ std::cout << "Missing signature: " << sig << '\n';
+ assert(0);
+ }
+#endif
return it->second;
}
@@ -799,6 +804,8 @@ void WasmBinaryWriter::writeFeaturesSection() {
return BinaryConsts::UserSections::GCFeature;
case FeatureSet::Memory64:
return BinaryConsts::UserSections::Memory64Feature;
+ case FeatureSet::TypedFunctionReferences:
+ return BinaryConsts::UserSections::TypedFunctionReferencesFeature;
default:
WASM_UNREACHABLE("unexpected feature flag");
}
@@ -950,6 +957,100 @@ void WasmBinaryWriter::finishUp() {
}
}
+void WasmBinaryWriter::writeType(Type type) {
+ if (type.isRef()) {
+ auto heapType = type.getHeapType();
+ // TODO: fully handle non-signature reference types (GC), and in reading
+ if (heapType.isSignature()) {
+ if (type.isNullable()) {
+ o << S32LEB(BinaryConsts::EncodedType::nullable);
+ } else {
+ o << S32LEB(BinaryConsts::EncodedType::nonnullable);
+ }
+ writeHeapType(heapType);
+ return;
+ }
+ }
+ int ret = 0;
+ TODO_SINGLE_COMPOUND(type);
+ switch (type.getBasic()) {
+ // None only used for block signatures. TODO: Separate out?
+ case Type::none:
+ ret = BinaryConsts::EncodedType::Empty;
+ break;
+ case Type::i32:
+ ret = BinaryConsts::EncodedType::i32;
+ break;
+ case Type::i64:
+ ret = BinaryConsts::EncodedType::i64;
+ break;
+ case Type::f32:
+ ret = BinaryConsts::EncodedType::f32;
+ break;
+ case Type::f64:
+ ret = BinaryConsts::EncodedType::f64;
+ break;
+ case Type::v128:
+ ret = BinaryConsts::EncodedType::v128;
+ break;
+ case Type::funcref:
+ ret = BinaryConsts::EncodedType::funcref;
+ break;
+ case Type::externref:
+ ret = BinaryConsts::EncodedType::externref;
+ break;
+ case Type::exnref:
+ ret = BinaryConsts::EncodedType::exnref;
+ break;
+ case Type::anyref:
+ ret = BinaryConsts::EncodedType::anyref;
+ break;
+ case Type::eqref:
+ ret = BinaryConsts::EncodedType::eqref;
+ break;
+ case Type::i31ref:
+ ret = BinaryConsts::EncodedType::i31ref;
+ break;
+ default:
+ WASM_UNREACHABLE("unexpected type");
+ }
+ o << S32LEB(ret);
+}
+
+void WasmBinaryWriter::writeHeapType(HeapType type) {
+ if (type.isSignature()) {
+ auto sig = type.getSignature();
+ o << S32LEB(getTypeIndex(sig));
+ return;
+ }
+ int ret = 0;
+ switch (type.kind) {
+ case HeapType::FuncKind:
+ ret = BinaryConsts::EncodedHeapType::func;
+ break;
+ case HeapType::ExternKind:
+ ret = BinaryConsts::EncodedHeapType::extern_;
+ break;
+ case HeapType::ExnKind:
+ ret = BinaryConsts::EncodedHeapType::exn;
+ break;
+ case HeapType::AnyKind:
+ ret = BinaryConsts::EncodedHeapType::any;
+ break;
+ case HeapType::EqKind:
+ ret = BinaryConsts::EncodedHeapType::eq;
+ break;
+ case HeapType::I31Kind:
+ ret = BinaryConsts::EncodedHeapType::i31;
+ break;
+ case HeapType::SignatureKind:
+ case HeapType::StructKind:
+ case HeapType::ArrayKind:
+ WASM_UNREACHABLE("TODO: compound GC types");
+ }
+ o << S32LEB(ret); // TODO: Actually encoded as s33
+}
+
// reader
bool WasmBinaryBuilder::hasDWARFSections() {
@@ -1253,6 +1354,10 @@ Type WasmBinaryBuilder::getType() {
return Type::anyref;
case BinaryConsts::EncodedType::eqref:
return Type::eqref;
+ case BinaryConsts::EncodedType::nullable:
+ return Type(getHeapType(), /* nullable = */ true);
+ case BinaryConsts::EncodedType::nonnullable:
+ return Type(getHeapType(), /* nullable = */ false);
case BinaryConsts::EncodedType::i31ref:
return Type::i31ref;
default:
@@ -1581,6 +1686,18 @@ void WasmBinaryBuilder::readFunctionSignatures() {
}
}
+Signature WasmBinaryBuilder::getFunctionSignatureByIndex(Index index) {
+ Signature sig;
+ if (index < functionImports.size()) {
+ return functionImports[index]->sig;
+ }
+ Index adjustedIndex = index - functionImports.size();
+ if (adjustedIndex >= functionSignatures.size()) {
+ throwError("invalid function index");
+ }
+ return functionSignatures[adjustedIndex];
+}
+
void WasmBinaryBuilder::readFunctions() {
BYN_TRACE("== readFunctions\n");
size_t total = getU32LEB();
@@ -2471,6 +2588,9 @@ void WasmBinaryBuilder::readFeatures(size_t payloadLen) {
wasm.features.setGC();
} else if (name == BinaryConsts::UserSections::Memory64Feature) {
wasm.features.setMemory64();
+ } else if (name ==
+ BinaryConsts::UserSections::TypedFunctionReferencesFeature) {
+ wasm.features.setTypedFunctionReferences();
}
}
}
@@ -3042,17 +3162,7 @@ void WasmBinaryBuilder::visitSwitch(Switch* curr) {
void WasmBinaryBuilder::visitCall(Call* curr) {
BYN_TRACE("zz node: Call\n");
auto index = getU32LEB();
- Signature sig;
- if (index < functionImports.size()) {
- auto* import = functionImports[index];
- sig = import->sig;
- } else {
- Index adjustedIndex = index - functionImports.size();
- if (adjustedIndex >= functionSignatures.size()) {
- throwError("invalid call index");
- }
- sig = functionSignatures[adjustedIndex];
- }
+ auto sig = getFunctionSignatureByIndex(index);
auto num = sig.params.size();
curr->operands.resize(num);
for (size_t i = 0; i < num; i++) {
@@ -5169,7 +5279,10 @@ void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) {
throwError("ref.func: invalid call index");
}
functionRefs[index].push_back(curr); // we don't know function names yet
- curr->finalize();
+ // To support typed function refs, we give the reference not just a general
+ // funcref, but a specific subtype with the actual signature.
+ curr->finalize(
+ Type(HeapType(getFunctionSignatureByIndex(index)), /* nullable = */ true));
}
void WasmBinaryBuilder::visitRefEq(RefEq* curr) {
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 0636836d7..6286ae090 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -1890,7 +1890,10 @@ Expression* SExpressionWasmBuilder::makeRefFunc(Element& s) {
auto func = getFunctionName(*s[1]);
auto ret = allocator.alloc<RefFunc>();
ret->func = func;
- ret->finalize();
+ // To support typed function refs, we give the reference not just a general
+ // funcref, but a specific subtype with the actual signature.
+ ret->finalize(
+ Type(HeapType(functionSignatures[func]), /* nullable = */ true));
return ret;
}
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index c8a4f7a90..021b05cb6 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -24,11 +24,11 @@ static Name IMPOSSIBLE_CONTINUE("impossible-continue");
void BinaryInstWriter::emitResultType(Type type) {
if (type == Type::unreachable) {
- o << binaryType(Type::none);
+ parent.writeType(Type::none);
} else if (type.isTuple()) {
o << S32LEB(parent.getTypeIndex(Signature(Type::none, type)));
} else {
- o << binaryType(type);
+ parent.writeType(type);
}
}
@@ -1756,8 +1756,8 @@ void BinaryInstWriter::visitSelect(Select* curr) {
if (curr->type.isRef()) {
o << int8_t(BinaryConsts::SelectWithType) << U32LEB(curr->type.size());
for (size_t i = 0; i < curr->type.size(); i++) {
- o << binaryType(curr->type != Type::unreachable ? curr->type
- : Type::none);
+ parent.writeType(curr->type != Type::unreachable ? curr->type
+ : Type::none);
}
} else {
o << int8_t(BinaryConsts::Select);
@@ -1779,8 +1779,8 @@ void BinaryInstWriter::visitMemoryGrow(MemoryGrow* curr) {
}
void BinaryInstWriter::visitRefNull(RefNull* curr) {
- o << int8_t(BinaryConsts::RefNull)
- << binaryHeapType(curr->type.getHeapType());
+ o << int8_t(BinaryConsts::RefNull);
+ parent.writeHeapType(curr->type.getHeapType());
}
void BinaryInstWriter::visitRefIsNull(RefIsNull* curr) {
@@ -1966,7 +1966,8 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
o << U32LEB(func->getNumVars());
for (Index i = varStart; i < varEnd; i++) {
mappedLocals[std::make_pair(i, 0)] = i;
- o << U32LEB(1) << binaryType(func->getLocalType(i));
+ o << U32LEB(1);
+ parent.writeType(func->getLocalType(i));
}
return;
}
@@ -1995,7 +1996,8 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
setScratchLocals();
o << U32LEB(numLocalsByType.size());
for (auto& typeCount : numLocalsByType) {
- o << U32LEB(typeCount.second) << binaryType(typeCount.first);
+ o << U32LEB(typeCount.second);
+ parent.writeType(typeCount.first);
}
}
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index dc4d50ef4..cf4404739 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -460,6 +460,14 @@ Type Type::reinterpret() const {
FeatureSet Type::getFeatures() const {
auto getSingleFeatures = [](Type t) -> FeatureSet {
+ if (t != Type::funcref && t.isFunction()) {
+ // Strictly speaking, typed function references require the typed function
+ // references feature, however, we use these types internally regardless
+ // of the presence of features (in particular, since during load of the
+ // wasm we don't know the features yet, so we apply the more refined
+ // types).
+ return FeatureSet::ReferenceTypes;
+ }
TODO_SINGLE_COMPOUND(t);
switch (t.getBasic()) {
case Type::v128:
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 809ca5a6a..78e123a90 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -2313,6 +2313,7 @@ void FunctionValidator::visitFunction(Function* curr) {
for (const auto& var : curr->vars) {
features |= var.getFeatures();
shouldBeTrue(var.isConcrete(), curr, "vars must be concretely typed");
+ // TODO: check for nullability
}
shouldBeTrue(features <= getModule()->features,
curr->name,
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index c7a187b43..6245a3575 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -47,6 +47,7 @@ const char* ReferenceTypesFeature = "reference-types";
const char* MultivalueFeature = "multivalue";
const char* GCFeature = "gc";
const char* Memory64Feature = "memory64";
+const char* TypedFunctionReferencesFeature = "typed-function-references";
} // namespace UserSections
} // namespace BinaryConsts
@@ -984,7 +985,12 @@ void RefIsNull::finalize() {
type = Type::i32;
}
-void RefFunc::finalize() { type = Type::funcref; }
+void RefFunc::finalize() {
+ // No-op. We assume that the full proper typed function type has been applied
+ // previously.
+}
+
+void RefFunc::finalize(Type type_) { type = type_; }
void RefEq::finalize() {
if (left->type == Type::unreachable || right->type == Type::unreachable) {