diff options
author | Alon Zakai <azakai@google.com> | 2021-09-20 10:47:14 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-20 17:47:14 +0000 |
commit | 737c22d30798c491eea3b401b948b9327ac979de (patch) | |
tree | f75a72adbd81a85eca19b732378837670c828b23 /src/wasm | |
parent | b5e8c371001de20128453d5064ac0422d481020e (diff) | |
download | binaryen-737c22d30798c491eea3b401b948b9327ac979de.tar.gz binaryen-737c22d30798c491eea3b401b948b9327ac979de.tar.bz2 binaryen-737c22d30798c491eea3b401b948b9327ac979de.zip |
[Wasm GC] Add static variants of ref.test, ref.cast, and br_on_cast* (#4163)
These variants take a HeapType that is the type we intend to cast to,
and do not take an RTT.
These are intended to be more statically optimizable. For now though
this PR just implements the minimum to get them parsing and to get
through the optimizer without crashing.
Spec: https://docs.google.com/document/d/1afthjsL_B9UaMqCA5ekgVmOm75BVFu6duHNsN9-gnXw/edit#
See #4149
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/wasm-binary.cpp | 43 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 19 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 33 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 61 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 25 |
5 files changed, 148 insertions, 33 deletions
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index dad6f9320..e1ff43988 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -6419,23 +6419,33 @@ bool WasmBinaryBuilder::maybeVisitI31Get(Expression*& out, uint32_t code) { } bool WasmBinaryBuilder::maybeVisitRefTest(Expression*& out, uint32_t code) { - if (code != BinaryConsts::RefTest) { - return false; + if (code == BinaryConsts::RefTest) { + auto* rtt = popNonVoidExpression(); + auto* ref = popNonVoidExpression(); + out = Builder(wasm).makeRefTest(ref, rtt); + return true; + } else if (code == BinaryConsts::RefTestStatic) { + auto intendedType = getIndexedHeapType(); + auto* ref = popNonVoidExpression(); + out = Builder(wasm).makeRefTest(ref, intendedType); + return true; } - auto* rtt = popNonVoidExpression(); - auto* ref = popNonVoidExpression(); - out = Builder(wasm).makeRefTest(ref, rtt); - return true; + return false; } bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) { - if (code != BinaryConsts::RefCast) { - return false; + if (code == BinaryConsts::RefCast) { + auto* rtt = popNonVoidExpression(); + auto* ref = popNonVoidExpression(); + out = Builder(wasm).makeRefCast(ref, rtt); + return true; + } else if (code == BinaryConsts::RefCastStatic) { + auto intendedType = getIndexedHeapType(); + auto* ref = popNonVoidExpression(); + out = Builder(wasm).makeRefCast(ref, intendedType); + return true; } - auto* rtt = popNonVoidExpression(); - auto* ref = popNonVoidExpression(); - out = Builder(wasm).makeRefCast(ref, rtt); - return true; + return false; } bool WasmBinaryBuilder::maybeVisitBrOn(Expression*& out, uint32_t code) { @@ -6448,9 +6458,11 @@ bool WasmBinaryBuilder::maybeVisitBrOn(Expression*& out, uint32_t code) { op = BrOnNonNull; break; case BinaryConsts::BrOnCast: + case BinaryConsts::BrOnCastStatic: op = BrOnCast; break; case BinaryConsts::BrOnCastFail: + case BinaryConsts::BrOnCastStaticFail: op = BrOnCastFail; break; case BinaryConsts::BrOnFunc: @@ -6475,6 +6487,13 @@ bool WasmBinaryBuilder::maybeVisitBrOn(Expression*& out, uint32_t code) { return false; } auto name = getBreakTarget(getU32LEB()).name; + if (code == BinaryConsts::BrOnCastStatic || + code == BinaryConsts::BrOnCastStaticFail) { + auto intendedType = getIndexedHeapType(); + auto* ref = popNonVoidExpression(); + out = Builder(wasm).makeBrOn(op, name, ref, intendedType); + return true; + } Expression* rtt = nullptr; if (op == BrOnCast || op == BrOnCastFail) { rtt = popNonVoidExpression(); diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 28028c459..3cfe90e26 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -2574,12 +2574,24 @@ Expression* SExpressionWasmBuilder::makeRefTest(Element& s) { return Builder(wasm).makeRefTest(ref, rtt); } +Expression* SExpressionWasmBuilder::makeRefTestStatic(Element& s) { + auto heapType = parseHeapType(*s[1]); + auto* ref = parseExpression(*s[2]); + return Builder(wasm).makeRefTest(ref, heapType); +} + Expression* SExpressionWasmBuilder::makeRefCast(Element& s) { auto* ref = parseExpression(*s[1]); auto* rtt = parseExpression(*s[2]); return Builder(wasm).makeRefCast(ref, rtt); } +Expression* SExpressionWasmBuilder::makeRefCastStatic(Element& s) { + auto heapType = parseHeapType(*s[1]); + auto* ref = parseExpression(*s[2]); + return Builder(wasm).makeRefCast(ref, heapType); +} + Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) { auto name = getLabel(*s[1]); auto* ref = parseExpression(*s[2]); @@ -2591,6 +2603,13 @@ Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) { .validateAndMakeBrOn(op, name, ref, rtt); } +Expression* SExpressionWasmBuilder::makeBrOnStatic(Element& s, BrOnOp op) { + auto name = getLabel(*s[1]); + auto heapType = parseHeapType(*s[2]); + auto* ref = parseExpression(*s[3]); + return Builder(wasm).makeBrOn(op, name, ref, heapType); +} + Expression* SExpressionWasmBuilder::makeRttCanon(Element& s) { return Builder(wasm).makeRttCanon(parseHeapType(*s[1])); } diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index e5460cf6f..b6424cdde 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -1942,11 +1942,23 @@ void BinaryInstWriter::visitCallRef(CallRef* curr) { } void BinaryInstWriter::visitRefTest(RefTest* curr) { - o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::RefTest); + o << int8_t(BinaryConsts::GCPrefix); + if (curr->rtt) { + o << U32LEB(BinaryConsts::RefTest); + } else { + o << U32LEB(BinaryConsts::RefTestStatic); + parent.writeIndexedHeapType(curr->intendedType); + } } void BinaryInstWriter::visitRefCast(RefCast* curr) { - o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::RefCast); + o << int8_t(BinaryConsts::GCPrefix); + if (curr->rtt) { + o << U32LEB(BinaryConsts::RefCast); + } else { + o << U32LEB(BinaryConsts::RefCastStatic); + parent.writeIndexedHeapType(curr->intendedType); + } } void BinaryInstWriter::visitBrOn(BrOn* curr) { @@ -1958,10 +1970,20 @@ void BinaryInstWriter::visitBrOn(BrOn* curr) { o << int8_t(BinaryConsts::BrOnNonNull); break; case BrOnCast: - o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::BrOnCast); + o << int8_t(BinaryConsts::GCPrefix); + if (curr->rtt) { + o << U32LEB(BinaryConsts::BrOnCast); + } else { + o << U32LEB(BinaryConsts::BrOnCastStatic); + } break; case BrOnCastFail: - o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::BrOnCastFail); + o << int8_t(BinaryConsts::GCPrefix); + if (curr->rtt) { + o << U32LEB(BinaryConsts::BrOnCastFail); + } else { + o << U32LEB(BinaryConsts::BrOnCastStaticFail); + } break; case BrOnFunc: o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::BrOnFunc); @@ -1985,6 +2007,9 @@ void BinaryInstWriter::visitBrOn(BrOn* curr) { WASM_UNREACHABLE("invalid br_on_*"); } o << U32LEB(getBreakIndex(curr->name)); + if ((curr->op == BrOnCast || curr->op == BrOnCastFail) && !curr->rtt) { + parent.writeIndexedHeapType(curr->intendedType); + } } void BinaryInstWriter::visitRttCanon(RttCanon* curr) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 0de417ed7..1f6421609 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2219,9 +2219,20 @@ void FunctionValidator::visitRefTest(RefTest* curr) { shouldBeTrue( curr->ref->type.isRef(), curr, "ref.test ref must have ref type"); } - if (curr->rtt->type != Type::unreachable) { - shouldBeTrue( - curr->rtt->type.isRtt(), curr, "ref.test rtt must have rtt type"); + if (curr->rtt) { + if (curr->rtt->type != Type::unreachable) { + shouldBeTrue( + curr->rtt->type.isRtt(), curr, "ref.test rtt must have rtt type"); + } + shouldBeEqual(curr->intendedType, + HeapType(), + curr, + "dynamic ref.test must not use intendedType field"); + } else { + shouldBeUnequal(curr->intendedType, + HeapType(), + curr, + "static ref.test must set intendedType field"); } } @@ -2232,9 +2243,20 @@ void FunctionValidator::visitRefCast(RefCast* curr) { shouldBeTrue( curr->ref->type.isRef(), curr, "ref.cast ref must have ref type"); } - if (curr->rtt->type != Type::unreachable) { - shouldBeTrue( - curr->rtt->type.isRtt(), curr, "ref.cast rtt must have rtt type"); + if (curr->rtt) { + if (curr->rtt->type != Type::unreachable) { + shouldBeTrue( + curr->rtt->type.isRtt(), curr, "ref.cast rtt must have rtt type"); + } + shouldBeEqual(curr->intendedType, + HeapType(), + curr, + "dynamic ref.cast must not use intendedType field"); + } else { + shouldBeUnequal(curr->intendedType, + HeapType(), + curr, + "static ref.cast must set intendedType field"); } } @@ -2247,14 +2269,29 @@ void FunctionValidator::visitBrOn(BrOn* curr) { curr->ref->type.isRef(), curr, "br_on_cast ref must have ref type"); } if (curr->op == BrOnCast || curr->op == BrOnCastFail) { - // Note that an unreachable rtt is not supported: the text and binary - // formats do not provide the type, so if it's unreachable we should not - // even create a br_on_cast in such a case, as we'd have no idea what it - // casts to. - shouldBeTrue( - curr->rtt->type.isRtt(), curr, "br_on_cast rtt must have rtt type"); + if (curr->rtt) { + // Note that an unreachable rtt is not supported: the text and binary + // formats do not provide the type, so if it's unreachable we should not + // even create a br_on_cast in such a case, as we'd have no idea what it + // casts to. + shouldBeTrue( + curr->rtt->type.isRtt(), curr, "br_on_cast rtt must have rtt type"); + shouldBeEqual(curr->intendedType, + HeapType(), + curr, + "dynamic br_on_cast* must not use intendedType field"); + } else { + shouldBeUnequal(curr->intendedType, + HeapType(), + curr, + "static br_on_cast* must set intendedType field"); + } } else { shouldBeTrue(curr->rtt == nullptr, curr, "non-cast BrOn must not have rtt"); + shouldBeEqual(curr->intendedType, + HeapType(), + curr, + "non-cast br_on* must not set intendedType field"); } noteBreak(curr->name, curr->getSentType(), curr); } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 2861c4cee..45137fffe 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -895,23 +895,33 @@ void CallRef::finalize(Type type_) { } void RefTest::finalize() { - if (ref->type == Type::unreachable || rtt->type == Type::unreachable) { + if (ref->type == Type::unreachable || + (rtt && rtt->type == Type::unreachable)) { type = Type::unreachable; } else { type = Type::i32; } } +HeapType RefTest::getIntendedType() { + return rtt ? rtt->type.getHeapType() : intendedType; +} + void RefCast::finalize() { - if (ref->type == Type::unreachable || rtt->type == Type::unreachable) { + if (ref->type == Type::unreachable || + (rtt && rtt->type == Type::unreachable)) { type = Type::unreachable; } else { // The output of ref.cast may be null if the input is null (in that case the // null is passed through). - type = Type(rtt->type.getHeapType(), ref->type.getNullability()); + type = Type(getIntendedType(), ref->type.getNullability()); } } +HeapType RefCast::getIntendedType() { + return rtt ? rtt->type.getHeapType() : intendedType; +} + void BrOn::finalize() { if (ref->type == Type::unreachable || (rtt && rtt->type == Type::unreachable)) { @@ -938,7 +948,7 @@ void BrOn::finalize() { case BrOnCastFail: // If we do not branch, the cast worked, and we have something of the cast // type. - type = Type(rtt->type.getHeapType(), NonNullable); + type = Type(getIntendedType(), NonNullable); break; case BrOnNonFunc: type = Type(HeapType::func, NonNullable); @@ -954,6 +964,11 @@ void BrOn::finalize() { } } +HeapType BrOn::getIntendedType() { + assert(op == BrOnCast || op == BrOnCastFail); + return rtt ? rtt->type.getHeapType() : intendedType; +} + Type BrOn::getSentType() { switch (op) { case BrOnNull: @@ -971,7 +986,7 @@ Type BrOn::getSentType() { if (ref->type == Type::unreachable) { return Type::unreachable; } - return Type(rtt->type.getHeapType(), NonNullable); + return Type(getIntendedType(), NonNullable); case BrOnFunc: return Type::funcref; case BrOnData: |