diff options
Diffstat (limited to 'src/wasm/wasm-validator.cpp')
-rw-r--r-- | src/wasm/wasm-validator.cpp | 156 |
1 files changed, 62 insertions, 94 deletions
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 78e123a90..5faa8b2f5 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -339,6 +339,7 @@ public: void visitBrOnExn(BrOnExn* curr); void visitTupleMake(TupleMake* curr); void visitTupleExtract(TupleExtract* curr); + void visitCallRef(CallRef* curr); void visitI31New(I31New* curr); void visitI31Get(I31Get* curr); void visitRefTest(RefTest* curr); @@ -406,6 +407,49 @@ private: size_t align, Type type, Index bytes, bool isAtomic, Expression* curr); void validateMemBytes(uint8_t bytes, Type type, Expression* curr); + template<typename T> void validateReturnCall(T* curr) { + shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), + curr, + "return_call* requires tail calls to be enabled"); + } + + template<typename T> + void validateCallParamsAndResult(T* curr, Signature sig) { + if (!shouldBeTrue(curr->operands.size() == sig.params.size(), + curr, + "call* param number must match")) { + return; + } + size_t i = 0; + for (const auto& param : sig.params) { + if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, + param, + curr, + "call param types must match") && + !info.quiet) { + getStream() << "(on argument " << i << ")\n"; + } + ++i; + } + if (curr->isReturn) { + shouldBeEqual(curr->type, + Type(Type::unreachable), + curr, + "return_call* should have unreachable type"); + shouldBeEqual( + getFunction()->sig.results, + sig.results, + curr, + "return_call* callee return type must match caller return type"); + } else { + shouldBeEqualOrFirstIsUnreachable( + curr->type, + sig.results, + curr, + "call* type must match callee return type"); + } + } + Type indexType() { return getModule()->memory.indexType; } }; @@ -748,9 +792,7 @@ void FunctionValidator::visitSwitch(Switch* curr) { } void FunctionValidator::visitCall(Call* curr) { - shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), - curr, - "return_call requires tail calls to be enabled"); + validateReturnCall(curr); if (!info.validateGlobally) { return; } @@ -758,104 +800,16 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - if (!shouldBeTrue(curr->operands.size() == target->sig.params.size(), - curr, - "call param number must match")) { - return; - } - size_t i = 0; - for (const auto& param : target->sig.params) { - if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, - param, - curr, - "call param types must match") && - !info.quiet) { - getStream() << "(on argument " << i << ")\n"; - } - ++i; - } - if (curr->isReturn) { - shouldBeEqual(curr->type, - Type(Type::unreachable), - curr, - "return_call should have unreachable type"); - shouldBeEqual( - getFunction()->sig.results, - target->sig.results, - curr, - "return_call callee return type must match caller return type"); - } else { - if (curr->type == Type::unreachable) { - bool hasUnreachableOperand = std::any_of( - curr->operands.begin(), curr->operands.end(), [](Expression* op) { - return op->type == Type::unreachable; - }); - shouldBeTrue( - hasUnreachableOperand, - curr, - "calls may only be unreachable if they have unreachable operands"); - } else { - shouldBeEqual(curr->type, - target->sig.results, - curr, - "call type must match callee return type"); - } - } + validateCallParamsAndResult(curr, target->sig); } void FunctionValidator::visitCallIndirect(CallIndirect* curr) { - shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), - curr, - "return_call_indirect requires tail calls to be enabled"); + validateReturnCall(curr); shouldBeEqualOrFirstIsUnreachable(curr->target->type, Type(Type::i32), curr, "indirect call target must be an i32"); - if (!shouldBeTrue(curr->operands.size() == curr->sig.params.size(), - curr, - "call param number must match")) { - return; - } - size_t i = 0; - for (const auto& param : curr->sig.params) { - if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, - param, - curr, - "call param types must match") && - !info.quiet) { - getStream() << "(on argument " << i << ")\n"; - } - ++i; - } - if (curr->isReturn) { - shouldBeEqual(curr->type, - Type(Type::unreachable), - curr, - "return_call_indirect should have unreachable type"); - shouldBeEqual( - getFunction()->sig.results, - curr->sig.results, - curr, - "return_call_indirect callee return type must match caller return type"); - } else { - if (curr->type == Type::unreachable) { - if (curr->target->type != Type::unreachable) { - bool hasUnreachableOperand = std::any_of( - curr->operands.begin(), curr->operands.end(), [](Expression* op) { - return op->type == Type::unreachable; - }); - shouldBeTrue(hasUnreachableOperand, - curr, - "call_indirects may only be unreachable if they have " - "unreachable operands"); - } - } else { - shouldBeEqual(curr->type, - curr->sig.results, - curr, - "call_indirect type must match callee return type"); - } - } + validateCallParamsAndResult(curr, curr->sig); } void FunctionValidator::visitConst(Const* curr) { @@ -2199,6 +2153,20 @@ void FunctionValidator::visitTupleExtract(TupleExtract* curr) { } } +void FunctionValidator::visitCallRef(CallRef* curr) { + validateReturnCall(curr); + shouldBeTrue(getModule()->features.hasTypedFunctionReferences(), + curr, + "call_ref requires typed-function-references to be enabled"); + shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference"); + if (curr->target->type != Type::unreachable) { + validateCallParamsAndResult( + curr, curr->target->type.getHeapType().getSignature()); + } +} + void FunctionValidator::visitI31New(I31New* curr) { shouldBeTrue( getModule()->features.hasGC(), curr, "i31.new requires gc to be enabled"); |