summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/gen-s-parser.inc28
-rw-r--r--src/ir/ReFinalize.cpp1
-rw-r--r--src/ir/cost.h7
-rw-r--r--src/ir/effects.h11
-rw-r--r--src/ir/module-utils.h2
-rw-r--r--src/js/binaryen.js-post.js1
-rw-r--r--src/passes/DeadArgumentElimination.cpp6
-rw-r--r--src/passes/Directize.cpp44
-rw-r--r--src/passes/Inlining.cpp5
-rw-r--r--src/passes/MergeBlocks.cpp6
-rw-r--r--src/passes/Print.cpp42
-rw-r--r--src/shared-constants.h2
-rw-r--r--src/tools/fuzzing.h57
-rw-r--r--src/wasm-binary.h6
-rw-r--r--src/wasm-builder.h13
-rw-r--r--src/wasm-delegations-fields.h8
-rw-r--r--src/wasm-delegations.h1
-rw-r--r--src/wasm-interpreter.h34
-rw-r--r--src/wasm-s-parser.h8
-rw-r--r--src/wasm.h12
-rw-r--r--src/wasm/wasm-binary.cpp30
-rw-r--r--src/wasm/wasm-s-parser.cpp96
-rw-r--r--src/wasm/wasm-stack.cpp5
-rw-r--r--src/wasm/wasm-type.cpp3
-rw-r--r--src/wasm/wasm-validator.cpp156
-rw-r--r--src/wasm/wasm.cpp19
-rw-r--r--src/wasm2js.h4
27 files changed, 450 insertions, 157 deletions
diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc
index a62d9fdc6..8afcea917 100644
--- a/src/gen-s-parser.inc
+++ b/src/gen-s-parser.inc
@@ -99,9 +99,17 @@ switch (op[0]) {
case '\0':
if (strcmp(op, "call") == 0) { return makeCall(s, /*isReturn=*/false); }
goto parse_error;
- case '_':
- if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/false); }
- goto parse_error;
+ case '_': {
+ switch (op[5]) {
+ case 'i':
+ if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/false); }
+ goto parse_error;
+ case 'r':
+ if (strcmp(op, "call_ref") == 0) { return makeCallRef(s, /*isReturn=*/false); }
+ goto parse_error;
+ default: goto parse_error;
+ }
+ }
default: goto parse_error;
}
}
@@ -2747,9 +2755,17 @@ switch (op[0]) {
case '\0':
if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); }
goto parse_error;
- case '_':
- if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); }
- goto parse_error;
+ case '_': {
+ switch (op[12]) {
+ case 'i':
+ if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); }
+ goto parse_error;
+ case 'r':
+ if (strcmp(op, "return_call_ref") == 0) { return makeCallRef(s, /*isReturn=*/true); }
+ goto parse_error;
+ default: goto parse_error;
+ }
+ }
default: goto parse_error;
}
}
diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp
index 448c1f30a..54dde65bc 100644
--- a/src/ir/ReFinalize.cpp
+++ b/src/ir/ReFinalize.cpp
@@ -150,6 +150,7 @@ void ReFinalize::visitTupleMake(TupleMake* curr) { curr->finalize(); }
void ReFinalize::visitTupleExtract(TupleExtract* curr) { curr->finalize(); }
void ReFinalize::visitI31New(I31New* curr) { curr->finalize(); }
void ReFinalize::visitI31Get(I31Get* curr) { curr->finalize(); }
+void ReFinalize::visitCallRef(CallRef* curr) { curr->finalize(); }
void ReFinalize::visitRefTest(RefTest* curr) { curr->finalize(); }
void ReFinalize::visitRefCast(RefCast* curr) { curr->finalize(); }
void ReFinalize::visitBrOnCast(BrOnCast* curr) {
diff --git a/src/ir/cost.h b/src/ir/cost.h
index c0845f7e2..333f599ee 100644
--- a/src/ir/cost.h
+++ b/src/ir/cost.h
@@ -65,6 +65,13 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, Index> {
}
return ret;
}
+ Index visitCallRef(CallRef* curr) {
+ Index ret = 5 + visit(curr->target);
+ for (auto* child : curr->operands) {
+ ret += visit(child);
+ }
+ return ret;
+ }
Index visitLocalGet(LocalGet* curr) { return 0; }
Index visitLocalSet(LocalSet* curr) { return 1 + visit(curr->value); }
Index visitGlobalGet(GlobalGet* curr) { return 1; }
diff --git a/src/ir/effects.h b/src/ir/effects.h
index ab8cafcb1..c0210c221 100644
--- a/src/ir/effects.h
+++ b/src/ir/effects.h
@@ -534,6 +534,17 @@ private:
void visitTupleExtract(TupleExtract* curr) {}
void visitI31New(I31New* curr) {}
void visitI31Get(I31Get* curr) {}
+ void visitCallRef(CallRef* curr) {
+ parent.calls = true;
+ if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) {
+ parent.throws = true;
+ }
+ if (curr->isReturn) {
+ parent.branchesOut = true;
+ }
+ // traps when the arg is null
+ parent.implicitTrap = true;
+ }
void visitRefTest(RefTest* curr) {
WASM_UNREACHABLE("TODO (gc): ref.test");
}
diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h
index c8776297c..2b1c6812c 100644
--- a/src/ir/module-utils.h
+++ b/src/ir/module-utils.h
@@ -331,10 +331,10 @@ template<typename T> struct CallGraphPropertyAnalysis {
void visitCall(Call* curr) {
info.callsTo.insert(module->getFunction(curr->target));
}
-
void visitCallIndirect(CallIndirect* curr) {
info.hasNonDirectCall = true;
}
+ void visitCallRef(CallRef* curr) { info.hasNonDirectCall = true; }
private:
Module* module;
diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js
index 9abe5e718..bbf2ae237 100644
--- a/src/js/binaryen.js-post.js
+++ b/src/js/binaryen.js-post.js
@@ -99,6 +99,7 @@ function initializeConstants() {
'Pop',
'I31New',
'I31Get',
+ 'CallRef',
'RefTest',
'RefCast',
'BrOnCast',
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index 89d03f461..34637cf5a 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -143,6 +143,12 @@ struct DAEScanner
}
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ info->hasTailCalls = true;
+ }
+ }
+
void visitDrop(Drop* curr) {
if (auto* call = curr->value->dynCast<Call>()) {
info->droppedCalls[call] = getCurrentPointer();
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 0c1132b04..f966d1a5a 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -41,6 +41,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
FunctionDirectizer(TableUtils::FlatTable* flatTable) : flatTable(flatTable) {}
void visitCallIndirect(CallIndirect* curr) {
+ if (!flatTable) {
+ return;
+ }
if (auto* c = curr->target->dynCast<Const>()) {
Index index = c->value.geti32();
// If the index is invalid, or the type is wrong, we can
@@ -68,6 +71,15 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
}
}
+ void visitCallRef(CallRef* curr) {
+ if (auto* ref = curr->target->dynCast<RefFunc>()) {
+ // We know the target!
+ replaceCurrent(
+ Builder(*getModule())
+ .makeCall(ref->func, curr->operands, curr->type, curr->isReturn));
+ }
+ }
+
void doWalkFunction(Function* func) {
WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func);
if (changedTypes) {
@@ -76,7 +88,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
}
private:
+ // If null, then we cannot optimize call_indirects.
TableUtils::FlatTable* flatTable;
+
bool changedTypes = false;
void replaceWithUnreachable(CallIndirect* call) {
@@ -92,23 +106,31 @@ private:
struct Directize : public Pass {
void run(PassRunner* runner, Module* module) override {
+ bool canOptimizeCallIndirect = true;
+ TableUtils::FlatTable flatTable(module->table);
if (!module->table.exists) {
- return;
- }
- if (module->table.imported()) {
- return;
- }
- for (auto& ex : module->exports) {
- if (ex->kind == ExternalKind::Table) {
- return;
+ canOptimizeCallIndirect = false;
+ } else if (module->table.imported()) {
+ canOptimizeCallIndirect = false;
+ } else {
+ for (auto& ex : module->exports) {
+ if (ex->kind == ExternalKind::Table) {
+ canOptimizeCallIndirect = false;
+ }
+ }
+ if (!flatTable.valid) {
+ canOptimizeCallIndirect = false;
}
}
- TableUtils::FlatTable flatTable(module->table);
- if (!flatTable.valid) {
+ // Without typed function references, all we can do is optimize table
+ // accesses, so if we can't do that, stop.
+ if (!canOptimizeCallIndirect &&
+ !module->features.hasTypedFunctionReferences()) {
return;
}
// The table exists and is constant, so this is possible.
- FunctionDirectizer(&flatTable).run(runner, module);
+ FunctionDirectizer(canOptimizeCallIndirect ? &flatTable : nullptr)
+ .run(runner, module);
}
};
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index bcab7318f..a44f02426 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -211,6 +211,11 @@ struct Updater : public PostWalker<Updater> {
handleReturnCall(curr, curr->sig.results);
}
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ handleReturnCall(curr, curr->target->type);
+ }
+ }
void visitLocalGet(LocalGet* curr) {
curr->index = localMapping[curr->index];
}
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 4ecec6669..33dbec77c 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -564,7 +564,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
void visitCall(Call* curr) { handleCall(curr); }
- void visitCallIndirect(CallIndirect* curr) {
+ template<typename T> void handleNonDirectCall(T* curr) {
FeatureSet features = getModule()->features;
Block* outer = nullptr;
for (Index i = 0; i < curr->operands.size(); i++) {
@@ -581,6 +581,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
optimize(curr, curr->target, outer);
}
+ void visitCallIndirect(CallIndirect* curr) { handleNonDirectCall(curr); }
+
+ void visitCallRef(CallRef* curr) { handleNonDirectCall(curr); }
+
void visitThrow(Throw* curr) {
Block* outer = nullptr;
for (Index i = 0; i < curr->operands.size(); i++) {
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index e512d398f..864a46362 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -87,14 +87,35 @@ struct SigName {
};
std::ostream& operator<<(std::ostream& os, SigName sigName) {
- auto printType = [&](Type type) {
+ std::function<void(Type)> printType = [&](Type type) {
if (type == Type::none) {
os << "none";
} else {
auto sep = "";
for (const auto& t : type) {
- os << sep << t;
+ os << sep;
sep = "_";
+ if (t.isRef()) {
+ auto heapType = t.getHeapType();
+ if (heapType.isSignature()) {
+ auto sig = heapType.getSignature();
+ os << "ref";
+ if (t.isNullable()) {
+ os << "_null";
+ }
+ os << "<";
+ for (auto s : sig.params) {
+ printType(s);
+ }
+ os << "_->_";
+ for (auto s : sig.results) {
+ printType(s);
+ }
+ os << ">";
+ continue;
+ }
+ }
+ os << t;
}
}
};
@@ -1561,6 +1582,13 @@ struct PrintExpressionContents
void visitI31Get(I31Get* curr) {
printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u");
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ printMedium(o, "return_call_ref");
+ } else {
+ printMedium(o, "call_ref");
+ }
+ }
void visitRefTest(RefTest* curr) {
printMedium(o, "ref.test");
WASM_UNREACHABLE("TODO (gc): ref.test");
@@ -2216,6 +2244,16 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
printFullLine(curr->i31);
decIndent();
}
+ void visitCallRef(CallRef* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ incIndent();
+ for (auto operand : curr->operands) {
+ printFullLine(operand);
+ }
+ printFullLine(curr->target);
+ decIndent();
+ }
void visitRefTest(RefTest* curr) {
o << '(';
PrintExpressionContents(currFunction, o).visit(curr);
diff --git a/src/shared-constants.h b/src/shared-constants.h
index e3c34e62f..569fd792d 100644
--- a/src/shared-constants.h
+++ b/src/shared-constants.h
@@ -43,6 +43,8 @@ extern Name GLOBAL;
extern Name ELEM;
extern Name LOCAL;
extern Name TYPE;
+extern Name REF;
+extern Name NULL_;
extern Name CALL;
extern Name CALL_IMPORT;
extern Name CALL_INDIRECT;
diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h
index 1c1359586..39298ce21 100644
--- a/src/tools/fuzzing.h
+++ b/src/tools/fuzzing.h
@@ -1081,13 +1081,15 @@ private:
WeightedOption{&Self::makeGlobalGet, Important},
WeightedOption{&Self::makeConst, Important});
if (canMakeControlFlow) {
- options.add(FeatureSet::MVP,
- WeightedOption{&Self::makeBlock, Important},
- WeightedOption{&Self::makeIf, Important},
- WeightedOption{&Self::makeLoop, Important},
- WeightedOption{&Self::makeBreak, Important},
- &Self::makeCall,
- &Self::makeCallIndirect);
+ options
+ .add(FeatureSet::MVP,
+ WeightedOption{&Self::makeBlock, Important},
+ WeightedOption{&Self::makeIf, Important},
+ WeightedOption{&Self::makeLoop, Important},
+ WeightedOption{&Self::makeBreak, Important},
+ &Self::makeCall,
+ &Self::makeCallIndirect)
+ .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef);
}
if (type.isSingle()) {
options
@@ -1146,7 +1148,8 @@ private:
&Self::makeNop,
&Self::makeGlobalSet)
.add(FeatureSet::BulkMemory, &Self::makeBulkMemory)
- .add(FeatureSet::Atomics, &Self::makeAtomic);
+ .add(FeatureSet::Atomics, &Self::makeAtomic)
+ .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef);
return (this->*pick(options))(Type::none);
}
@@ -1154,22 +1157,24 @@ private:
using Self = TranslateToFuzzReader;
auto options = FeatureOptions<Expression* (Self::*)(Type)>();
using WeightedOption = decltype(options)::WeightedOption;
- options.add(FeatureSet::MVP,
- WeightedOption{&Self::makeLocalSet, VeryImportant},
- WeightedOption{&Self::makeBlock, Important},
- WeightedOption{&Self::makeIf, Important},
- WeightedOption{&Self::makeLoop, Important},
- WeightedOption{&Self::makeBreak, Important},
- WeightedOption{&Self::makeStore, Important},
- WeightedOption{&Self::makeUnary, Important},
- WeightedOption{&Self::makeBinary, Important},
- WeightedOption{&Self::makeUnreachable, Important},
- &Self::makeCall,
- &Self::makeCallIndirect,
- &Self::makeSelect,
- &Self::makeSwitch,
- &Self::makeDrop,
- &Self::makeReturn);
+ options
+ .add(FeatureSet::MVP,
+ WeightedOption{&Self::makeLocalSet, VeryImportant},
+ WeightedOption{&Self::makeBlock, Important},
+ WeightedOption{&Self::makeIf, Important},
+ WeightedOption{&Self::makeLoop, Important},
+ WeightedOption{&Self::makeBreak, Important},
+ WeightedOption{&Self::makeStore, Important},
+ WeightedOption{&Self::makeUnary, Important},
+ WeightedOption{&Self::makeBinary, Important},
+ WeightedOption{&Self::makeUnreachable, Important},
+ &Self::makeCall,
+ &Self::makeCallIndirect,
+ &Self::makeSelect,
+ &Self::makeSwitch,
+ &Self::makeDrop,
+ &Self::makeReturn)
+ .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef);
return (this->*pick(options))(Type::unreachable);
}
@@ -1443,6 +1448,10 @@ private:
return builder.makeCallIndirect(target, args, targetFn->sig, isReturn);
}
+ Expression* makeCallRef(Type type) {
+ return makeTrivial(type); // FIXME
+ }
+
Expression* makeLocalGet(Type type) {
auto& locals = funcContext->typeLocals[type];
if (locals.empty()) {
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 0918151c5..b0f41e69c 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -972,6 +972,11 @@ enum ASTNodes {
Rethrow = 0x09,
BrOnExn = 0x0a,
+ // typed function references opcodes
+
+ CallRef = 0x14,
+ RetCallRef = 0x15,
+
// gc opcodes
RefEq = 0xd5,
@@ -1479,6 +1484,7 @@ public:
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
void visitBrOnExn(BrOnExn* curr);
+ void visitCallRef(CallRef* curr);
void throwError(std::string text);
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 6800aa2ed..50a6e97fb 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -257,6 +257,19 @@ public:
call->finalize();
return call;
}
+ template<typename T>
+ CallRef* makeCallRef(Expression* target,
+ const T& args,
+ Type type,
+ bool isReturn = false) {
+ auto* call = wasm.allocator.alloc<CallRef>();
+ call->type = type;
+ call->target = target;
+ call->operands.set(args);
+ call->isReturn = isReturn;
+ call->finalize();
+ return call;
+ }
LocalGet* makeLocalGet(Index index, Type type) {
auto* ret = wasm.allocator.alloc<LocalGet>();
ret->index = index;
diff --git a/src/wasm-delegations-fields.h b/src/wasm-delegations-fields.h
index 7f6e43d75..ca0a8f7cb 100644
--- a/src/wasm-delegations-fields.h
+++ b/src/wasm-delegations-fields.h
@@ -549,6 +549,14 @@ switch (DELEGATE_ID) {
DELEGATE_END(I31Get);
break;
}
+ case Expression::Id::CallRefId: {
+ DELEGATE_START(CallRef);
+ DELEGATE_FIELD_CHILD(CallRef, target);
+ DELEGATE_FIELD_CHILD_VECTOR(CallRef, operands);
+ DELEGATE_FIELD_INT(CallRef, isReturn);
+ DELEGATE_END(CallRef);
+ break;
+ }
case Expression::Id::RefTestId: {
DELEGATE_START(RefTest);
WASM_UNREACHABLE("TODO (gc): ref.test");
diff --git a/src/wasm-delegations.h b/src/wasm-delegations.h
index 7212cbee9..50ee8247b 100644
--- a/src/wasm-delegations.h
+++ b/src/wasm-delegations.h
@@ -66,6 +66,7 @@ DELEGATE(TupleMake);
DELEGATE(TupleExtract);
DELEGATE(I31New);
DELEGATE(I31Get);
+DELEGATE(CallRef);
DELEGATE(RefTest);
DELEGATE(RefCast);
DELEGATE(BrOnCast);
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 406938a56..37719d4d9 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -1272,6 +1272,7 @@ public:
WASM_UNREACHABLE("unimp");
}
Flow visitPop(Pop* curr) { WASM_UNREACHABLE("unimp"); }
+ Flow visitCallRef(CallRef* curr) { WASM_UNREACHABLE("unimp"); }
Flow visitRefNull(RefNull* curr) {
NOTE_ENTER("RefNull");
return Literal::makeNull(curr->type);
@@ -1593,11 +1594,14 @@ public:
}
return Flow(NONCONSTANT_FLOW);
}
-
Flow visitCallIndirect(CallIndirect* curr) {
NOTE_ENTER("CallIndirect");
return Flow(NONCONSTANT_FLOW);
}
+ Flow visitCallRef(CallRef* curr) {
+ NOTE_ENTER("CallRef");
+ return Flow(NONCONSTANT_FLOW);
+ }
Flow visitLoad(Load* curr) {
NOTE_ENTER("Load");
return Flow(NONCONSTANT_FLOW);
@@ -2095,6 +2099,34 @@ private:
}
return ret;
}
+ Flow visitCallRef(CallRef* curr) {
+ NOTE_ENTER("CallRef");
+ LiteralList arguments;
+ Flow flow = this->generateArguments(curr->operands, arguments);
+ if (flow.breaking()) {
+ return flow;
+ }
+ Flow target = this->visit(curr->target);
+ if (target.breaking()) {
+ return target;
+ }
+ Name funcName = target.getSingleValue().getFunc();
+ auto* func = instance.wasm.getFunction(funcName);
+ Flow ret;
+ if (func->imported()) {
+ ret.values = instance.externalInterface->callImport(func, arguments);
+ } else {
+ ret.values = instance.callFunctionInternal(funcName, arguments);
+ }
+#ifdef WASM_INTERPRETER_DEBUG
+ std::cout << "(returned to " << scope.function->name << ")\n";
+#endif
+ // TODO: make this a proper tail call (return first)
+ if (curr->isReturn) {
+ ret.breakTo = RETURN_FLOW;
+ }
+ return ret;
+ }
Flow visitLocalGet(LocalGet* curr) {
NOTE_ENTER("LocalGet");
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 085b58ba0..9a501171d 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -77,6 +77,11 @@ public:
Element* setString(cashew::IString str__, bool dollared__, bool quoted__);
Element* setMetadata(size_t line_, size_t col_, SourceLocation* startLoc_);
+ // comparisons
+ bool operator==(Name name) { return isStr() && str() == name; }
+
+ template<typename T> bool operator!=(T t) { return !(*this == t); }
+
// printing
friend std::ostream& operator<<(std::ostream& o, Element& e);
void dump();
@@ -144,6 +149,7 @@ private:
UniqueNameMapper nameMapper;
+ // Given a function signature type's name, return the signature
Signature getFunctionSignature(Element& s);
Name getFunctionName(Element& s);
Name getGlobalName(Element& s);
@@ -246,6 +252,7 @@ private:
Expression* makeBrOnExn(Element& s);
Expression* makeTupleMake(Element& s);
Expression* makeTupleExtract(Element& s);
+ Expression* makeCallRef(Element& s, bool isReturn);
Expression* makeI31New(Element& s);
Expression* makeI31Get(Element& s, bool signed_);
Expression* makeRefTest(Element& s);
@@ -288,6 +295,7 @@ private:
void parseTable(Element& s, bool preParseImport = false);
void parseElem(Element& s);
void parseInnerElem(Element& s, Index i = 1, Expression* offset = nullptr);
+ Signature parseInlineFunctionSignature(Element& s);
void parseType(Element& s);
void parseEvent(Element& s, bool preParseImport = false);
diff --git a/src/wasm.h b/src/wasm.h
index e9fb4461b..6367fa345 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -588,6 +588,7 @@ public:
TupleExtractId,
I31NewId,
I31GetId,
+ CallRefId,
RefTestId,
RefCastId,
BrOnCastId,
@@ -1294,6 +1295,17 @@ public:
void finalize();
};
+class CallRef : public SpecificExpression<Expression::CallRefId> {
+public:
+ CallRef(MixedArena& allocator) : operands(allocator) {}
+ ExpressionList operands;
+ Expression* target;
+ bool isReturn = false;
+
+ void finalize();
+ void finalize(Type type_);
+};
+
class RefTest : public SpecificExpression<Expression::RefTestId> {
public:
RefTest(MixedArena& allocator) {}
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index a96039bc2..20b0899a5 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -2760,6 +2760,16 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
visitMemoryGrow(grow);
break;
}
+ case BinaryConsts::CallRef:
+ visitCallRef((curr = allocator.alloc<CallRef>())->cast<CallRef>());
+ break;
+ case BinaryConsts::RetCallRef: {
+ auto call = allocator.alloc<CallRef>();
+ call->isReturn = true;
+ curr = call;
+ visitCallRef(call);
+ break;
+ }
case BinaryConsts::AtomicPrefix: {
code = static_cast<uint8_t>(getU32LEB());
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) {
@@ -5426,6 +5436,26 @@ void WasmBinaryBuilder::visitBrOnExn(BrOnExn* curr) {
curr->finalize();
}
+void WasmBinaryBuilder::visitCallRef(CallRef* curr) {
+ BYN_TRACE("zz node: CallRef\n");
+ curr->target = popNonVoidExpression();
+ auto type = curr->target->type;
+ if (!type.isRef()) {
+ throwError("Non-ref type for a call_ref: " + type.toString());
+ }
+ auto heapType = type.getHeapType();
+ if (!heapType.isSignature()) {
+ throwError("Invalid reference type for a call_ref: " + type.toString());
+ }
+ auto sig = heapType.getSignature();
+ auto num = sig.params.size();
+ curr->operands.resize(num);
+ for (size_t i = 0; i < num; i++) {
+ curr->operands[num - i - 1] = popNonVoidExpression();
+ }
+ curr->finalize(sig.results);
+}
+
bool WasmBinaryBuilder::maybeVisitI31New(Expression*& out, uint32_t code) {
if (code != BinaryConsts::I31New) {
return false;
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 6286ae090..d8d9fa779 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -539,11 +539,11 @@ SExpressionWasmBuilder::parseParamOrLocal(Element& s, size_t& localIndex) {
if (s[i]->isStr()) {
type = stringToType(s[i]->str());
} else {
- if (elementStartsWith(s, PARAM)) {
+ type = elementToType(*s[i]);
+ if (elementStartsWith(s, PARAM) && type.isTuple()) {
throw ParseException(
"params may not have tuple types", s[i]->line, s[i]->col);
}
- type = elementToType(*s[i]);
}
namedParams.emplace_back(name, type);
}
@@ -925,10 +925,48 @@ Type SExpressionWasmBuilder::elementToType(Element& s) {
if (s.isStr()) {
return stringToType(s.str(), false, false);
}
- auto& tuple = s.list();
+ auto& list = s.list();
+ auto size = list.size();
+ if (size > 0 && elementStartsWith(s, REF)) {
+ // It's a reference. It should be in the form
+ // (ref $name)
+ // or
+ // (ref null $name)
+ // and also $name can be the expanded structure of the type and not a name,
+ // so something like (ref (func (result i32))), etc.
+ if (size != 2 && size != 3) {
+ throw ParseException(
+ std::string("invalid reference type size"), s.line, s.col);
+ }
+ if (size == 3 && *list[1] != NULL_) {
+ throw ParseException(
+ std::string("invalid reference type qualifier"), s.line, s.col);
+ }
+ bool nullable = false;
+ size_t i = 1;
+ if (size == 3) {
+ nullable = true;
+ i++;
+ }
+ Signature sig;
+ auto& last = *s[i];
+ if (last.isStr()) {
+ // A string name of a signature.
+ sig = getFunctionSignature(last);
+ } else {
+ // A signature written out in full in-line.
+ if (*last[0] != FUNC) {
+ throw ParseException(
+ std::string("invalid reference type type"), s.line, s.col);
+ }
+ sig = parseInlineFunctionSignature(last);
+ }
+ return Type(HeapType(sig), nullable);
+ }
+ // It's a tuple.
std::vector<Type> types;
for (size_t i = 0; i < s.size(); ++i) {
- types.push_back(stringToType(tuple[i]->str()));
+ types.push_back(stringToType(list[i]->str()));
}
return Type(types);
}
@@ -2026,6 +2064,24 @@ Expression* SExpressionWasmBuilder::makeTupleExtract(Element& s) {
return ret;
}
+Expression* SExpressionWasmBuilder::makeCallRef(Element& s, bool isReturn) {
+ auto ret = allocator.alloc<CallRef>();
+ parseCallOperands(s, 1, s.size() - 1, ret);
+ ret->target = parseExpression(s[s.size() - 1]);
+ ret->isReturn = isReturn;
+ if (!ret->target->type.isRef()) {
+ throw ParseException("Non-reference type for a call_ref", s.line, s.col);
+ }
+ auto heapType = ret->target->type.getHeapType();
+ if (!heapType.isSignature()) {
+ throw ParseException(
+ "Invalid reference type for a call_ref", s.line, s.col);
+ }
+ auto sig = heapType.getSignature();
+ ret->finalize(sig.results);
+ return ret;
+}
+
Expression* SExpressionWasmBuilder::makeI31New(Element& s) {
auto ret = allocator.alloc<I31New>();
ret->value = parseExpression(s[1]);
@@ -2710,9 +2766,26 @@ void SExpressionWasmBuilder::parseInnerElem(Element& s,
wasm.table.segments.push_back(segment);
}
-void SExpressionWasmBuilder::parseType(Element& s) {
+Signature SExpressionWasmBuilder::parseInlineFunctionSignature(Element& s) {
+ if (*s[0] != FUNC) {
+ throw ParseException("invalid inline function signature", s.line, s.col);
+ }
std::vector<Type> params;
std::vector<Type> results;
+ for (size_t k = 1; k < s.size(); k++) {
+ Element& curr = *s[k];
+ if (elementStartsWith(curr, PARAM)) {
+ auto newParams = parseParamOrLocal(curr);
+ params.insert(params.end(), newParams.begin(), newParams.end());
+ } else if (elementStartsWith(curr, RESULT)) {
+ auto newResults = parseResults(curr);
+ results.insert(results.end(), newResults.begin(), newResults.end());
+ }
+ }
+ return Signature(Type(params), Type(results));
+}
+
+void SExpressionWasmBuilder::parseType(Element& s) {
size_t i = 1;
if (s[i]->isStr()) {
std::string name = s[i]->str().str;
@@ -2722,18 +2795,7 @@ void SExpressionWasmBuilder::parseType(Element& s) {
signatureIndices[name] = signatures.size();
i++;
}
- Element& func = *s[i];
- for (size_t k = 1; k < func.size(); k++) {
- Element& curr = *func[k];
- if (elementStartsWith(curr, PARAM)) {
- auto newParams = parseParamOrLocal(curr);
- params.insert(params.end(), newParams.begin(), newParams.end());
- } else if (elementStartsWith(curr, RESULT)) {
- auto newResults = parseResults(curr);
- results.insert(results.end(), newResults.begin(), newResults.end());
- }
- }
- signatures.emplace_back(Type(params), Type(results));
+ signatures.emplace_back(parseInlineFunctionSignature(*s[i]));
}
void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) {
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index 021b05cb6..a6a4ce171 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -1875,6 +1875,11 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) {
<< U32LEB(curr->signed_ ? BinaryConsts::I31GetS : BinaryConsts::I31GetU);
}
+void BinaryInstWriter::visitCallRef(CallRef* curr) {
+ o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef
+ : BinaryConsts::CallRef);
+}
+
void BinaryInstWriter::visitRefTest(RefTest* curr) {
o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::RefTest);
WASM_UNREACHABLE("TODO (gc): ref.test");
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index cf4404739..7de274e9c 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -394,6 +394,9 @@ bool Type::operator<(const Type& other) const {
return false;
}
// Both are compound.
+ if (a.isNullable() != b.isNullable()) {
+ return a.isNullable();
+ }
auto aHeap = a.getHeapType();
auto bHeap = b.getHeapType();
if (aHeap.isSignature() && bHeap.isSignature()) {
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 78e123a90..5faa8b2f5 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -339,6 +339,7 @@ public:
void visitBrOnExn(BrOnExn* curr);
void visitTupleMake(TupleMake* curr);
void visitTupleExtract(TupleExtract* curr);
+ void visitCallRef(CallRef* curr);
void visitI31New(I31New* curr);
void visitI31Get(I31Get* curr);
void visitRefTest(RefTest* curr);
@@ -406,6 +407,49 @@ private:
size_t align, Type type, Index bytes, bool isAtomic, Expression* curr);
void validateMemBytes(uint8_t bytes, Type type, Expression* curr);
+ template<typename T> void validateReturnCall(T* curr) {
+ shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
+ curr,
+ "return_call* requires tail calls to be enabled");
+ }
+
+ template<typename T>
+ void validateCallParamsAndResult(T* curr, Signature sig) {
+ if (!shouldBeTrue(curr->operands.size() == sig.params.size(),
+ curr,
+ "call* param number must match")) {
+ return;
+ }
+ size_t i = 0;
+ for (const auto& param : sig.params) {
+ if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
+ param,
+ curr,
+ "call param types must match") &&
+ !info.quiet) {
+ getStream() << "(on argument " << i << ")\n";
+ }
+ ++i;
+ }
+ if (curr->isReturn) {
+ shouldBeEqual(curr->type,
+ Type(Type::unreachable),
+ curr,
+ "return_call* should have unreachable type");
+ shouldBeEqual(
+ getFunction()->sig.results,
+ sig.results,
+ curr,
+ "return_call* callee return type must match caller return type");
+ } else {
+ shouldBeEqualOrFirstIsUnreachable(
+ curr->type,
+ sig.results,
+ curr,
+ "call* type must match callee return type");
+ }
+ }
+
Type indexType() { return getModule()->memory.indexType; }
};
@@ -748,9 +792,7 @@ void FunctionValidator::visitSwitch(Switch* curr) {
}
void FunctionValidator::visitCall(Call* curr) {
- shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
- curr,
- "return_call requires tail calls to be enabled");
+ validateReturnCall(curr);
if (!info.validateGlobally) {
return;
}
@@ -758,104 +800,16 @@ void FunctionValidator::visitCall(Call* curr) {
if (!shouldBeTrue(!!target, curr, "call target must exist")) {
return;
}
- if (!shouldBeTrue(curr->operands.size() == target->sig.params.size(),
- curr,
- "call param number must match")) {
- return;
- }
- size_t i = 0;
- for (const auto& param : target->sig.params) {
- if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
- param,
- curr,
- "call param types must match") &&
- !info.quiet) {
- getStream() << "(on argument " << i << ")\n";
- }
- ++i;
- }
- if (curr->isReturn) {
- shouldBeEqual(curr->type,
- Type(Type::unreachable),
- curr,
- "return_call should have unreachable type");
- shouldBeEqual(
- getFunction()->sig.results,
- target->sig.results,
- curr,
- "return_call callee return type must match caller return type");
- } else {
- if (curr->type == Type::unreachable) {
- bool hasUnreachableOperand = std::any_of(
- curr->operands.begin(), curr->operands.end(), [](Expression* op) {
- return op->type == Type::unreachable;
- });
- shouldBeTrue(
- hasUnreachableOperand,
- curr,
- "calls may only be unreachable if they have unreachable operands");
- } else {
- shouldBeEqual(curr->type,
- target->sig.results,
- curr,
- "call type must match callee return type");
- }
- }
+ validateCallParamsAndResult(curr, target->sig);
}
void FunctionValidator::visitCallIndirect(CallIndirect* curr) {
- shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
- curr,
- "return_call_indirect requires tail calls to be enabled");
+ validateReturnCall(curr);
shouldBeEqualOrFirstIsUnreachable(curr->target->type,
Type(Type::i32),
curr,
"indirect call target must be an i32");
- if (!shouldBeTrue(curr->operands.size() == curr->sig.params.size(),
- curr,
- "call param number must match")) {
- return;
- }
- size_t i = 0;
- for (const auto& param : curr->sig.params) {
- if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
- param,
- curr,
- "call param types must match") &&
- !info.quiet) {
- getStream() << "(on argument " << i << ")\n";
- }
- ++i;
- }
- if (curr->isReturn) {
- shouldBeEqual(curr->type,
- Type(Type::unreachable),
- curr,
- "return_call_indirect should have unreachable type");
- shouldBeEqual(
- getFunction()->sig.results,
- curr->sig.results,
- curr,
- "return_call_indirect callee return type must match caller return type");
- } else {
- if (curr->type == Type::unreachable) {
- if (curr->target->type != Type::unreachable) {
- bool hasUnreachableOperand = std::any_of(
- curr->operands.begin(), curr->operands.end(), [](Expression* op) {
- return op->type == Type::unreachable;
- });
- shouldBeTrue(hasUnreachableOperand,
- curr,
- "call_indirects may only be unreachable if they have "
- "unreachable operands");
- }
- } else {
- shouldBeEqual(curr->type,
- curr->sig.results,
- curr,
- "call_indirect type must match callee return type");
- }
- }
+ validateCallParamsAndResult(curr, curr->sig);
}
void FunctionValidator::visitConst(Const* curr) {
@@ -2199,6 +2153,20 @@ void FunctionValidator::visitTupleExtract(TupleExtract* curr) {
}
}
+void FunctionValidator::visitCallRef(CallRef* curr) {
+ validateReturnCall(curr);
+ shouldBeTrue(getModule()->features.hasTypedFunctionReferences(),
+ curr,
+ "call_ref requires typed-function-references to be enabled");
+ shouldBeTrue(curr->target->type.isFunction(),
+ curr,
+ "call_ref target must be a function reference");
+ if (curr->target->type != Type::unreachable) {
+ validateCallParamsAndResult(
+ curr, curr->target->type.getHeapType().getSignature());
+ }
+}
+
void FunctionValidator::visitI31New(I31New* curr) {
shouldBeTrue(
getModule()->features.hasGC(), curr, "i31.new requires gc to be enabled");
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index 6245a3575..ac76a63ac 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -73,6 +73,8 @@ Name TABLE("table");
Name ELEM("elem");
Name LOCAL("local");
Name TYPE("type");
+Name REF("ref");
+Name NULL_("null");
Name CALL("call");
Name CALL_INDIRECT("call_indirect");
Name BLOCK("block");
@@ -212,6 +214,8 @@ const char* getExpressionName(Expression* curr) {
return "i31.new";
case Expression::Id::I31GetId:
return "i31.get";
+ case Expression::Id::CallRefId:
+ return "call_ref";
case Expression::Id::RefTestId:
return "ref.test";
case Expression::Id::RefCastId:
@@ -1060,6 +1064,21 @@ void I31Get::finalize() {
}
}
+void CallRef::finalize() {
+ handleUnreachableOperands(this);
+ if (isReturn) {
+ type = Type::unreachable;
+ }
+ if (target->type == Type::unreachable) {
+ type = Type::unreachable;
+ }
+}
+
+void CallRef::finalize(Type type_) {
+ type = type_;
+ finalize();
+}
+
// TODO (gc): ref.test
// TODO (gc): ref.cast
// TODO (gc): br_on_cast
diff --git a/src/wasm2js.h b/src/wasm2js.h
index 7a0f4692d..ddc3d7313 100644
--- a/src/wasm2js.h
+++ b/src/wasm2js.h
@@ -2198,6 +2198,10 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m,
unimplemented(curr);
WASM_UNREACHABLE("unimp");
}
+ Ref visitCallRef(CallRef* curr) {
+ unimplemented(curr);
+ WASM_UNREACHABLE("unimp");
+ }
Ref visitRefTest(RefTest* curr) {
unimplemented(curr);
WASM_UNREACHABLE("unimp");