diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2022-09-22 18:00:50 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-22 16:00:50 -0700 |
commit | 58bedde3ac54f82657d5de092e7142ffb2ff735c (patch) | |
tree | e486638a92aeca615439bbcd11b540f4913b98a1 | |
parent | b1ba25732c1a02ae3da726c4b01ca3825ef969ef (diff) | |
download | binaryen-58bedde3ac54f82657d5de092e7142ffb2ff735c.tar.gz binaryen-58bedde3ac54f82657d5de092e7142ffb2ff735c.tar.bz2 binaryen-58bedde3ac54f82657d5de092e7142ffb2ff735c.zip |
Add a type annotation to return_call_ref (#5068)
The GC spec has been updated to have heap type annotations on call_ref and
return_call_ref. To avoid breaking users, we will have a graceful, multi-step
upgrade to the annotated version of call_ref, but since return_call_ref has no
users yet, update it in a single step.
-rw-r--r-- | src/ir/module-utils.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 47 | ||||
-rw-r--r-- | src/wasm-binary.h | 3 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 42 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 17 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 10 | ||||
-rw-r--r-- | test/lit/passes/dae-gc-refine-return.wast | 38 | ||||
-rw-r--r-- | test/lit/passes/inlining_all-features.wast | 2 | ||||
-rw-r--r-- | test/lit/passes/optimize-instructions-call_ref.wast | 2 | ||||
-rw-r--r-- | test/lit/types-function-references.wast | 22 |
10 files changed, 126 insertions, 61 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index 33f8ac926..84cd66c60 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -53,6 +53,10 @@ struct CodeScanner void visitExpression(Expression* curr) { if (auto* call = curr->dynCast<CallIndirect>()) { counts.note(call->heapType); + } else if (auto* call = curr->dynCast<CallRef>()) { + if (call->isReturn && call->target->type.isFunction()) { + counts.note(call->target->type); + } } else if (curr->is<RefNull>()) { counts.note(curr->type); } else if (auto* make = curr->dynCast<StructNew>()) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index d1e08e0ec..04d487c08 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2043,9 +2043,31 @@ struct PrintExpressionContents void visitI31Get(I31Get* curr) { printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u"); } + + // If we cannot print a valid unreachable instruction (say, a struct.get, + // where if the ref is unreachable, we don't know what heap type to print), + // then print the children in a block, which is good enough as this + // instruction is never reached anyhow. + // + // This function checks if the input is in fact unreachable, and if so, begins + // to emit a replacement for it and returns true. + bool printUnreachableReplacement(Expression* curr) { + if (curr->type == Type::unreachable) { + printMedium(o, "block"); + return true; + } + return false; + } + void visitCallRef(CallRef* curr) { if (curr->isReturn) { - printMedium(o, "return_call_ref"); + if (printUnreachableReplacement(curr->target)) { + return; + } + printMedium(o, "return_call_ref "); + assert(curr->target->type != Type::unreachable); + // TODO: Workaround if target has bottom type. + printHeapType(o, curr->target->type.getHeapType(), wasm); } else { printMedium(o, "call_ref"); } @@ -2106,22 +2128,6 @@ struct PrintExpressionContents } printName(curr->name, o); } - - // If we cannot print a valid unreachable instruction (say, a struct.get, - // where if the ref is unreachable, we don't know what heap type to print), - // then print the children in a block, which is good enough as this - // instruction is never reached anyhow. - // - // This function checks if the input is in fact unreachable, and if so, begins - // to emit a replacement for it and returns true. - bool printUnreachableReplacement(Expression* curr) { - if (curr->type == Type::unreachable) { - printMedium(o, "block"); - return true; - } - return false; - } - void visitStructNew(StructNew* curr) { if (printUnreachableReplacement(curr)) { return; @@ -2748,6 +2754,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { } decIndent(); } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + maybePrintUnreachableReplacement(curr, curr->target->type); + } else { + visitExpression(curr); + } + } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 14a77ea41..705770bfc 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1727,7 +1727,8 @@ public: void visitTryOrTryInBlock(Expression*& out); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); - void visitCallRef(CallRef* curr); + void visitCallRef(CallRef* curr, + std::optional<HeapType> maybeType = std::nullopt); void visitRefAs(RefAs* curr, uint8_t code); [[noreturn]] void throwError(std::string text); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 9540bd8f3..1de5e4bd1 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -3783,7 +3783,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { auto call = allocator.alloc<CallRef>(); call->isReturn = true; curr = call; - visitCallRef(call); + visitCallRef(call, getTypeByIndex(getU32LEB())); break; } case BinaryConsts::AtomicPrefix: { @@ -6777,22 +6777,32 @@ void WasmBinaryBuilder::visitRethrow(Rethrow* curr) { curr->finalize(); } -void WasmBinaryBuilder::visitCallRef(CallRef* curr) { +void WasmBinaryBuilder::visitCallRef(CallRef* curr, + std::optional<HeapType> maybeType) { BYN_TRACE("zz node: CallRef\n"); curr->target = popNonVoidExpression(); - auto type = curr->target->type; - if (type == Type::unreachable) { - // If our input is unreachable, then we cannot even find out how many inputs - // we have, and just set ourselves to unreachable as well. - curr->finalize(type); - return; - } - if (!type.isRef()) { - throwError("Non-ref type for a call_ref: " + type.toString()); + HeapType heapType; + if (maybeType) { + heapType = *maybeType; + if (!Type::isSubType(curr->target->type, Type(heapType, Nullable))) { + throwError("Call target has invalid type: " + + curr->target->type.toString()); + } + } else { + auto type = curr->target->type; + if (type == Type::unreachable) { + // If our input is unreachable, then we cannot even find out how many + // inputs we have, and just set ourselves to unreachable as well. + curr->finalize(type); + return; + } + if (!type.isRef()) { + throwError("Non-ref type for a call_ref: " + type.toString()); + } + heapType = type.getHeapType(); } - auto heapType = type.getHeapType(); if (!heapType.isSignature()) { - throwError("Invalid reference type for a call_ref: " + type.toString()); + throwError("Invalid reference type for a call_ref: " + heapType.toString()); } auto sig = heapType.getSignature(); auto num = sig.params.size(); @@ -6800,7 +6810,11 @@ void WasmBinaryBuilder::visitCallRef(CallRef* curr) { for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->finalize(sig.results); + if (maybeType) { + curr->finalize(); + } else { + curr->finalize(sig.results); + } } bool WasmBinaryBuilder::maybeVisitI31New(Expression*& out, uint32_t code) { diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index ef41dec93..cf4824323 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -2830,9 +2830,24 @@ Expression* SExpressionWasmBuilder::makeTupleExtract(Element& s) { } Expression* SExpressionWasmBuilder::makeCallRef(Element& s, bool isReturn) { + Index operandsStart = 1; + HeapType sigType; + if (isReturn) { + sigType = parseHeapType(*s[1]); + operandsStart = 2; + } std::vector<Expression*> operands; - parseOperands(s, 1, s.size() - 1, operands); + parseOperands(s, operandsStart, s.size() - 1, operands); auto* target = parseExpression(s[s.size() - 1]); + + if (isReturn) { + if (!sigType.isSignature()) { + throw ParseException( + "return_call_ref type annotation should be a signature", s.line, s.col); + } + return Builder(wasm).makeCallRef( + target, operands, sigType.getSignature().results, isReturn); + } return ValidatingBuilder(wasm, s.line, s.col) .validateAndMakeCallRef(target, operands, isReturn); } diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index d88138075..243b2810b 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2013,8 +2013,14 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { } void BinaryInstWriter::visitCallRef(CallRef* curr) { - o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef - : BinaryConsts::CallRef); + if (curr->isReturn) { + assert(curr->target->type != Type::unreachable); + // TODO: `emitUnreachable` if target has bottom type. + o << int8_t(BinaryConsts::RetCallRef); + parent.writeIndexedHeapType(curr->target->type.getHeapType()); + return; + } + o << int8_t(BinaryConsts::CallRef); } void BinaryInstWriter::visitRefTest(RefTest* curr) { diff --git a/test/lit/passes/dae-gc-refine-return.wast b/test/lit/passes/dae-gc-refine-return.wast index ccf11509c..84e161651 100644 --- a/test/lit/passes/dae-gc-refine-return.wast +++ b/test/lit/passes/dae-gc-refine-return.wast @@ -586,20 +586,20 @@ ) ;; CHECK: (func $tail-caller-call_ref-yes (result (ref ${})) ;; CHECK-NEXT: (local $return_{} (ref null $return_{})) - ;; CHECK-NEXT: (return_call_ref + ;; CHECK-NEXT: (return_call_ref $return_{} ;; CHECK-NEXT: (local.get $return_{}) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) ;; NOMNL: (func $tail-caller-call_ref-yes (type $return_{}) (result (ref ${})) ;; NOMNL-NEXT: (local $return_{} (ref null $return_{})) - ;; NOMNL-NEXT: (return_call_ref + ;; NOMNL-NEXT: (return_call_ref $return_{} ;; NOMNL-NEXT: (local.get $return_{}) ;; NOMNL-NEXT: ) ;; NOMNL-NEXT: ) (func $tail-caller-call_ref-yes (result anyref) (local $return_{} (ref null $return_{})) - (return_call_ref (local.get $return_{})) + (return_call_ref $return_{} (local.get $return_{})) ) ;; CHECK: (func $tail-caller-call_ref-no (result anyref) ;; CHECK-NEXT: (local $any anyref) @@ -610,7 +610,7 @@ ;; CHECK-NEXT: (local.get $any) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (return_call_ref + ;; CHECK-NEXT: (return_call_ref $return_{} ;; CHECK-NEXT: (local.get $return_{}) ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) @@ -623,7 +623,7 @@ ;; NOMNL-NEXT: (local.get $any) ;; NOMNL-NEXT: ) ;; NOMNL-NEXT: ) - ;; NOMNL-NEXT: (return_call_ref + ;; NOMNL-NEXT: (return_call_ref $return_{} ;; NOMNL-NEXT: (local.get $return_{}) ;; NOMNL-NEXT: ) ;; NOMNL-NEXT: ) @@ -634,18 +634,26 @@ (if (i32.const 1) (return (local.get $any)) ) - (return_call_ref (local.get $return_{})) + (return_call_ref $return_{} (local.get $return_{})) ) - ;; CHECK: (func $tail-caller-call_ref-unreachable - ;; CHECK-NEXT: (unreachable) + ;; CHECK: (func $tail-caller-call_ref-unreachable (result anyref) + ;; CHECK-NEXT: (block ;; (replaces something unreachable we can't emit) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) - ;; NOMNL: (func $tail-caller-call_ref-unreachable (type $none_=>_none) - ;; NOMNL-NEXT: (unreachable) + ;; NOMNL: (func $tail-caller-call_ref-unreachable (type $none_=>_anyref) (result anyref) + ;; NOMNL-NEXT: (block ;; (replaces something unreachable we can't emit) + ;; NOMNL-NEXT: (drop + ;; NOMNL-NEXT: (unreachable) + ;; NOMNL-NEXT: ) + ;; NOMNL-NEXT: ) ;; NOMNL-NEXT: ) (func $tail-caller-call_ref-unreachable (result anyref) ;; An unreachable means there is no function signature to even look at. We ;; should not hit an assertion on such things. - (return_call_ref (unreachable)) + (return_call_ref $return_{} (unreachable)) ) ;; CHECK: (func $tail-call-caller-call_ref ;; CHECK-NEXT: (drop @@ -654,7 +662,9 @@ ;; CHECK-NEXT: (drop ;; CHECK-NEXT: (call $tail-caller-call_ref-no) ;; CHECK-NEXT: ) - ;; CHECK-NEXT: (call $tail-caller-call_ref-unreachable) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $tail-caller-call_ref-unreachable) + ;; CHECK-NEXT: ) ;; CHECK-NEXT: ) ;; NOMNL: (func $tail-call-caller-call_ref (type $none_=>_none) ;; NOMNL-NEXT: (drop @@ -663,7 +673,9 @@ ;; NOMNL-NEXT: (drop ;; NOMNL-NEXT: (call $tail-caller-call_ref-no) ;; NOMNL-NEXT: ) - ;; NOMNL-NEXT: (call $tail-caller-call_ref-unreachable) + ;; NOMNL-NEXT: (drop + ;; NOMNL-NEXT: (call $tail-caller-call_ref-unreachable) + ;; NOMNL-NEXT: ) ;; NOMNL-NEXT: ) (func $tail-call-caller-call_ref (drop diff --git a/test/lit/passes/inlining_all-features.wast b/test/lit/passes/inlining_all-features.wast index f820d1826..59afe1ce9 100644 --- a/test/lit/passes/inlining_all-features.wast +++ b/test/lit/passes/inlining_all-features.wast @@ -135,7 +135,7 @@ (export "func_36_invoker" (func $1)) (func $0 - (return_call_ref + (return_call_ref $none_=>_none (ref.null $none_=>_none) ) ) diff --git a/test/lit/passes/optimize-instructions-call_ref.wast b/test/lit/passes/optimize-instructions-call_ref.wast index 7164fda55..95e2af18c 100644 --- a/test/lit/passes/optimize-instructions-call_ref.wast +++ b/test/lit/passes/optimize-instructions-call_ref.wast @@ -316,7 +316,7 @@ (func $return_call_ref-to-select (param $x i32) (param $y i32) ;; As above, but with a return call. We optimize this too, and turn a ;; return_call_ref over a select into an if over return_calls. - (return_call_ref + (return_call_ref $i32_i32_=>_none (local.get $x) (local.get $y) (select diff --git a/test/lit/types-function-references.wast b/test/lit/types-function-references.wast index ffca23132..16ea53997 100644 --- a/test/lit/types-function-references.wast +++ b/test/lit/types-function-references.wast @@ -10,12 +10,16 @@ ;; RUN: cat %t.text.wast | filecheck %s --check-prefix=CHECK-TEXT (module - ;; inline ref type in result - (type $_=>_eqref (func (result eqref))) ;; CHECK-BINARY: (type $mixed_results (func (result anyref f32 anyref f32))) - ;; CHECK-BINARY: (type $none_=>_none (func)) + ;; CHECK-BINARY: (type $void (func)) + ;; CHECK-TEXT: (type $mixed_results (func (result anyref f32 anyref f32))) + ;; CHECK-TEXT: (type $void (func)) + (type $void (func)) + + ;; inline ref type in result + (type $_=>_eqref (func (result eqref))) ;; CHECK-BINARY: (type $i32-i32 (func (param i32) (result i32))) ;; CHECK-BINARY: (type $=>eqref (func (result eqref))) @@ -27,10 +31,6 @@ ;; CHECK-BINARY: (type $none_=>_i32 (func (result i32))) ;; CHECK-BINARY: (type $f64_=>_ref_null<_->_eqref> (func (param f64) (result (ref null $=>eqref)))) - ;; CHECK-TEXT: (type $mixed_results (func (result anyref f32 anyref f32))) - - ;; CHECK-TEXT: (type $none_=>_none (func)) - ;; CHECK-TEXT: (type $i32-i32 (func (param i32) (result i32))) ;; CHECK-TEXT: (type $=>eqref (func (result eqref))) @@ -77,17 +77,17 @@ (call_ref (ref.func $call-ref)) ) ;; CHECK-BINARY: (func $return-call-ref - ;; CHECK-BINARY-NEXT: (return_call_ref + ;; CHECK-BINARY-NEXT: (return_call_ref $void ;; CHECK-BINARY-NEXT: (ref.func $call-ref) ;; CHECK-BINARY-NEXT: ) ;; CHECK-BINARY-NEXT: ) ;; CHECK-TEXT: (func $return-call-ref - ;; CHECK-TEXT-NEXT: (return_call_ref + ;; CHECK-TEXT-NEXT: (return_call_ref $void ;; CHECK-TEXT-NEXT: (ref.func $call-ref) ;; CHECK-TEXT-NEXT: ) ;; CHECK-TEXT-NEXT: ) (func $return-call-ref - (return_call_ref (ref.func $call-ref)) + (return_call_ref $void (ref.func $call-ref)) ) ;; CHECK-BINARY: (func $call-ref-more (param $0 i32) (result i32) ;; CHECK-BINARY-NEXT: (call_ref @@ -405,7 +405,7 @@ ;; CHECK-NODEBUG-NEXT: ) ;; CHECK-NODEBUG: (func $1 -;; CHECK-NODEBUG-NEXT: (return_call_ref +;; CHECK-NODEBUG-NEXT: (return_call_ref $none_=>_none ;; CHECK-NODEBUG-NEXT: (ref.func $0) ;; CHECK-NODEBUG-NEXT: ) ;; CHECK-NODEBUG-NEXT: ) |