diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/binaryen-c.cpp | 20 | ||||
-rw-r--r-- | src/binaryen-c.h | 11 | ||||
-rw-r--r-- | src/ir/module-utils.cpp | 2 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 9 | ||||
-rw-r--r-- | src/passes/Print.cpp | 11 | ||||
-rw-r--r-- | src/passes/TypeMerging.cpp | 6 | ||||
-rw-r--r-- | src/wasm-builder.h | 5 | ||||
-rw-r--r-- | src/wasm-delegations-fields.def | 1 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 42 | ||||
-rw-r--r-- | src/wasm.h | 10 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 20 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 24 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 7 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 8 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 8 |
15 files changed, 96 insertions, 88 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index e3d1505fa..3a42b6f25 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -1749,11 +1749,10 @@ BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module, } BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module, BinaryenExpressionRef ref, - BinaryenHeapType intendedType) { - return static_cast<Expression*>(Builder(*(Module*)module) - .makeRefCast((Expression*)ref, - HeapType(intendedType), - RefCast::Safety::Safe)); + BinaryenType type) { + return static_cast<Expression*>( + Builder(*(Module*)module) + .makeRefCast((Expression*)ref, Type(type), RefCast::Safety::Safe)); } BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module, BinaryenOp op, @@ -4087,17 +4086,6 @@ void BinaryenRefCastSetRef(BinaryenExpressionRef expr, assert(refExpr); static_cast<RefCast*>(expression)->ref = (Expression*)refExpr; } -BinaryenHeapType BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr) { - auto* expression = (Expression*)expr; - assert(expression->is<RefCast>()); - return static_cast<RefCast*>(expression)->intendedType.getID(); -} -void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr, - BinaryenHeapType intendedType) { - auto* expression = (Expression*)expr; - assert(expression->is<RefCast>()); - static_cast<RefCast*>(expression)->intendedType = HeapType(intendedType); -} // BrOn BinaryenOp BinaryenBrOnGetOp(BinaryenExpressionRef expr) { auto* expression = (Expression*)expr; diff --git a/src/binaryen-c.h b/src/binaryen-c.h index c9331d49d..9fc0d8b04 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -1042,10 +1042,9 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenHeapType intendedType); -BINARYEN_API BinaryenExpressionRef -BinaryenRefCast(BinaryenModuleRef module, - BinaryenExpressionRef ref, - BinaryenHeapType intendedType); +BINARYEN_API BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module, + BinaryenExpressionRef ref, + BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module, BinaryenOp op, const char* name, @@ -2374,10 +2373,6 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefCastGetRef(BinaryenExpressionRef expr); BINARYEN_API void BinaryenRefCastSetRef(BinaryenExpressionRef expr, BinaryenExpressionRef refExpr); -BINARYEN_API BinaryenHeapType -BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr); -BINARYEN_API void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr, - BinaryenHeapType intendedType); // BrOn diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp index f3d9201b7..8e6ac0d1a 100644 --- a/src/ir/module-utils.cpp +++ b/src/ir/module-utils.cpp @@ -71,7 +71,7 @@ struct CodeScanner } else if (curr->is<ArrayInit>()) { counts.note(curr->type); } else if (auto* cast = curr->dynCast<RefCast>()) { - counts.note(cast->intendedType); + counts.note(cast->type); } else if (auto* cast = curr->dynCast<RefTest>()) { counts.note(cast->intendedType); } else if (auto* cast = curr->dynCast<BrOn>()) { diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index d218f5627..69a3a1fdb 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -1854,7 +1854,7 @@ struct OptimizeInstructions auto fallthrough = Properties::getFallthrough(curr->ref, getPassOptions(), *getModule()); - auto intendedType = curr->intendedType; + auto intendedType = curr->type.getHeapType(); // If the value is a null, it will just flow through, and we do not need // the cast. However, if that would change the type, then things are less @@ -1942,7 +1942,7 @@ struct OptimizeInstructions if (auto* child = ref->dynCast<RefCast>()) { // Repeated casts can be removed, leaving just the most demanding of // them. - auto childIntendedType = child->intendedType; + auto childIntendedType = child->type.getHeapType(); if (HeapType::isSubType(intendedType, childIntendedType)) { // Skip the child. if (curr->ref == child) { @@ -2000,6 +2000,11 @@ struct OptimizeInstructions if (auto* as = curr->ref->dynCast<RefAs>()) { if (as->op == RefAsNonNull) { curr->ref = as->value; + // Match the nullability of the new child. + // TODO: Combine the ref.as_non_null into the cast once we allow that. + if (curr->ref->type.isNullable()) { + curr->type = Type(curr->type.getHeapType(), Nullable); + } curr->finalize(); as->value = curr; as->finalize(); diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 68a2d16dc..2cd021b6f 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -2113,17 +2113,19 @@ struct PrintExpressionContents printHeapType(o, curr->intendedType, wasm); } void visitRefCast(RefCast* curr) { + if (printUnreachableReplacement(curr)) { + return; + } if (curr->safety == RefCast::Unsafe) { printMedium(o, "ref.cast_nop "); } else { - // Emulate legacy polymorphic behavior for now. - if (curr->ref->type.isNullable()) { + if (curr->type.isNullable()) { printMedium(o, "ref.cast null "); } else { printMedium(o, "ref.cast "); } } - printHeapType(o, curr->intendedType, wasm); + printHeapType(o, curr->type.getHeapType(), wasm); } void visitBrOn(BrOn* curr) { @@ -2825,6 +2827,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { void visitCallRef(CallRef* curr) { maybePrintUnreachableOrNullReplacement(curr, curr->target->type); } + void visitRefCast(RefCast* curr) { + maybePrintUnreachableReplacement(curr, curr->type); + } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp index 0fc766f4e..d873152f2 100644 --- a/src/passes/TypeMerging.cpp +++ b/src/passes/TypeMerging.cpp @@ -83,6 +83,12 @@ struct CastFinder #include "wasm-delegations-fields.def" } + + void visitRefCast(Expression* curr) { + if (curr->type != Type::unreachable) { + referredTypes.insert(curr->type.getHeapType()); + } + } }; struct TypeMerging : public Pass { diff --git a/src/wasm-builder.h b/src/wasm-builder.h index ff834013d..173e5976f 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -875,11 +875,10 @@ public: ret->finalize(); return ret; } - RefCast* - makeRefCast(Expression* ref, HeapType intendedType, RefCast::Safety safety) { + RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) { auto* ret = wasm.allocator.alloc<RefCast>(); ret->ref = ref; - ret->intendedType = intendedType; + ret->type = type; ret->safety = safety; ret->finalize(); return ret; diff --git a/src/wasm-delegations-fields.def b/src/wasm-delegations-fields.def index 85a088dd8..dc0a13f41 100644 --- a/src/wasm-delegations-fields.def +++ b/src/wasm-delegations-fields.def @@ -621,7 +621,6 @@ switch (DELEGATE_ID) { } case Expression::Id::RefCastId: { DELEGATE_START(RefCast); - DELEGATE_FIELD_HEAPTYPE(RefCast, intendedType); DELEGATE_FIELD_CHILD(RefCast, ref); DELEGATE_END(RefCast); break; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index e06d550b1..17c8a6456 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1430,10 +1430,6 @@ public: struct Breaking : Flow { Breaking(Flow breaking) : Flow(breaking) {} }; - // The null input to the cast. - struct Null : Literal { - Null(Literal original) : Literal(original) {} - }; // The result of the successful cast. struct Success : Literal { Success(Literal result) : Literal(result) {} @@ -1443,20 +1439,12 @@ public: Failure(Literal original) : Literal(original) {} }; - std::variant<Breaking, Null, Success, Failure> state; + std::variant<Breaking, Success, Failure> state; template<class T> Cast(T state) : state(state) {} Flow* getBreaking() { return std::get_if<Breaking>(&state); } - Literal* getNull() { return std::get_if<Null>(&state); } Literal* getSuccess() { return std::get_if<Success>(&state); } Literal* getFailure() { return std::get_if<Failure>(&state); } - Literal* getNullOrFailure() { - if (auto* original = getNull()) { - return original; - } else { - return getFailure(); - } - } }; template<typename T> Cast doCast(T* curr) { @@ -1464,21 +1452,19 @@ public: if (ref.breaking()) { return typename Cast::Breaking{ref}; } - Literal original = ref.getSingleValue(); - if (original.isNull()) { - return typename Cast::Null{original}; - } - // The input may not be GC data or a function; for example it could be an - // anyref or an i31. The cast definitely fails in these cases. - if (!original.isData() && !original.isFunction()) { - return typename Cast::Failure{original}; + Literal val = ref.getSingleValue(); + Type castType = curr->getCastType(); + if (val.isNull()) { + if (castType.isNullable()) { + return typename Cast::Success{val}; + } else { + return typename Cast::Failure{val}; + } } - HeapType actualType = original.type.getHeapType(); - // We have the actual and intended types, so perform the cast. - if (HeapType::isSubType(actualType, curr->intendedType)) { - return typename Cast::Success{original}; + if (HeapType::isSubType(val.type.getHeapType(), castType.getHeapType())) { + return typename Cast::Success{val}; } else { - return typename Cast::Failure{original}; + return typename Cast::Failure{val}; } } @@ -1496,8 +1482,6 @@ public: auto cast = doCast(curr); if (auto* breaking = cast.getBreaking()) { return *breaking; - } else if (cast.getNull()) { - return Literal::makeNull(curr->type.getHeapType()); } else if (auto* result = cast.getSuccess()) { return *result; } @@ -1512,7 +1496,7 @@ public: auto cast = doCast(curr); if (auto* breaking = cast.getBreaking()) { return *breaking; - } else if (auto* original = cast.getNullOrFailure()) { + } else if (auto* original = cast.getFailure()) { if (curr->op == BrOnCast) { return *original; } else { diff --git a/src/wasm.h b/src/wasm.h index 56dfa1ff2..3a44556d5 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -1525,6 +1525,9 @@ public: HeapType intendedType; void finalize(); + + // TODO: Support ref.test null as well. + Type getCastType() { return Type(intendedType, NonNullable); } }; class RefCast : public SpecificExpression<Expression::RefCastId> { @@ -1533,14 +1536,14 @@ public: Expression* ref; - HeapType intendedType; - // Support the unsafe `ref.cast_nop_static` to enable precise cast overhead // measurements. enum Safety { Safe, Unsafe }; Safety safety = Safe; void finalize(); + + Type getCastType() { return type; } }; class BrOn : public SpecificExpression<Expression::BrOnId> { @@ -1555,6 +1558,9 @@ public: void finalize(); + // TODO: Support br_on_cast* null as well. + Type getCastType() { return Type(intendedType, NonNullable); } + // Returns the type sent on the branch, if it is taken. Type getSentType(); }; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 0bad4d5ef..0160ffaf6 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -6908,21 +6908,27 @@ bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) { code == BinaryConsts::RefCastNull || code == BinaryConsts::RefCastNop) { bool legacy = code == BinaryConsts::RefCastStatic || code == BinaryConsts::RefCastNop; - auto intendedType = legacy ? getIndexedHeapType() : getHeapType(); + auto heapType = legacy ? getIndexedHeapType() : getHeapType(); auto* ref = popNonVoidExpression(); - // Even though we're parsing new instructions, we only support those that - // emulate the legacy polymorphic behavior for now. + Nullability nullability; + if (legacy) { + // Legacy polymorphic behavior. + nullability = ref->type.getNullability(); + } else { + nullability = code == BinaryConsts::RefCast ? NonNullable : Nullable; + } + // Only accept instructions emulating the legacy behavior for now. if (ref->type.isRef()) { - if (code == BinaryConsts::RefCast && ref->type.isNullable()) { + if (nullability == NonNullable && ref->type.isNullable()) { throwError("ref.cast on nullable input not yet supported"); - } else if (code == BinaryConsts::RefCastNull && - ref->type.isNonNullable()) { + } else if (nullability == Nullable && ref->type.isNonNullable()) { throwError("ref.cast null on non-nullable input not yet supported"); } } auto safety = code == BinaryConsts::RefCastNop ? RefCast::Unsafe : RefCast::Safe; - out = Builder(wasm).makeRefCast(ref, intendedType, safety); + auto type = Type(heapType, nullability); + out = Builder(wasm).makeRefCast(ref, type, safety); return true; } return false; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 98f7c1e87..cf287d8a1 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -2783,8 +2783,11 @@ Expression* SExpressionWasmBuilder::makeRefTest(Element& s) { Expression* SExpressionWasmBuilder::makeRefCast(Element& s) { int i = 1; - std::optional<Nullability> nullability; - if (s[0]->str().str != "ref.cast_static") { + Nullability nullability; + bool legacy = false; + if (s[0]->str().str == "ref.cast_static") { + legacy = true; + } else { nullability = NonNullable; if (s[i]->str().str == "null") { nullability = Nullable; @@ -2793,22 +2796,29 @@ Expression* SExpressionWasmBuilder::makeRefCast(Element& s) { } auto heapType = parseHeapType(*s[i++]); auto* ref = parseExpression(*s[i++]); - if (nullability && ref->type.isRef()) { - if (*nullability == NonNullable && ref->type.isNullable()) { + if (legacy) { + // Legacy polymorphic behavior. + nullability = ref->type.getNullability(); + } else if (ref->type.isRef()) { + // Only accept instructions emulating the legacy behavior for now. + if (nullability == NonNullable && ref->type.isNullable()) { throw ParseException( "ref.cast on nullable input not yet supported", s.line, s.col); - } else if (*nullability == Nullable && ref->type.isNonNullable()) { + } else if (nullability == Nullable && ref->type.isNonNullable()) { throw ParseException( "ref.cast null on non-nullable input not yet supported", s.line, s.col); } } - return Builder(wasm).makeRefCast(ref, heapType, RefCast::Safe); + auto type = Type(heapType, nullability); + return Builder(wasm).makeRefCast(ref, type, RefCast::Safe); } Expression* SExpressionWasmBuilder::makeRefCastNop(Element& s) { auto heapType = parseHeapType(*s[1]); auto* ref = parseExpression(*s[2]); - return Builder(wasm).makeRefCast(ref, heapType, RefCast::Unsafe); + // Legacy polymorphic behavior. + auto type = Type(heapType, ref->type.getNullability()); + return Builder(wasm).makeRefCast(ref, type, RefCast::Unsafe); } Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index db40a5980..911447b48 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2033,15 +2033,14 @@ void BinaryInstWriter::visitRefCast(RefCast* curr) { o << int8_t(BinaryConsts::GCPrefix); if (curr->safety == RefCast::Unsafe) { o << U32LEB(BinaryConsts::RefCastNop); - parent.writeIndexedHeapType(curr->intendedType); + parent.writeIndexedHeapType(curr->type.getHeapType()); } else { - // Emulate legacy polymorphic behavior for now. - if (curr->ref->type.isNullable()) { + if (curr->type.isNullable()) { o << U32LEB(BinaryConsts::RefCastNull); } else { o << U32LEB(BinaryConsts::RefCast); } - parent.writeHeapType(curr->intendedType); + parent.writeHeapType(curr->type.getHeapType()); } } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 726c70c7c..b102ecb1c 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2530,10 +2530,16 @@ void FunctionValidator::visitRefCast(RefCast* curr) { return; } shouldBeEqual( - curr->intendedType.getBottom(), + curr->type.getHeapType().getBottom(), curr->ref->type.getHeapType().getBottom(), curr, "ref.cast target type and ref type must have a common supertype"); + + // TODO: Remove this restriction + shouldBeEqual(curr->type.getNullability(), + curr->ref->type.getNullability(), + curr, + "ref.cast to a different nullability not yet implemented"); } void FunctionValidator::visitBrOn(BrOn* curr) { diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 604854878..cf71689b1 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -945,10 +945,10 @@ void RefTest::finalize() { void RefCast::finalize() { if (ref->type == Type::unreachable) { type = Type::unreachable; - } else { - // The output of ref.cast may be null if the input is null (in that case the - // null is passed through). - type = Type(intendedType, ref->type.getNullability()); + } + // Do not unnecessarily lose non-nullability information. + if (ref->type.isNonNullable() && type.isNullable()) { + type = Type(type.getHeapType(), NonNullable); } } |