From 58bedde3ac54f82657d5de092e7142ffb2ff735c Mon Sep 17 00:00:00 2001 From: Thomas Lively <7121787+tlively@users.noreply.github.com> Date: Thu, 22 Sep 2022 18:00:50 -0500 Subject: Add a type annotation to return_call_ref (#5068) The GC spec has been updated to have heap type annotations on call_ref and return_call_ref. To avoid breaking users, we will have a graceful, multi-step upgrade to the annotated version of call_ref, but since return_call_ref has no users yet, update it in a single step. --- src/ir/module-utils.cpp | 4 ++++ src/passes/Print.cpp | 47 +++++++++++++++++++++++++++++----------------- src/wasm-binary.h | 3 ++- src/wasm/wasm-binary.cpp | 42 +++++++++++++++++++++++++++-------------- src/wasm/wasm-s-parser.cpp | 17 ++++++++++++++++- src/wasm/wasm-stack.cpp | 10 ++++++++-- 6 files changed, 88 insertions(+), 35 deletions(-) (limited to 'src') diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 33f8ac926..84cd66c60 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -53,6 +53,10 @@ struct CodeScanner void visitExpression(Expression* curr) { if (auto* call = curr->dynCast()) { counts.note(call->heapType); + } else if (auto* call = curr->dynCast()) { + if (call->isReturn && call->target->type.isFunction()) { + counts.note(call->target->type); + } } else if (curr->is()) { counts.note(curr->type); } else if (auto* make = curr->dynCast()) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index d1e08e0ec..04d487c08 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2043,9 +2043,31 @@ struct PrintExpressionContents void visitI31Get(I31Get* curr) { printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u"); } + + // If we cannot print a valid unreachable instruction (say, a struct.get, + // where if the ref is unreachable, we don't know what heap type to print), + // then print the children in a block, which is good enough as this + // instruction is never reached anyhow. + // + // This function checks if the input is in fact unreachable, and if so, begins + // to emit a replacement for it and returns true. + bool printUnreachableReplacement(Expression* curr) { + if (curr->type == Type::unreachable) { + printMedium(o, "block"); + return true; + } + return false; + } + void visitCallRef(CallRef* curr) { if (curr->isReturn) { - printMedium(o, "return_call_ref"); + if (printUnreachableReplacement(curr->target)) { + return; + } + printMedium(o, "return_call_ref "); + assert(curr->target->type != Type::unreachable); + // TODO: Workaround if target has bottom type. + printHeapType(o, curr->target->type.getHeapType(), wasm); } else { printMedium(o, "call_ref"); } @@ -2106,22 +2128,6 @@ struct PrintExpressionContents } printName(curr->name, o); } - - // If we cannot print a valid unreachable instruction (say, a struct.get, - // where if the ref is unreachable, we don't know what heap type to print), - // then print the children in a block, which is good enough as this - // instruction is never reached anyhow. - // - // This function checks if the input is in fact unreachable, and if so, begins - // to emit a replacement for it and returns true. - bool printUnreachableReplacement(Expression* curr) { - if (curr->type == Type::unreachable) { - printMedium(o, "block"); - return true; - } - return false; - } - void visitStructNew(StructNew* curr) { if (printUnreachableReplacement(curr)) { return; @@ -2748,6 +2754,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor { } decIndent(); } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + maybePrintUnreachableReplacement(curr, curr->target->type); + } else { + visitExpression(curr); + } + } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 14a77ea41..705770bfc 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1727,7 +1727,8 @@ public: void visitTryOrTryInBlock(Expression*& out); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); - void visitCallRef(CallRef* curr); + void visitCallRef(CallRef* curr, + std::optional maybeType = std::nullopt); void visitRefAs(RefAs* curr, uint8_t code); [[noreturn]] void throwError(std::string text); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 9540bd8f3..1de5e4bd1 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -3783,7 +3783,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { auto call = allocator.alloc(); call->isReturn = true; curr = call; - visitCallRef(call); + visitCallRef(call, getTypeByIndex(getU32LEB())); break; } case BinaryConsts::AtomicPrefix: { @@ -6777,22 +6777,32 @@ void WasmBinaryBuilder::visitRethrow(Rethrow* curr) { curr->finalize(); } -void WasmBinaryBuilder::visitCallRef(CallRef* curr) { +void WasmBinaryBuilder::visitCallRef(CallRef* curr, + std::optional maybeType) { BYN_TRACE("zz node: CallRef\n"); curr->target = popNonVoidExpression(); - auto type = curr->target->type; - if (type == Type::unreachable) { - // If our input is unreachable, then we cannot even find out how many inputs - // we have, and just set ourselves to unreachable as well. - curr->finalize(type); - return; - } - if (!type.isRef()) { - throwError("Non-ref type for a call_ref: " + type.toString()); + HeapType heapType; + if (maybeType) { + heapType = *maybeType; + if (!Type::isSubType(curr->target->type, Type(heapType, Nullable))) { + throwError("Call target has invalid type: " + + curr->target->type.toString()); + } + } else { + auto type = curr->target->type; + if (type == Type::unreachable) { + // If our input is unreachable, then we cannot even find out how many + // inputs we have, and just set ourselves to unreachable as well. + curr->finalize(type); + return; + } + if (!type.isRef()) { + throwError("Non-ref type for a call_ref: " + type.toString()); + } + heapType = type.getHeapType(); } - auto heapType = type.getHeapType(); if (!heapType.isSignature()) { - throwError("Invalid reference type for a call_ref: " + type.toString()); + throwError("Invalid reference type for a call_ref: " + heapType.toString()); } auto sig = heapType.getSignature(); auto num = sig.params.size(); @@ -6800,7 +6810,11 @@ void WasmBinaryBuilder::visitCallRef(CallRef* curr) { for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->finalize(sig.results); + if (maybeType) { + curr->finalize(); + } else { + curr->finalize(sig.results); + } } bool WasmBinaryBuilder::maybeVisitI31New(Expression*& out, uint32_t code) { diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index ef41dec93..cf4824323 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -2830,9 +2830,24 @@ Expression* SExpressionWasmBuilder::makeTupleExtract(Element& s) { } Expression* SExpressionWasmBuilder::makeCallRef(Element& s, bool isReturn) { + Index operandsStart = 1; + HeapType sigType; + if (isReturn) { + sigType = parseHeapType(*s[1]); + operandsStart = 2; + } std::vector operands; - parseOperands(s, 1, s.size() - 1, operands); + parseOperands(s, operandsStart, s.size() - 1, operands); auto* target = parseExpression(s[s.size() - 1]); + + if (isReturn) { + if (!sigType.isSignature()) { + throw ParseException( + "return_call_ref type annotation should be a signature", s.line, s.col); + } + return Builder(wasm).makeCallRef( + target, operands, sigType.getSignature().results, isReturn); + } return ValidatingBuilder(wasm, s.line, s.col) .validateAndMakeCallRef(target, operands, isReturn); } diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index d88138075..243b2810b 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2013,8 +2013,14 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { } void BinaryInstWriter::visitCallRef(CallRef* curr) { - o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef - : BinaryConsts::CallRef); + if (curr->isReturn) { + assert(curr->target->type != Type::unreachable); + // TODO: `emitUnreachable` if target has bottom type. + o << int8_t(BinaryConsts::RetCallRef); + parent.writeIndexedHeapType(curr->target->type.getHeapType()); + return; + } + o << int8_t(BinaryConsts::CallRef); } void BinaryInstWriter::visitRefTest(RefTest* curr) { -- cgit v1.2.3