diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/gen-s-parser.inc | 26 | ||||
-rw-r--r-- | src/passes/Print.cpp | 13 | ||||
-rw-r--r-- | src/tools/tool-options.h | 1 | ||||
-rw-r--r-- | src/wasm-binary.h | 3 | ||||
-rw-r--r-- | src/wasm-builder.h | 24 | ||||
-rw-r--r-- | src/wasm-features.h | 11 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 4 | ||||
-rw-r--r-- | src/wasm-stack.h | 10 | ||||
-rw-r--r-- | src/wasm.h | 2 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 18 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 7 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 6 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 1 |
13 files changed, 99 insertions, 27 deletions
diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc index 5750981f5..93d2f7578 100644 --- a/src/gen-s-parser.inc +++ b/src/gen-s-parser.inc @@ -40,10 +40,10 @@ switch (op[0]) { case 'c': { switch (op[4]) { case '\0': - if (strcmp(op, "call") == 0) { return makeCall(s); } + if (strcmp(op, "call") == 0) { return makeCall(s, /*isReturn=*/false); } goto parse_error; case '_': - if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s); } + if (strcmp(op, "call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/false); } goto parse_error; default: goto parse_error; } @@ -2228,9 +2228,25 @@ switch (op[0]) { case 'p': if (strcmp(op, "push") == 0) { return makePush(s); } goto parse_error; - case 'r': - if (strcmp(op, "return") == 0) { return makeReturn(s); } - goto parse_error; + case 'r': { + switch (op[6]) { + case '\0': + if (strcmp(op, "return") == 0) { return makeReturn(s); } + goto parse_error; + case '_': { + switch (op[11]) { + case '\0': + if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); } + goto parse_error; + case '_': + if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); } + goto parse_error; + default: goto parse_error; + } + } + default: goto parse_error; + } + } case 's': if (strcmp(op, "select") == 0) { return makeSelect(s); } goto parse_error; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index e1ed788c4..bae5b25d4 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -106,11 +106,20 @@ struct PrintExpressionContents o << ' ' << curr->default_; } void visitCall(Call* curr) { - printMedium(o, "call "); + if (curr->isReturn) { + printMedium(o, "return_call "); + } else { + printMedium(o, "call "); + } printName(curr->target, o); } void visitCallIndirect(CallIndirect* curr) { - printMedium(o, "call_indirect (type ") << curr->fullType << ')'; + if (curr->isReturn) { + printMedium(o, "return_call_indirect (type "); + } else { + printMedium(o, "call_indirect (type "); + } + o << curr->fullType << ')'; } void visitLocalGet(LocalGet* curr) { printMedium(o, "local.get ") << printableLocal(curr->index, currFunction); diff --git a/src/tools/tool-options.h b/src/tools/tool-options.h index 016154252..7c1151c7f 100644 --- a/src/tools/tool-options.h +++ b/src/tools/tool-options.h @@ -70,6 +70,7 @@ struct ToolOptions : public Options { .addFeature(FeatureSet::BulkMemory, "bulk memory operations") .addFeature(FeatureSet::ExceptionHandling, "exception handling operations") + .addFeature(FeatureSet::TailCall, "tail call operations") .add("--no-validation", "-n", "Disables validation, assumes inputs are correct", diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 8f1cbb7de..5e94f5095 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -405,6 +405,7 @@ extern const char* TruncSatFeature; extern const char* SignExtFeature; extern const char* SIMD128Feature; extern const char* ExceptionHandlingFeature; +extern const char* TailCallFeature; enum Subsection { NameFunction = 1, @@ -429,6 +430,8 @@ enum ASTNodes { CallFunction = 0x10, CallIndirect = 0x11, + RetCallFunction = 0x12, + RetCallIndirect = 0x13, Drop = 0x1a, Select = 0x1b, diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 8bf60e59c..14d308d91 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -187,41 +187,45 @@ public: ret->condition = condition; return ret; } - Call* makeCall(Name target, const std::vector<Expression*>& args, Type type) { + Call* makeCall(Name target, + const std::vector<Expression*>& args, + Type type, + bool isReturn = false) { auto* call = allocator.alloc<Call>(); // not all functions may exist yet, so type must be provided call->type = type; call->target = target; call->operands.set(args); + call->isReturn = isReturn; return call; } - template<typename T> Call* makeCall(Name target, const T& args, Type type) { + template<typename T> + Call* makeCall(Name target, const T& args, Type type, bool isReturn = false) { auto* call = allocator.alloc<Call>(); // not all functions may exist yet, so type must be provided call->type = type; call->target = target; call->operands.set(args); + call->isReturn = isReturn; return call; } CallIndirect* makeCallIndirect(FunctionType* type, Expression* target, - const std::vector<Expression*>& args) { - auto* call = allocator.alloc<CallIndirect>(); - call->fullType = type->name; - call->type = type->result; - call->target = target; - call->operands.set(args); - return call; + const std::vector<Expression*>& args, + bool isReturn = false) { + return makeCallIndirect(type->name, target, args, type->result, isReturn); } CallIndirect* makeCallIndirect(Name fullType, Expression* target, const std::vector<Expression*>& args, - Type type) { + Type type, + bool isReturn = false) { auto* call = allocator.alloc<CallIndirect>(); call->fullType = fullType; call->type = type; call->target = target; call->operands.set(args); + call->isReturn = isReturn; return call; } // FunctionType diff --git a/src/wasm-features.h b/src/wasm-features.h index 325f413a1..d1789a67a 100644 --- a/src/wasm-features.h +++ b/src/wasm-features.h @@ -32,8 +32,8 @@ struct FeatureSet { BulkMemory = 1 << 4, SignExt = 1 << 5, ExceptionHandling = 1 << 6, - All = Atomics | MutableGlobals | TruncSat | SIMD | BulkMemory | SignExt | - ExceptionHandling + TailCall = 1 << 7, + All = (1 << 8) - 1 }; static std::string toString(Feature f) { @@ -52,6 +52,8 @@ struct FeatureSet { return "sign-ext"; case ExceptionHandling: return "exception-handling"; + case TailCall: + return "tail-call"; default: WASM_UNREACHABLE(); } @@ -69,6 +71,7 @@ struct FeatureSet { bool hasBulkMemory() const { return features & BulkMemory; } bool hasSignExt() const { return features & SignExt; } bool hasExceptionHandling() const { return features & ExceptionHandling; } + bool hasTailCall() const { return features & TailCall; } bool hasAll() const { return features & All; } void makeMVP() { features = MVP; } @@ -82,6 +85,7 @@ struct FeatureSet { void setBulkMemory(bool v = true) { set(BulkMemory, v); } void setSignExt(bool v = true) { set(SignExt, v); } void setExceptionHandling(bool v = true) { set(ExceptionHandling, v); } + void setTailCall(bool v = true) { set(TailCall, v); } void setAll(bool v = true) { features = v ? All : MVP; } void enable(const FeatureSet& other) { features |= other.features; } @@ -111,6 +115,9 @@ struct FeatureSet { if (hasSIMD()) { f(SIMD); } + if (hasTailCall()) { + f(TailCall); + } } bool operator<=(const FeatureSet& other) const { diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index e87eaae4e..c07cc49a9 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -208,8 +208,8 @@ private: Expression* makeIf(Element& s); Expression* makeMaybeBlock(Element& s, size_t i, Type type); Expression* makeLoop(Element& s); - Expression* makeCall(Element& s); - Expression* makeCallIndirect(Element& s); + Expression* makeCall(Element& s, bool isReturn); + Expression* makeCallIndirect(Element& s, bool isReturn); template<class T> void parseCallOperands(Element& s, Index i, Index j, T* call) { while (i < j) { diff --git a/src/wasm-stack.h b/src/wasm-stack.h index 7003b7b95..007e794b9 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -626,8 +626,9 @@ void StackWriter<Mode, Parent>::visitCall(Call* curr) { visitChild(operand); } if (!justAddToStack(curr)) { - o << int8_t(BinaryConsts::CallFunction) - << U32LEB(parent.getFunctionIndex(curr->target)); + int8_t op = curr->isReturn ? BinaryConsts::RetCallFunction + : BinaryConsts::CallFunction; + o << op << U32LEB(parent.getFunctionIndex(curr->target)); } // TODO FIXME: this and similar can be removed if (curr->type == unreachable) { @@ -642,8 +643,9 @@ void StackWriter<Mode, Parent>::visitCallIndirect(CallIndirect* curr) { } visitChild(curr->target); if (!justAddToStack(curr)) { - o << int8_t(BinaryConsts::CallIndirect) - << U32LEB(parent.getFunctionTypeIndex(curr->fullType)) + int8_t op = curr->isReturn ? BinaryConsts::RetCallIndirect + : BinaryConsts::CallIndirect; + o << op << U32LEB(parent.getFunctionTypeIndex(curr->fullType)) << U32LEB(0); // Reserved flags field } if (curr->type == unreachable) { diff --git a/src/wasm.h b/src/wasm.h index 0c4576f23..3ac89b91d 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -621,6 +621,7 @@ public: ExpressionList operands; Name target; + bool isReturn = false; void finalize(); }; @@ -647,6 +648,7 @@ public: ExpressionList operands; Name fullType; Expression* target; + bool isReturn = false; void finalize(); }; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index f1c6eeb7c..d15ac493f 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -679,6 +679,8 @@ void WasmBinaryWriter::writeFeaturesSection() { return BinaryConsts::UserSections::SignExtFeature; case FeatureSet::ExceptionHandling: return BinaryConsts::UserSections::ExceptionHandlingFeature; + case FeatureSet::TailCall: + return BinaryConsts::UserSections::TailCallFeature; default: WASM_UNREACHABLE(); } @@ -2162,6 +2164,8 @@ void WasmBinaryBuilder::readFeatures(size_t payloadLen) { wasm.features.setSignExt(); } else if (name == BinaryConsts::UserSections::SIMD128Feature) { wasm.features.setSIMD(); + } else if (name == BinaryConsts::UserSections::TailCallFeature) { + wasm.features.setTailCall(); } } } @@ -2210,6 +2214,20 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitCallIndirect( (curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break; + case BinaryConsts::RetCallFunction: { + auto call = allocator.alloc<Call>(); + call->isReturn = true; + curr = call; + visitCall(call); + break; + } + case BinaryConsts::RetCallIndirect: { + auto call = allocator.alloc<CallIndirect>(); + call->isReturn = true; + curr = call; + visitCallIndirect(call); + break; + } case BinaryConsts::LocalGet: visitLocalGet((curr = allocator.alloc<LocalGet>())->cast<LocalGet>()); break; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 04c31815f..7a2d8a009 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1622,17 +1622,19 @@ Expression* SExpressionWasmBuilder::makeLoop(Element& s) { return ret; } -Expression* SExpressionWasmBuilder::makeCall(Element& s) { +Expression* SExpressionWasmBuilder::makeCall(Element& s, bool isReturn) { auto target = getFunctionName(*s[1]); auto ret = allocator.alloc<Call>(); ret->target = target; ret->type = functionTypes[ret->target]; parseCallOperands(s, 2, s.size(), ret); + ret->isReturn = isReturn; ret->finalize(); return ret; } -Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s) { +Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s, + bool isReturn) { if (!wasm.table.exists) { throw ParseException("no table"); } @@ -1645,6 +1647,7 @@ Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s) { ret->type = functionType->result; parseCallOperands(s, i, s.size() - 1, ret); ret->target = parseExpression(s[s.size() - 1]); + ret->isReturn = isReturn; ret->finalize(); return ret; } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 1d12c2452..01cb1e2a8 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -569,6 +569,9 @@ void FunctionValidator::visitSwitch(Switch* curr) { } void FunctionValidator::visitCall(Call* curr) { + shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), + curr, + "return_call requires tail calls to be enabled"); if (!info.validateGlobally) { return; } @@ -593,6 +596,9 @@ void FunctionValidator::visitCall(Call* curr) { } void FunctionValidator::visitCallIndirect(CallIndirect* curr) { + shouldBeTrue(!curr->isReturn || getModule()->features.hasTailCall(), + curr, + "return_call_indirect requires tail calls to be enabled"); if (!info.validateGlobally) { return; } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index a307d95f4..e4c9813b3 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -40,6 +40,7 @@ const char* MutableGlobalsFeature = "mutable-globals"; const char* TruncSatFeature = "nontrapping-fptoint"; const char* SignExtFeature = "sign-ext"; const char* SIMD128Feature = "simd128"; +const char* TailCallFeature = "tail-call"; } // namespace UserSections } // namespace BinaryConsts |