summaryrefslogtreecommitdiff
path: root/src/wasm/wasm-validator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm/wasm-validator.cpp')
-rw-r--r--src/wasm/wasm-validator.cpp156
1 files changed, 62 insertions, 94 deletions
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 78e123a90..5faa8b2f5 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -339,6 +339,7 @@ public:
void visitBrOnExn(BrOnExn* curr);
void visitTupleMake(TupleMake* curr);
void visitTupleExtract(TupleExtract* curr);
+ void visitCallRef(CallRef* curr);
void visitI31New(I31New* curr);
void visitI31Get(I31Get* curr);
void visitRefTest(RefTest* curr);
@@ -406,6 +407,49 @@ private:
size_t align, Type type, Index bytes, bool isAtomic, Expression* curr);
void validateMemBytes(uint8_t bytes, Type type, Expression* curr);
+ template<typename T> void validateReturnCall(T* curr) {
+ shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
+ curr,
+ "return_call* requires tail calls to be enabled");
+ }
+
+ template<typename T>
+ void validateCallParamsAndResult(T* curr, Signature sig) {
+ if (!shouldBeTrue(curr->operands.size() == sig.params.size(),
+ curr,
+ "call* param number must match")) {
+ return;
+ }
+ size_t i = 0;
+ for (const auto& param : sig.params) {
+ if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
+ param,
+ curr,
+ "call param types must match") &&
+ !info.quiet) {
+ getStream() << "(on argument " << i << ")\n";
+ }
+ ++i;
+ }
+ if (curr->isReturn) {
+ shouldBeEqual(curr->type,
+ Type(Type::unreachable),
+ curr,
+ "return_call* should have unreachable type");
+ shouldBeEqual(
+ getFunction()->sig.results,
+ sig.results,
+ curr,
+ "return_call* callee return type must match caller return type");
+ } else {
+ shouldBeEqualOrFirstIsUnreachable(
+ curr->type,
+ sig.results,
+ curr,
+ "call* type must match callee return type");
+ }
+ }
+
Type indexType() { return getModule()->memory.indexType; }
};
@@ -748,9 +792,7 @@ void FunctionValidator::visitSwitch(Switch* curr) {
}
void FunctionValidator::visitCall(Call* curr) {
- shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
- curr,
- "return_call requires tail calls to be enabled");
+ validateReturnCall(curr);
if (!info.validateGlobally) {
return;
}
@@ -758,104 +800,16 @@ void FunctionValidator::visitCall(Call* curr) {
if (!shouldBeTrue(!!target, curr, "call target must exist")) {
return;
}
- if (!shouldBeTrue(curr->operands.size() == target->sig.params.size(),
- curr,
- "call param number must match")) {
- return;
- }
- size_t i = 0;
- for (const auto& param : target->sig.params) {
- if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
- param,
- curr,
- "call param types must match") &&
- !info.quiet) {
- getStream() << "(on argument " << i << ")\n";
- }
- ++i;
- }
- if (curr->isReturn) {
- shouldBeEqual(curr->type,
- Type(Type::unreachable),
- curr,
- "return_call should have unreachable type");
- shouldBeEqual(
- getFunction()->sig.results,
- target->sig.results,
- curr,
- "return_call callee return type must match caller return type");
- } else {
- if (curr->type == Type::unreachable) {
- bool hasUnreachableOperand = std::any_of(
- curr->operands.begin(), curr->operands.end(), [](Expression* op) {
- return op->type == Type::unreachable;
- });
- shouldBeTrue(
- hasUnreachableOperand,
- curr,
- "calls may only be unreachable if they have unreachable operands");
- } else {
- shouldBeEqual(curr->type,
- target->sig.results,
- curr,
- "call type must match callee return type");
- }
- }
+ validateCallParamsAndResult(curr, target->sig);
}
void FunctionValidator::visitCallIndirect(CallIndirect* curr) {
- shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(),
- curr,
- "return_call_indirect requires tail calls to be enabled");
+ validateReturnCall(curr);
shouldBeEqualOrFirstIsUnreachable(curr->target->type,
Type(Type::i32),
curr,
"indirect call target must be an i32");
- if (!shouldBeTrue(curr->operands.size() == curr->sig.params.size(),
- curr,
- "call param number must match")) {
- return;
- }
- size_t i = 0;
- for (const auto& param : curr->sig.params) {
- if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type,
- param,
- curr,
- "call param types must match") &&
- !info.quiet) {
- getStream() << "(on argument " << i << ")\n";
- }
- ++i;
- }
- if (curr->isReturn) {
- shouldBeEqual(curr->type,
- Type(Type::unreachable),
- curr,
- "return_call_indirect should have unreachable type");
- shouldBeEqual(
- getFunction()->sig.results,
- curr->sig.results,
- curr,
- "return_call_indirect callee return type must match caller return type");
- } else {
- if (curr->type == Type::unreachable) {
- if (curr->target->type != Type::unreachable) {
- bool hasUnreachableOperand = std::any_of(
- curr->operands.begin(), curr->operands.end(), [](Expression* op) {
- return op->type == Type::unreachable;
- });
- shouldBeTrue(hasUnreachableOperand,
- curr,
- "call_indirects may only be unreachable if they have "
- "unreachable operands");
- }
- } else {
- shouldBeEqual(curr->type,
- curr->sig.results,
- curr,
- "call_indirect type must match callee return type");
- }
- }
+ validateCallParamsAndResult(curr, curr->sig);
}
void FunctionValidator::visitConst(Const* curr) {
@@ -2199,6 +2153,20 @@ void FunctionValidator::visitTupleExtract(TupleExtract* curr) {
}
}
+void FunctionValidator::visitCallRef(CallRef* curr) {
+ validateReturnCall(curr);
+ shouldBeTrue(getModule()->features.hasTypedFunctionReferences(),
+ curr,
+ "call_ref requires typed-function-references to be enabled");
+ shouldBeTrue(curr->target->type.isFunction(),
+ curr,
+ "call_ref target must be a function reference");
+ if (curr->target->type != Type::unreachable) {
+ validateCallParamsAndResult(
+ curr, curr->target->type.getHeapType().getSignature());
+ }
+}
+
void FunctionValidator::visitI31New(I31New* curr) {
shouldBeTrue(
getModule()->features.hasGC(), curr, "i31.new requires gc to be enabled");