diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gen-s-parser.inc | 28 | ||||
-rw-r--r-- | src/ir/ReFinalize.cpp | 1 | ||||
-rw-r--r-- | src/ir/cost.h | 7 | ||||
-rw-r--r-- | src/ir/effects.h | 11 | ||||
-rw-r--r-- | src/ir/module-utils.h | 2 | ||||
-rw-r--r-- | src/js/binaryen.js-post.js | 1 | ||||
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 6 | ||||
-rw-r--r-- | src/passes/Directize.cpp | 44 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 5 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 6 | ||||
-rw-r--r-- | src/passes/Print.cpp | 42 | ||||
-rw-r--r-- | src/shared-constants.h | 2 | ||||
-rw-r--r-- | src/tools/fuzzing.h | 57 | ||||
-rw-r--r-- | src/wasm-binary.h | 6 | ||||
-rw-r--r-- | src/wasm-builder.h | 13 | ||||
-rw-r--r-- | src/wasm-delegations-fields.h | 8 | ||||
-rw-r--r-- | src/wasm-delegations.h | 1 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 34 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 8 | ||||
-rw-r--r-- | src/wasm.h | 12 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 30 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 96 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 5 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 3 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 156 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 19 | ||||
-rw-r--r-- | src/wasm2js.h | 4 |
27 files changed, 450 insertions, 157 deletions
diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc index a62d9fdc6..8afcea917 100644 --- a/src/gen-s-parser.inc +++ b/src/gen-s-parser.inc @@ -99,9 +99,17 @@ switch (op[0]) { case '\0': if (strcmp(op, "call") == 0) { return makeCall(s, /*isReturn=*/false); } goto parse_error; - case '_': - if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/false); } - goto parse_error; + case '_': { + switch (op[5]) { + case 'i': + if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/false); } + goto parse_error; + case 'r': + if (strcmp(op, "call_ref") == 0) { return makeCallRef(s, /*isReturn=*/false); } + goto parse_error; + default: goto parse_error; + } + } default: goto parse_error; } } @@ -2747,9 +2755,17 @@ switch (op[0]) { case '\0': if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); } goto parse_error; - case '_': - if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); } - goto parse_error; + case '_': { + switch (op[12]) { + case 'i': + if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); } + goto parse_error; + case 'r': + if (strcmp(op, "return_call_ref") == 0) { return makeCallRef(s, /*isReturn=*/true); } + goto parse_error; + default: goto parse_error; + } + } default: goto parse_error; } } diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index 448c1f30a..54dde65bc 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp @@ -150,6 +150,7 @@ void ReFinalize::visitTupleMake(TupleMake* curr) { curr->finalize(); } void ReFinalize::visitTupleExtract(TupleExtract* curr) { curr->finalize(); } void ReFinalize::visitI31New(I31New* curr) { curr->finalize(); } void ReFinalize::visitI31Get(I31Get* curr) { curr->finalize(); } +void ReFinalize::visitCallRef(CallRef* curr) { curr->finalize(); } void ReFinalize::visitRefTest(RefTest* curr) { curr->finalize(); } void ReFinalize::visitRefCast(RefCast* curr) { curr->finalize(); } void ReFinalize::visitBrOnCast(BrOnCast* curr) { diff --git a/src/ir/cost.h b/src/ir/cost.h index c0845f7e2..333f599ee 100644 --- a/src/ir/cost.h +++ b/src/ir/cost.h @@ -65,6 +65,13 @@ struct CostAnalyzer : public OverriddenVisitor<CostAnalyzer, Index> { } return ret; } + Index visitCallRef(CallRef* curr) { + Index ret = 5 + visit(curr->target); + for (auto* child : curr->operands) { + ret += visit(child); + } + return ret; + } Index visitLocalGet(LocalGet* curr) { return 0; } Index visitLocalSet(LocalSet* curr) { return 1 + visit(curr->value); } Index visitGlobalGet(GlobalGet* curr) { return 1; } diff --git a/src/ir/effects.h b/src/ir/effects.h index ab8cafcb1..c0210c221 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -534,6 +534,17 @@ private: void visitTupleExtract(TupleExtract* curr) {} void visitI31New(I31New* curr) {} void visitI31Get(I31Get* curr) {} + void visitCallRef(CallRef* curr) { + parent.calls = true; + if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { + parent.throws = true; + } + if (curr->isReturn) { + parent.branchesOut = true; + } + // traps when the arg is null + parent.implicitTrap = true; + } void visitRefTest(RefTest* curr) { WASM_UNREACHABLE("TODO (gc): ref.test"); } diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index c8776297c..2b1c6812c 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -331,10 +331,10 @@ template<typename T> struct CallGraphPropertyAnalysis { void visitCall(Call* curr) { info.callsTo.insert(module->getFunction(curr->target)); } - void visitCallIndirect(CallIndirect* curr) { info.hasNonDirectCall = true; } + void visitCallRef(CallRef* curr) { info.hasNonDirectCall = true; } private: Module* module; diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 9abe5e718..bbf2ae237 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -99,6 +99,7 @@ function initializeConstants() { 'Pop', 'I31New', 'I31Get', + 'CallRef', 'RefTest', 'RefCast', 'BrOnCast', diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 89d03f461..34637cf5a 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -143,6 +143,12 @@ struct DAEScanner } } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + info->hasTailCalls = true; + } + } + void visitDrop(Drop* curr) { if (auto* call = curr->value->dynCast<Call>()) { info->droppedCalls[call] = getCurrentPointer(); diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 0c1132b04..f966d1a5a 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -41,6 +41,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { FunctionDirectizer(TableUtils::FlatTable* flatTable) : flatTable(flatTable) {} void visitCallIndirect(CallIndirect* curr) { + if (!flatTable) { + return; + } if (auto* c = curr->target->dynCast<Const>()) { Index index = c->value.geti32(); // If the index is invalid, or the type is wrong, we can @@ -68,6 +71,15 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { } } + void visitCallRef(CallRef* curr) { + if (auto* ref = curr->target->dynCast<RefFunc>()) { + // We know the target! + replaceCurrent( + Builder(*getModule()) + .makeCall(ref->func, curr->operands, curr->type, curr->isReturn)); + } + } + void doWalkFunction(Function* func) { WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func); if (changedTypes) { @@ -76,7 +88,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { } private: + // If null, then we cannot optimize call_indirects. TableUtils::FlatTable* flatTable; + bool changedTypes = false; void replaceWithUnreachable(CallIndirect* call) { @@ -92,23 +106,31 @@ private: struct Directize : public Pass { void run(PassRunner* runner, Module* module) override { + bool canOptimizeCallIndirect = true; + TableUtils::FlatTable flatTable(module->table); if (!module->table.exists) { - return; - } - if (module->table.imported()) { - return; - } - for (auto& ex : module->exports) { - if (ex->kind == ExternalKind::Table) { - return; + canOptimizeCallIndirect = false; + } else if (module->table.imported()) { + canOptimizeCallIndirect = false; + } else { + for (auto& ex : module->exports) { + if (ex->kind == ExternalKind::Table) { + canOptimizeCallIndirect = false; + } + } + if (!flatTable.valid) { + canOptimizeCallIndirect = false; } } - TableUtils::FlatTable flatTable(module->table); - if (!flatTable.valid) { + // Without typed function references, all we can do is optimize table + // accesses, so if we can't do that, stop. + if (!canOptimizeCallIndirect && + !module->features.hasTypedFunctionReferences()) { return; } // The table exists and is constant, so this is possible. - FunctionDirectizer(&flatTable).run(runner, module); + FunctionDirectizer(canOptimizeCallIndirect ? &flatTable : nullptr) + .run(runner, module); } }; diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index bcab7318f..a44f02426 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -211,6 +211,11 @@ struct Updater : public PostWalker<Updater> { handleReturnCall(curr, curr->sig.results); } } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + handleReturnCall(curr, curr->target->type); + } + } void visitLocalGet(LocalGet* curr) { curr->index = localMapping[curr->index]; } diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 4ecec6669..33dbec77c 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -564,7 +564,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { void visitCall(Call* curr) { handleCall(curr); } - void visitCallIndirect(CallIndirect* curr) { + template<typename T> void handleNonDirectCall(T* curr) { FeatureSet features = getModule()->features; Block* outer = nullptr; for (Index i = 0; i < curr->operands.size(); i++) { @@ -581,6 +581,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { optimize(curr, curr->target, outer); } + void visitCallIndirect(CallIndirect* curr) { handleNonDirectCall(curr); } + + void visitCallRef(CallRef* curr) { handleNonDirectCall(curr); } + void visitThrow(Throw* curr) { Block* outer = nullptr; for (Index i = 0; i < curr->operands.size(); i++) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index e512d398f..864a46362 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -87,14 +87,35 @@ struct SigName { }; std::ostream& operator<<(std::ostream& os, SigName sigName) { - auto printType = [&](Type type) { + std::function<void(Type)> printType = [&](Type type) { if (type == Type::none) { os << "none"; } else { auto sep = ""; for (const auto& t : type) { - os << sep << t; + os << sep; sep = "_"; + if (t.isRef()) { + auto heapType = t.getHeapType(); + if (heapType.isSignature()) { + auto sig = heapType.getSignature(); + os << "ref"; + if (t.isNullable()) { + os << "_null"; + } + os << "<"; + for (auto s : sig.params) { + printType(s); + } + os << "_->_"; + for (auto s : sig.results) { + printType(s); + } + os << ">"; + continue; + } + } + os << t; } } }; @@ -1561,6 +1582,13 @@ struct PrintExpressionContents void visitI31Get(I31Get* curr) { printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u"); } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + printMedium(o, "return_call_ref"); + } else { + printMedium(o, "call_ref"); + } + } void visitRefTest(RefTest* curr) { printMedium(o, "ref.test"); WASM_UNREACHABLE("TODO (gc): ref.test"); @@ -2216,6 +2244,16 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { printFullLine(curr->i31); decIndent(); } + void visitCallRef(CallRef* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + incIndent(); + for (auto operand : curr->operands) { + printFullLine(operand); + } + printFullLine(curr->target); + decIndent(); + } void visitRefTest(RefTest* curr) { o << '('; PrintExpressionContents(currFunction, o).visit(curr); diff --git a/src/shared-constants.h b/src/shared-constants.h index e3c34e62f..569fd792d 100644 --- a/src/shared-constants.h +++ b/src/shared-constants.h @@ -43,6 +43,8 @@ extern Name GLOBAL; extern Name ELEM; extern Name LOCAL; extern Name TYPE; +extern Name REF; +extern Name NULL_; extern Name CALL; extern Name CALL_IMPORT; extern Name CALL_INDIRECT; diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 1c1359586..39298ce21 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -1081,13 +1081,15 @@ private: WeightedOption{&Self::makeGlobalGet, Important}, WeightedOption{&Self::makeConst, Important}); if (canMakeControlFlow) { - options.add(FeatureSet::MVP, - WeightedOption{&Self::makeBlock, Important}, - WeightedOption{&Self::makeIf, Important}, - WeightedOption{&Self::makeLoop, Important}, - WeightedOption{&Self::makeBreak, Important}, - &Self::makeCall, - &Self::makeCallIndirect); + options + .add(FeatureSet::MVP, + WeightedOption{&Self::makeBlock, Important}, + WeightedOption{&Self::makeIf, Important}, + WeightedOption{&Self::makeLoop, Important}, + WeightedOption{&Self::makeBreak, Important}, + &Self::makeCall, + &Self::makeCallIndirect) + .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); } if (type.isSingle()) { options @@ -1146,7 +1148,8 @@ private: &Self::makeNop, &Self::makeGlobalSet) .add(FeatureSet::BulkMemory, &Self::makeBulkMemory) - .add(FeatureSet::Atomics, &Self::makeAtomic); + .add(FeatureSet::Atomics, &Self::makeAtomic) + .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); return (this->*pick(options))(Type::none); } @@ -1154,22 +1157,24 @@ private: using Self = TranslateToFuzzReader; auto options = FeatureOptions<Expression* (Self::*)(Type)>(); using WeightedOption = decltype(options)::WeightedOption; - options.add(FeatureSet::MVP, - WeightedOption{&Self::makeLocalSet, VeryImportant}, - WeightedOption{&Self::makeBlock, Important}, - WeightedOption{&Self::makeIf, Important}, - WeightedOption{&Self::makeLoop, Important}, - WeightedOption{&Self::makeBreak, Important}, - WeightedOption{&Self::makeStore, Important}, - WeightedOption{&Self::makeUnary, Important}, - WeightedOption{&Self::makeBinary, Important}, - WeightedOption{&Self::makeUnreachable, Important}, - &Self::makeCall, - &Self::makeCallIndirect, - &Self::makeSelect, - &Self::makeSwitch, - &Self::makeDrop, - &Self::makeReturn); + options + .add(FeatureSet::MVP, + WeightedOption{&Self::makeLocalSet, VeryImportant}, + WeightedOption{&Self::makeBlock, Important}, + WeightedOption{&Self::makeIf, Important}, + WeightedOption{&Self::makeLoop, Important}, + WeightedOption{&Self::makeBreak, Important}, + WeightedOption{&Self::makeStore, Important}, + WeightedOption{&Self::makeUnary, Important}, + WeightedOption{&Self::makeBinary, Important}, + WeightedOption{&Self::makeUnreachable, Important}, + &Self::makeCall, + &Self::makeCallIndirect, + &Self::makeSelect, + &Self::makeSwitch, + &Self::makeDrop, + &Self::makeReturn) + .add(FeatureSet::TypedFunctionReferences, &Self::makeCallRef); return (this->*pick(options))(Type::unreachable); } @@ -1443,6 +1448,10 @@ private: return builder.makeCallIndirect(target, args, targetFn->sig, isReturn); } + Expression* makeCallRef(Type type) { + return makeTrivial(type); // FIXME + } + Expression* makeLocalGet(Type type) { auto& locals = funcContext->typeLocals[type]; if (locals.empty()) { diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 0918151c5..b0f41e69c 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -972,6 +972,11 @@ enum ASTNodes { Rethrow = 0x09, BrOnExn = 0x0a, + // typed function references opcodes + + CallRef = 0x14, + RetCallRef = 0x15, + // gc opcodes RefEq = 0xd5, @@ -1479,6 +1484,7 @@ public: void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); void visitBrOnExn(BrOnExn* curr); + void visitCallRef(CallRef* curr); void throwError(std::string text); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 6800aa2ed..50a6e97fb 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -257,6 +257,19 @@ public: call->finalize(); return call; } + template<typename T> + CallRef* makeCallRef(Expression* target, + const T& args, + Type type, + bool isReturn = false) { + auto* call = wasm.allocator.alloc<CallRef>(); + call->type = type; + call->target = target; + call->operands.set(args); + call->isReturn = isReturn; + call->finalize(); + return call; + } LocalGet* makeLocalGet(Index index, Type type) { auto* ret = wasm.allocator.alloc<LocalGet>(); ret->index = index; diff --git a/src/wasm-delegations-fields.h b/src/wasm-delegations-fields.h index 7f6e43d75..ca0a8f7cb 100644 --- a/src/wasm-delegations-fields.h +++ b/src/wasm-delegations-fields.h @@ -549,6 +549,14 @@ switch (DELEGATE_ID) { DELEGATE_END(I31Get); break; } + case Expression::Id::CallRefId: { + DELEGATE_START(CallRef); + DELEGATE_FIELD_CHILD(CallRef, target); + DELEGATE_FIELD_CHILD_VECTOR(CallRef, operands); + DELEGATE_FIELD_INT(CallRef, isReturn); + DELEGATE_END(CallRef); + break; + } case Expression::Id::RefTestId: { DELEGATE_START(RefTest); WASM_UNREACHABLE("TODO (gc): ref.test"); diff --git a/src/wasm-delegations.h b/src/wasm-delegations.h index 7212cbee9..50ee8247b 100644 --- a/src/wasm-delegations.h +++ b/src/wasm-delegations.h @@ -66,6 +66,7 @@ DELEGATE(TupleMake); DELEGATE(TupleExtract); DELEGATE(I31New); DELEGATE(I31Get); +DELEGATE(CallRef); DELEGATE(RefTest); DELEGATE(RefCast); DELEGATE(BrOnCast); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 406938a56..37719d4d9 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1272,6 +1272,7 @@ public: WASM_UNREACHABLE("unimp"); } Flow visitPop(Pop* curr) { WASM_UNREACHABLE("unimp"); } + Flow visitCallRef(CallRef* curr) { WASM_UNREACHABLE("unimp"); } Flow visitRefNull(RefNull* curr) { NOTE_ENTER("RefNull"); return Literal::makeNull(curr->type); @@ -1593,11 +1594,14 @@ public: } return Flow(NONCONSTANT_FLOW); } - Flow visitCallIndirect(CallIndirect* curr) { NOTE_ENTER("CallIndirect"); return Flow(NONCONSTANT_FLOW); } + Flow visitCallRef(CallRef* curr) { + NOTE_ENTER("CallRef"); + return Flow(NONCONSTANT_FLOW); + } Flow visitLoad(Load* curr) { NOTE_ENTER("Load"); return Flow(NONCONSTANT_FLOW); @@ -2095,6 +2099,34 @@ private: } return ret; } + Flow visitCallRef(CallRef* curr) { + NOTE_ENTER("CallRef"); + LiteralList arguments; + Flow flow = this->generateArguments(curr->operands, arguments); + if (flow.breaking()) { + return flow; + } + Flow target = this->visit(curr->target); + if (target.breaking()) { + return target; + } + Name funcName = target.getSingleValue().getFunc(); + auto* func = instance.wasm.getFunction(funcName); + Flow ret; + if (func->imported()) { + ret.values = instance.externalInterface->callImport(func, arguments); + } else { + ret.values = instance.callFunctionInternal(funcName, arguments); + } +#ifdef WASM_INTERPRETER_DEBUG + std::cout << "(returned to " << scope.function->name << ")\n"; +#endif + // TODO: make this a proper tail call (return first) + if (curr->isReturn) { + ret.breakTo = RETURN_FLOW; + } + return ret; + } Flow visitLocalGet(LocalGet* curr) { NOTE_ENTER("LocalGet"); diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 085b58ba0..9a501171d 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -77,6 +77,11 @@ public: Element* setString(cashew::IString str__, bool dollared__, bool quoted__); Element* setMetadata(size_t line_, size_t col_, SourceLocation* startLoc_); + // comparisons + bool operator==(Name name) { return isStr() && str() == name; } + + template<typename T> bool operator!=(T t) { return !(*this == t); } + // printing friend std::ostream& operator<<(std::ostream& o, Element& e); void dump(); @@ -144,6 +149,7 @@ private: UniqueNameMapper nameMapper; + // Given a function signature type's name, return the signature Signature getFunctionSignature(Element& s); Name getFunctionName(Element& s); Name getGlobalName(Element& s); @@ -246,6 +252,7 @@ private: Expression* makeBrOnExn(Element& s); Expression* makeTupleMake(Element& s); Expression* makeTupleExtract(Element& s); + Expression* makeCallRef(Element& s, bool isReturn); Expression* makeI31New(Element& s); Expression* makeI31Get(Element& s, bool signed_); Expression* makeRefTest(Element& s); @@ -288,6 +295,7 @@ private: void parseTable(Element& s, bool preParseImport = false); void parseElem(Element& s); void parseInnerElem(Element& s, Index i = 1, Expression* offset = nullptr); + Signature parseInlineFunctionSignature(Element& s); void parseType(Element& s); void parseEvent(Element& s, bool preParseImport = false); diff --git a/src/wasm.h b/src/wasm.h index e9fb4461b..6367fa345 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -588,6 +588,7 @@ public: TupleExtractId, I31NewId, I31GetId, + CallRefId, RefTestId, RefCastId, BrOnCastId, @@ -1294,6 +1295,17 @@ public: void finalize(); }; +class CallRef : public SpecificExpression<Expression::CallRefId> { +public: + CallRef(MixedArena& allocator) : operands(allocator) {} + ExpressionList operands; + Expression* target; + bool isReturn = false; + + void finalize(); + void finalize(Type type_); +}; + class RefTest : public SpecificExpression<Expression::RefTestId> { public: RefTest(MixedArena& allocator) {} diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index a96039bc2..20b0899a5 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -2760,6 +2760,16 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitMemoryGrow(grow); break; } + case BinaryConsts::CallRef: + visitCallRef((curr = allocator.alloc<CallRef>())->cast<CallRef>()); + break; + case BinaryConsts::RetCallRef: { + auto call = allocator.alloc<CallRef>(); + call->isReturn = true; + curr = call; + visitCallRef(call); + break; + } case BinaryConsts::AtomicPrefix: { code = static_cast<uint8_t>(getU32LEB()); if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) { @@ -5426,6 +5436,26 @@ void WasmBinaryBuilder::visitBrOnExn(BrOnExn* curr) { curr->finalize(); } +void WasmBinaryBuilder::visitCallRef(CallRef* curr) { + BYN_TRACE("zz node: CallRef\n"); + curr->target = popNonVoidExpression(); + auto type = curr->target->type; + if (!type.isRef()) { + throwError("Non-ref type for a call_ref: " + type.toString()); + } + auto heapType = type.getHeapType(); + if (!heapType.isSignature()) { + throwError("Invalid reference type for a call_ref: " + type.toString()); + } + auto sig = heapType.getSignature(); + auto num = sig.params.size(); + curr->operands.resize(num); + for (size_t i = 0; i < num; i++) { + curr->operands[num - i - 1] = popNonVoidExpression(); + } + curr->finalize(sig.results); +} + bool WasmBinaryBuilder::maybeVisitI31New(Expression*& out, uint32_t code) { if (code != BinaryConsts::I31New) { return false; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 6286ae090..d8d9fa779 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -539,11 +539,11 @@ SExpressionWasmBuilder::parseParamOrLocal(Element& s, size_t& localIndex) { if (s[i]->isStr()) { type = stringToType(s[i]->str()); } else { - if (elementStartsWith(s, PARAM)) { + type = elementToType(*s[i]); + if (elementStartsWith(s, PARAM) && type.isTuple()) { throw ParseException( "params may not have tuple types", s[i]->line, s[i]->col); } - type = elementToType(*s[i]); } namedParams.emplace_back(name, type); } @@ -925,10 +925,48 @@ Type SExpressionWasmBuilder::elementToType(Element& s) { if (s.isStr()) { return stringToType(s.str(), false, false); } - auto& tuple = s.list(); + auto& list = s.list(); + auto size = list.size(); + if (size > 0 && elementStartsWith(s, REF)) { + // It's a reference. It should be in the form + // (ref $name) + // or + // (ref null $name) + // and also $name can be the expanded structure of the type and not a name, + // so something like (ref (func (result i32))), etc. + if (size != 2 && size != 3) { + throw ParseException( + std::string("invalid reference type size"), s.line, s.col); + } + if (size == 3 && *list[1] != NULL_) { + throw ParseException( + std::string("invalid reference type qualifier"), s.line, s.col); + } + bool nullable = false; + size_t i = 1; + if (size == 3) { + nullable = true; + i++; + } + Signature sig; + auto& last = *s[i]; + if (last.isStr()) { + // A string name of a signature. + sig = getFunctionSignature(last); + } else { + // A signature written out in full in-line. + if (*last[0] != FUNC) { + throw ParseException( + std::string("invalid reference type type"), s.line, s.col); + } + sig = parseInlineFunctionSignature(last); + } + return Type(HeapType(sig), nullable); + } + // It's a tuple. std::vector<Type> types; for (size_t i = 0; i < s.size(); ++i) { - types.push_back(stringToType(tuple[i]->str())); + types.push_back(stringToType(list[i]->str())); } return Type(types); } @@ -2026,6 +2064,24 @@ Expression* SExpressionWasmBuilder::makeTupleExtract(Element& s) { return ret; } +Expression* SExpressionWasmBuilder::makeCallRef(Element& s, bool isReturn) { + auto ret = allocator.alloc<CallRef>(); + parseCallOperands(s, 1, s.size() - 1, ret); + ret->target = parseExpression(s[s.size() - 1]); + ret->isReturn = isReturn; + if (!ret->target->type.isRef()) { + throw ParseException("Non-reference type for a call_ref", s.line, s.col); + } + auto heapType = ret->target->type.getHeapType(); + if (!heapType.isSignature()) { + throw ParseException( + "Invalid reference type for a call_ref", s.line, s.col); + } + auto sig = heapType.getSignature(); + ret->finalize(sig.results); + return ret; +} + Expression* SExpressionWasmBuilder::makeI31New(Element& s) { auto ret = allocator.alloc<I31New>(); ret->value = parseExpression(s[1]); @@ -2710,9 +2766,26 @@ void SExpressionWasmBuilder::parseInnerElem(Element& s, wasm.table.segments.push_back(segment); } -void SExpressionWasmBuilder::parseType(Element& s) { +Signature SExpressionWasmBuilder::parseInlineFunctionSignature(Element& s) { + if (*s[0] != FUNC) { + throw ParseException("invalid inline function signature", s.line, s.col); + } std::vector<Type> params; std::vector<Type> results; + for (size_t k = 1; k < s.size(); k++) { + Element& curr = *s[k]; + if (elementStartsWith(curr, PARAM)) { + auto newParams = parseParamOrLocal(curr); + params.insert(params.end(), newParams.begin(), newParams.end()); + } else if (elementStartsWith(curr, RESULT)) { + auto newResults = parseResults(curr); + results.insert(results.end(), newResults.begin(), newResults.end()); + } + } + return Signature(Type(params), Type(results)); +} + +void SExpressionWasmBuilder::parseType(Element& s) { size_t i = 1; if (s[i]->isStr()) { std::string name = s[i]->str().str; @@ -2722,18 +2795,7 @@ void SExpressionWasmBuilder::parseType(Element& s) { signatureIndices[name] = signatures.size(); i++; } - Element& func = *s[i]; - for (size_t k = 1; k < func.size(); k++) { - Element& curr = *func[k]; - if (elementStartsWith(curr, PARAM)) { - auto newParams = parseParamOrLocal(curr); - params.insert(params.end(), newParams.begin(), newParams.end()); - } else if (elementStartsWith(curr, RESULT)) { - auto newResults = parseResults(curr); - results.insert(results.end(), newResults.begin(), newResults.end()); - } - } - signatures.emplace_back(Type(params), Type(results)); + signatures.emplace_back(parseInlineFunctionSignature(*s[i])); } void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 021b05cb6..a6a4ce171 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -1875,6 +1875,11 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { << U32LEB(curr->signed_ ? BinaryConsts::I31GetS : BinaryConsts::I31GetU); } +void BinaryInstWriter::visitCallRef(CallRef* curr) { + o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef + : BinaryConsts::CallRef); +} + void BinaryInstWriter::visitRefTest(RefTest* curr) { o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::RefTest); WASM_UNREACHABLE("TODO (gc): ref.test"); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index cf4404739..7de274e9c 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -394,6 +394,9 @@ bool Type::operator<(const Type& other) const { return false; } // Both are compound. + if (a.isNullable() != b.isNullable()) { + return a.isNullable(); + } auto aHeap = a.getHeapType(); auto bHeap = b.getHeapType(); if (aHeap.isSignature() && bHeap.isSignature()) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 78e123a90..5faa8b2f5 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -339,6 +339,7 @@ public: void visitBrOnExn(BrOnExn* curr); void visitTupleMake(TupleMake* curr); void visitTupleExtract(TupleExtract* curr); + void visitCallRef(CallRef* curr); void visitI31New(I31New* curr); void visitI31Get(I31Get* curr); void visitRefTest(RefTest* curr); @@ -406,6 +407,49 @@ private: size_t align, Type type, Index bytes, bool isAtomic, Expression* curr); void validateMemBytes(uint8_t bytes, Type type, Expression* curr); + template<typename T> void validateReturnCall(T* curr) { + shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), + curr, + "return_call* requires tail calls to be enabled"); + } + + template<typename T> + void validateCallParamsAndResult(T* curr, Signature sig) { + if (!shouldBeTrue(curr->operands.size() == sig.params.size(), + curr, + "call* param number must match")) { + return; + } + size_t i = 0; + for (const auto& param : sig.params) { + if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, + param, + curr, + "call param types must match") && + !info.quiet) { + getStream() << "(on argument " << i << ")\n"; + } + ++i; + } + if (curr->isReturn) { + shouldBeEqual(curr->type, + Type(Type::unreachable), + curr, + "return_call* should have unreachable type"); + shouldBeEqual( + getFunction()->sig.results, + sig.results, + curr, + "return_call* callee return type must match caller return type"); + } else { + shouldBeEqualOrFirstIsUnreachable( + curr->type, + sig.results, + curr, + "call* type must match callee return type"); + } + } + Type indexType() { return getModule()->memory.indexType; } }; @@ -748,9 +792,7 @@ void FunctionValidator::visitSwitch(Switch* curr) { } void FunctionValidator::visitCall(Call* curr) { - shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), - curr, - "return_call requires tail calls to be enabled"); + validateReturnCall(curr); if (!info.validateGlobally) { return; } @@ -758,104 +800,16 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - if (!shouldBeTrue(curr->operands.size() == target->sig.params.size(), - curr, - "call param number must match")) { - return; - } - size_t i = 0; - for (const auto& param : target->sig.params) { - if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, - param, - curr, - "call param types must match") && - !info.quiet) { - getStream() << "(on argument " << i << ")\n"; - } - ++i; - } - if (curr->isReturn) { - shouldBeEqual(curr->type, - Type(Type::unreachable), - curr, - "return_call should have unreachable type"); - shouldBeEqual( - getFunction()->sig.results, - target->sig.results, - curr, - "return_call callee return type must match caller return type"); - } else { - if (curr->type == Type::unreachable) { - bool hasUnreachableOperand = std::any_of( - curr->operands.begin(), curr->operands.end(), [](Expression* op) { - return op->type == Type::unreachable; - }); - shouldBeTrue( - hasUnreachableOperand, - curr, - "calls may only be unreachable if they have unreachable operands"); - } else { - shouldBeEqual(curr->type, - target->sig.results, - curr, - "call type must match callee return type"); - } - } + validateCallParamsAndResult(curr, target->sig); } void FunctionValidator::visitCallIndirect(CallIndirect* curr) { - shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), - curr, - "return_call_indirect requires tail calls to be enabled"); + validateReturnCall(curr); shouldBeEqualOrFirstIsUnreachable(curr->target->type, Type(Type::i32), curr, "indirect call target must be an i32"); - if (!shouldBeTrue(curr->operands.size() == curr->sig.params.size(), - curr, - "call param number must match")) { - return; - } - size_t i = 0; - for (const auto& param : curr->sig.params) { - if (!shouldBeSubTypeOrFirstIsUnreachable(curr->operands[i]->type, - param, - curr, - "call param types must match") && - !info.quiet) { - getStream() << "(on argument " << i << ")\n"; - } - ++i; - } - if (curr->isReturn) { - shouldBeEqual(curr->type, - Type(Type::unreachable), - curr, - "return_call_indirect should have unreachable type"); - shouldBeEqual( - getFunction()->sig.results, - curr->sig.results, - curr, - "return_call_indirect callee return type must match caller return type"); - } else { - if (curr->type == Type::unreachable) { - if (curr->target->type != Type::unreachable) { - bool hasUnreachableOperand = std::any_of( - curr->operands.begin(), curr->operands.end(), [](Expression* op) { - return op->type == Type::unreachable; - }); - shouldBeTrue(hasUnreachableOperand, - curr, - "call_indirects may only be unreachable if they have " - "unreachable operands"); - } - } else { - shouldBeEqual(curr->type, - curr->sig.results, - curr, - "call_indirect type must match callee return type"); - } - } + validateCallParamsAndResult(curr, curr->sig); } void FunctionValidator::visitConst(Const* curr) { @@ -2199,6 +2153,20 @@ void FunctionValidator::visitTupleExtract(TupleExtract* curr) { } } +void FunctionValidator::visitCallRef(CallRef* curr) { + validateReturnCall(curr); + shouldBeTrue(getModule()->features.hasTypedFunctionReferences(), + curr, + "call_ref requires typed-function-references to be enabled"); + shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference"); + if (curr->target->type != Type::unreachable) { + validateCallParamsAndResult( + curr, curr->target->type.getHeapType().getSignature()); + } +} + void FunctionValidator::visitI31New(I31New* curr) { shouldBeTrue( getModule()->features.hasGC(), curr, "i31.new requires gc to be enabled"); diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 6245a3575..ac76a63ac 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -73,6 +73,8 @@ Name TABLE("table"); Name ELEM("elem"); Name LOCAL("local"); Name TYPE("type"); +Name REF("ref"); +Name NULL_("null"); Name CALL("call"); Name CALL_INDIRECT("call_indirect"); Name BLOCK("block"); @@ -212,6 +214,8 @@ const char* getExpressionName(Expression* curr) { return "i31.new"; case Expression::Id::I31GetId: return "i31.get"; + case Expression::Id::CallRefId: + return "call_ref"; case Expression::Id::RefTestId: return "ref.test"; case Expression::Id::RefCastId: @@ -1060,6 +1064,21 @@ void I31Get::finalize() { } } +void CallRef::finalize() { + handleUnreachableOperands(this); + if (isReturn) { + type = Type::unreachable; + } + if (target->type == Type::unreachable) { + type = Type::unreachable; + } +} + +void CallRef::finalize(Type type_) { + type = type_; + finalize(); +} + // TODO (gc): ref.test // TODO (gc): ref.cast // TODO (gc): br_on_cast diff --git a/src/wasm2js.h b/src/wasm2js.h index 7a0f4692d..ddc3d7313 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -2198,6 +2198,10 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m, unimplemented(curr); WASM_UNREACHABLE("unimp"); } + Ref visitCallRef(CallRef* curr) { + unimplemented(curr); + WASM_UNREACHABLE("unimp"); + } Ref visitRefTest(RefTest* curr) { unimplemented(curr); WASM_UNREACHABLE("unimp"); |