diff options
author | Thomas Lively <tlively@google.com> | 2022-12-20 09:52:54 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-20 07:52:54 -0800 |
commit | 569f789622f116177c8a1e32fb62a4e5a5c9dfe0 (patch) | |
tree | 3a67097f753f3a22a1ebde6cd4f50c532e773663 /src | |
parent | 12ad604c17407f6b36d52c6404f2dab32e5c7960 (diff) | |
download | binaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.tar.gz binaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.tar.bz2 binaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.zip |
Update RefCast representation to drop extra HeapType (#5350)
The latest upstream version of ref.cast is parameterized with a target reference
type, not just a heap type, because the nullability of the result is
parameterizable. As a first step toward implementing these new, more flexible
ref.cast instructions, change the internal representation of ref.cast to use the
expression type as the cast target rather than storing a separate heap type
field. For now require that the encoded semantics match the previously allowed
semantics, though, so that none of the optimization passes need to be updated.
Diffstat (limited to 'src')
-rw-r--r-- | src/binaryen-c.cpp | 20 | ||||
-rw-r--r-- | src/binaryen-c.h | 11 | ||||
-rw-r--r-- | src/ir/module-utils.cpp | 2 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 9 | ||||
-rw-r--r-- | src/passes/Print.cpp | 11 | ||||
-rw-r--r-- | src/passes/TypeMerging.cpp | 6 | ||||
-rw-r--r-- | src/wasm-builder.h | 5 | ||||
-rw-r--r-- | src/wasm-delegations-fields.def | 1 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 42 | ||||
-rw-r--r-- | src/wasm.h | 10 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 20 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 24 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 7 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 8 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 8 |
15 files changed, 96 insertions, 88 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index e3d1505fa..3a42b6f25 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -1749,11 +1749,10 @@ BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module, } BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module, BinaryenExpressionRef ref, - BinaryenHeapType intendedType) { - return static_cast<Expression*>(Builder(*(Module*)module) - .makeRefCast((Expression*)ref, - HeapType(intendedType), - RefCast::Safety::Safe)); + BinaryenType type) { + return static_cast<Expression*>( + Builder(*(Module*)module) + .makeRefCast((Expression*)ref, Type(type), RefCast::Safety::Safe)); } BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module, BinaryenOp op, @@ -4087,17 +4086,6 @@ void BinaryenRefCastSetRef(BinaryenExpressionRef expr, assert(refExpr); static_cast<RefCast*>(expression)->ref = (Expression*)refExpr; } -BinaryenHeapType BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr) { - auto* expression = (Expression*)expr; - assert(expression->is<RefCast>()); - return static_cast<RefCast*>(expression)->intendedType.getID(); -} -void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr, - BinaryenHeapType intendedType) { - auto* expression = (Expression*)expr; - assert(expression->is<RefCast>()); - static_cast<RefCast*>(expression)->intendedType = HeapType(intendedType); -} // BrOn BinaryenOp BinaryenBrOnGetOp(BinaryenExpressionRef expr) { auto* expression = (Expression*)expr; diff --git a/src/binaryen-c.h b/src/binaryen-c.h index c9331d49d..9fc0d8b04 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -1042,10 +1042,9 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenHeapType intendedType); -BINARYEN_API BinaryenExpressionRef -BinaryenRefCast(BinaryenModuleRef module, - BinaryenExpressionRef ref, - BinaryenHeapType intendedType); +BINARYEN_API BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module, + BinaryenExpressionRef ref, + BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module, BinaryenOp op, const char* name, @@ -2374,10 +2373,6 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefCastGetRef(BinaryenExpressionRef expr); BINARYEN_API void BinaryenRefCastSetRef(BinaryenExpressionRef expr, BinaryenExpressionRef refExpr); -BINARYEN_API BinaryenHeapType -BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr); -BINARYEN_API void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr, - BinaryenHeapType intendedType); // BrOn diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index f3d9201b7..8e6ac0d1a 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -71,7 +71,7 @@ struct CodeScanner } else if (curr->is<ArrayInit>()) { counts.note(curr->type); } else if (auto* cast = curr->dynCast<RefCast>()) { - counts.note(cast->intendedType); + counts.note(cast->type); } else if (auto* cast = curr->dynCast<RefTest>()) { counts.note(cast->intendedType); } else if (auto* cast = curr->dynCast<BrOn>()) { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index d218f5627..69a3a1fdb 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -1854,7 +1854,7 @@ struct OptimizeInstructions auto fallthrough = Properties::getFallthrough(curr->ref, getPassOptions(), *getModule()); - auto intendedType = curr->intendedType; + auto intendedType = curr->type.getHeapType(); // If the value is a null, it will just flow through, and we do not need // the cast. However, if that would change the type, then things are less @@ -1942,7 +1942,7 @@ struct OptimizeInstructions if (auto* child = ref->dynCast<RefCast>()) { // Repeated casts can be removed, leaving just the most demanding of // them. - auto childIntendedType = child->intendedType; + auto childIntendedType = child->type.getHeapType(); if (HeapType::isSubType(intendedType, childIntendedType)) { // Skip the child. if (curr->ref == child) { @@ -2000,6 +2000,11 @@ struct OptimizeInstructions if (auto* as = curr->ref->dynCast<RefAs>()) { if (as->op == RefAsNonNull) { curr->ref = as->value; + // Match the nullability of the new child. + // TODO: Combine the ref.as_non_null into the cast once we allow that. + if (curr->ref->type.isNullable()) { + curr->type = Type(curr->type.getHeapType(), Nullable); + } curr->finalize(); as->value = curr; as->finalize(); diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 68a2d16dc..2cd021b6f 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2113,17 +2113,19 @@ struct PrintExpressionContents printHeapType(o, curr->intendedType, wasm); } void visitRefCast(RefCast* curr) { + if (printUnreachableReplacement(curr)) { + return; + } if (curr->safety == RefCast::Unsafe) { printMedium(o, "ref.cast_nop "); } else { - // Emulate legacy polymorphic behavior for now. - if (curr->ref->type.isNullable()) { + if (curr->type.isNullable()) { printMedium(o, "ref.cast null "); } else { printMedium(o, "ref.cast "); } } - printHeapType(o, curr->intendedType, wasm); + printHeapType(o, curr->type.getHeapType(), wasm); } void visitBrOn(BrOn* curr) { @@ -2825,6 +2827,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { void visitCallRef(CallRef* curr) { maybePrintUnreachableOrNullReplacement(curr, curr->target->type); } + void visitRefCast(RefCast* curr) { + maybePrintUnreachableReplacement(curr, curr->type); + } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp index 0fc766f4e..d873152f2 100644 --- a/src/passes/TypeMerging.cpp +++ b/src/passes/TypeMerging.cpp @@ -83,6 +83,12 @@ struct CastFinder #include "wasm-delegations-fields.def" } + + void visitRefCast(Expression* curr) { + if (curr->type != Type::unreachable) { + referredTypes.insert(curr->type.getHeapType()); + } + } }; struct TypeMerging : public Pass { diff --git a/src/wasm-builder.h b/src/wasm-builder.h index ff834013d..173e5976f 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -875,11 +875,10 @@ public: ret->finalize(); return ret; } - RefCast* - makeRefCast(Expression* ref, HeapType intendedType, RefCast::Safety safety) { + RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) { auto* ret = wasm.allocator.alloc<RefCast>(); ret->ref = ref; - ret->intendedType = intendedType; + ret->type = type; ret->safety = safety; ret->finalize(); return ret; diff --git a/src/wasm-delegations-fields.def b/src/wasm-delegations-fields.def index 85a088dd8..dc0a13f41 100644 --- a/src/wasm-delegations-fields.def +++ b/src/wasm-delegations-fields.def @@ -621,7 +621,6 @@ switch (DELEGATE_ID) { } case Expression::Id::RefCastId: { DELEGATE_START(RefCast); - DELEGATE_FIELD_HEAPTYPE(RefCast, intendedType); DELEGATE_FIELD_CHILD(RefCast, ref); DELEGATE_END(RefCast); break; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index e06d550b1..17c8a6456 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1430,10 +1430,6 @@ public: struct Breaking : Flow { Breaking(Flow breaking) : Flow(breaking) {} }; - // The null input to the cast. - struct Null : Literal { - Null(Literal original) : Literal(original) {} - }; // The result of the successful cast. struct Success : Literal { Success(Literal result) : Literal(result) {} @@ -1443,20 +1439,12 @@ public: Failure(Literal original) : Literal(original) {} }; - std::variant<Breaking, Null, Success, Failure> state; + std::variant<Breaking, Success, Failure> state; template<class T> Cast(T state) : state(state) {} Flow* getBreaking() { return std::get_if<Breaking>(&state); } - Literal* getNull() { return std::get_if<Null>(&state); } Literal* getSuccess() { return std::get_if<Success>(&state); } Literal* getFailure() { return std::get_if<Failure>(&state); } - Literal* getNullOrFailure() { - if (auto* original = getNull()) { - return original; - } else { - return getFailure(); - } - } }; template<typename T> Cast doCast(T* curr) { @@ -1464,21 +1452,19 @@ public: if (ref.breaking()) { return typename Cast::Breaking{ref}; } - Literal original = ref.getSingleValue(); - if (original.isNull()) { - return typename Cast::Null{original}; - } - // The input may not be GC data or a function; for example it could be an - // anyref or an i31. The cast definitely fails in these cases. - if (!original.isData() && !original.isFunction()) { - return typename Cast::Failure{original}; + Literal val = ref.getSingleValue(); + Type castType = curr->getCastType(); + if (val.isNull()) { + if (castType.isNullable()) { + return typename Cast::Success{val}; + } else { + return typename Cast::Failure{val}; + } } - HeapType actualType = original.type.getHeapType(); - // We have the actual and intended types, so perform the cast. - if (HeapType::isSubType(actualType, curr->intendedType)) { - return typename Cast::Success{original}; + if (HeapType::isSubType(val.type.getHeapType(), castType.getHeapType())) { + return typename Cast::Success{val}; } else { - return typename Cast::Failure{original}; + return typename Cast::Failure{val}; } } @@ -1496,8 +1482,6 @@ public: auto cast = doCast(curr); if (auto* breaking = cast.getBreaking()) { return *breaking; - } else if (cast.getNull()) { - return Literal::makeNull(curr->type.getHeapType()); } else if (auto* result = cast.getSuccess()) { return *result; } @@ -1512,7 +1496,7 @@ public: auto cast = doCast(curr); if (auto* breaking = cast.getBreaking()) { return *breaking; - } else if (auto* original = cast.getNullOrFailure()) { + } else if (auto* original = cast.getFailure()) { if (curr->op == BrOnCast) { return *original; } else { diff --git a/src/wasm.h b/src/wasm.h index 56dfa1ff2..3a44556d5 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1525,6 +1525,9 @@ public: HeapType intendedType; void finalize(); + + // TODO: Support ref.test null as well. + Type getCastType() { return Type(intendedType, NonNullable); } }; class RefCast : public SpecificExpression<Expression::RefCastId> { @@ -1533,14 +1536,14 @@ public: Expression* ref; - HeapType intendedType; - // Support the unsafe `ref.cast_nop_static` to enable precise cast overhead // measurements. enum Safety { Safe, Unsafe }; Safety safety = Safe; void finalize(); + + Type getCastType() { return type; } }; class BrOn : public SpecificExpression<Expression::BrOnId> { @@ -1555,6 +1558,9 @@ public: void finalize(); + // TODO: Support br_on_cast* null as well. + Type getCastType() { return Type(intendedType, NonNullable); } + // Returns the type sent on the branch, if it is taken. Type getSentType(); }; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 0bad4d5ef..0160ffaf6 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -6908,21 +6908,27 @@ bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) { code == BinaryConsts::RefCastNull || code == BinaryConsts::RefCastNop) { bool legacy = code == BinaryConsts::RefCastStatic || code == BinaryConsts::RefCastNop; - auto intendedType = legacy ? getIndexedHeapType() : getHeapType(); + auto heapType = legacy ? getIndexedHeapType() : getHeapType(); auto* ref = popNonVoidExpression(); - // Even though we're parsing new instructions, we only support those that - // emulate the legacy polymorphic behavior for now. + Nullability nullability; + if (legacy) { + // Legacy polymorphic behavior. + nullability = ref->type.getNullability(); + } else { + nullability = code == BinaryConsts::RefCast ? NonNullable : Nullable; + } + // Only accept instructions emulating the legacy behavior for now. if (ref->type.isRef()) { - if (code == BinaryConsts::RefCast && ref->type.isNullable()) { + if (nullability == NonNullable && ref->type.isNullable()) { throwError("ref.cast on nullable input not yet supported"); - } else if (code == BinaryConsts::RefCastNull && - ref->type.isNonNullable()) { + } else if (nullability == Nullable && ref->type.isNonNullable()) { throwError("ref.cast null on non-nullable input not yet supported"); } } auto safety = code == BinaryConsts::RefCastNop ? RefCast::Unsafe : RefCast::Safe; - out = Builder(wasm).makeRefCast(ref, intendedType, safety); + auto type = Type(heapType, nullability); + out = Builder(wasm).makeRefCast(ref, type, safety); return true; } return false; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 98f7c1e87..cf287d8a1 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -2783,8 +2783,11 @@ Expression* SExpressionWasmBuilder::makeRefTest(Element& s) { Expression* SExpressionWasmBuilder::makeRefCast(Element& s) { int i = 1; - std::optional<Nullability> nullability; - if (s[0]->str().str != "ref.cast_static") { + Nullability nullability; + bool legacy = false; + if (s[0]->str().str == "ref.cast_static") { + legacy = true; + } else { nullability = NonNullable; if (s[i]->str().str == "null") { nullability = Nullable; @@ -2793,22 +2796,29 @@ Expression* SExpressionWasmBuilder::makeRefCast(Element& s) { } auto heapType = parseHeapType(*s[i++]); auto* ref = parseExpression(*s[i++]); - if (nullability && ref->type.isRef()) { - if (*nullability == NonNullable && ref->type.isNullable()) { + if (legacy) { + // Legacy polymorphic behavior. + nullability = ref->type.getNullability(); + } else if (ref->type.isRef()) { + // Only accept instructions emulating the legacy behavior for now. + if (nullability == NonNullable && ref->type.isNullable()) { throw ParseException( "ref.cast on nullable input not yet supported", s.line, s.col); - } else if (*nullability == Nullable && ref->type.isNonNullable()) { + } else if (nullability == Nullable && ref->type.isNonNullable()) { throw ParseException( "ref.cast null on non-nullable input not yet supported", s.line, s.col); } } - return Builder(wasm).makeRefCast(ref, heapType, RefCast::Safe); + auto type = Type(heapType, nullability); + return Builder(wasm).makeRefCast(ref, type, RefCast::Safe); } Expression* SExpressionWasmBuilder::makeRefCastNop(Element& s) { auto heapType = parseHeapType(*s[1]); auto* ref = parseExpression(*s[2]); - return Builder(wasm).makeRefCast(ref, heapType, RefCast::Unsafe); + // Legacy polymorphic behavior. + auto type = Type(heapType, ref->type.getNullability()); + return Builder(wasm).makeRefCast(ref, type, RefCast::Unsafe); } Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index db40a5980..911447b48 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2033,15 +2033,14 @@ void BinaryInstWriter::visitRefCast(RefCast* curr) { o << int8_t(BinaryConsts::GCPrefix); if (curr->safety == RefCast::Unsafe) { o << U32LEB(BinaryConsts::RefCastNop); - parent.writeIndexedHeapType(curr->intendedType); + parent.writeIndexedHeapType(curr->type.getHeapType()); } else { - // Emulate legacy polymorphic behavior for now. - if (curr->ref->type.isNullable()) { + if (curr->type.isNullable()) { o << U32LEB(BinaryConsts::RefCastNull); } else { o << U32LEB(BinaryConsts::RefCast); } - parent.writeHeapType(curr->intendedType); + parent.writeHeapType(curr->type.getHeapType()); } } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 726c70c7c..b102ecb1c 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2530,10 +2530,16 @@ void FunctionValidator::visitRefCast(RefCast* curr) { return; } shouldBeEqual( - curr->intendedType.getBottom(), + curr->type.getHeapType().getBottom(), curr->ref->type.getHeapType().getBottom(), curr, "ref.cast target type and ref type must have a common supertype"); + + // TODO: Remove this restriction + shouldBeEqual(curr->type.getNullability(), + curr->ref->type.getNullability(), + curr, + "ref.cast to a different nullability not yet implemented"); } void FunctionValidator::visitBrOn(BrOn* curr) { diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 604854878..cf71689b1 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -945,10 +945,10 @@ void RefTest::finalize() { void RefCast::finalize() { if (ref->type == Type::unreachable) { type = Type::unreachable; - } else { - // The output of ref.cast may be null if the input is null (in that case the - // null is passed through). - type = Type(intendedType, ref->type.getNullability()); + } + // Do not unnecessarily lose non-nullability information. + if (ref->type.isNonNullable() && type.isNullable()) { + type = Type(type.getHeapType(), NonNullable); } } |