summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/passes/OptimizeInstructions.cpp42
-rw-r--r--src/wasm/wasm-binary.cpp8
-rw-r--r--src/wasm/wasm-s-parser.cpp9
-rw-r--r--src/wasm/wasm-validator.cpp10
-rw-r--r--src/wasm/wasm.cpp1
-rw-r--r--test/lit/passes/optimize-instructions-gc.wast44
-rw-r--r--test/spec/ref_cast.wast33
7 files changed, 74 insertions, 73 deletions
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 097e0299a..e61373711 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -1872,28 +1872,28 @@ struct OptimizeInstructions
auto intendedType = curr->type.getHeapType();
- // If the value is a null, it will just flow through, and we do not need
- // the cast. However, if that would change the type, then things are less
- // simple: if the original type was non-nullable, replacing it with a null
- // would change the type, which can happen in e.g.
- // (ref.cast (ref.as_non_null (.. (ref.null)
+ // If the value is a null, then we know a nullable cast will succeed and a
+ // non-nullable cast will fail. Either way, we do not need the cast.
+ // However, we have to avoid changing the type when replacing a cast with
+ // its potentially more refined child, e.g.
+ // (ref.cast null (ref.as_non_null (.. (ref.null)))
if (fallthrough->is<RefNull>()) {
- // Replace the expression with drops of the inputs, and a null. Note
- // that we provide a null of the previous type, so that we do not alter
- // the type received by our parent.
- Expression* rep = builder.makeSequence(builder.makeDrop(curr->ref),
- builder.makeRefNull(intendedType));
- if (curr->ref->type.isNonNullable()) {
- // Avoid a type change by forcing to be non-nullable. In practice,
- // this would have trapped before we get here, so this is just for
- // validation.
- rep = builder.makeRefAs(RefAsNonNull, rep);
- }
- replaceCurrent(rep);
- return;
- // TODO: The optimal ordering of this and the other ref.as_non_null
- // stuff later down in this functions is unclear and may be worth
- // looking into.
+ if (curr->type.isNullable()) {
+ // Replace the expression to drop the input and directly produce the
+ // null.
+ replaceCurrent(builder.makeSequence(builder.makeDrop(curr->ref),
+ builder.makeRefNull(intendedType)));
+ return;
+ // TODO: The optimal ordering of this and the other ref.as_non_null
+ // stuff later down in this functions is unclear and may be worth
+ // looking into.
+ } else {
+ // The cast will trap on the null, so replace it with an unreachable
+ // wrapped in a block of the original type.
+ replaceCurrent(builder.makeSequence(
+ builder.makeDrop(curr->ref), builder.makeUnreachable(), curr->type));
+ return;
+ }
}
// For the cast to be able to succeed, the value being cast must be a
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 386f89060..c6fb361bd 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -6920,14 +6920,6 @@ bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) {
} else {
nullability = code == BinaryConsts::RefCast ? NonNullable : Nullable;
}
- // Only accept instructions emulating the legacy behavior for now.
- if (ref->type.isRef()) {
- if (nullability == NonNullable && ref->type.isNullable()) {
- throwError("ref.cast on nullable input not yet supported");
- } else if (nullability == Nullable && ref->type.isNonNullable()) {
- throwError("ref.cast null on non-nullable input not yet supported");
- }
- }
auto safety =
code == BinaryConsts::RefCastNop ? RefCast::Unsafe : RefCast::Safe;
auto type = Type(heapType, nullability);
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 4fdc29a66..2f791609a 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -2805,15 +2805,6 @@ Expression* SExpressionWasmBuilder::makeRefCast(Element& s) {
if (legacy) {
// Legacy polymorphic behavior.
nullability = ref->type.getNullability();
- } else if (ref->type.isRef()) {
- // Only accept instructions emulating the legacy behavior for now.
- if (nullability == NonNullable && ref->type.isNullable()) {
- throw ParseException(
- "ref.cast on nullable input not yet supported", s.line, s.col);
- } else if (nullability == Nullable && ref->type.isNonNullable()) {
- throw ParseException(
- "ref.cast null on non-nullable input not yet supported", s.line, s.col);
- }
}
auto type = Type(heapType, nullability);
return Builder(wasm).makeRefCast(ref, type, RefCast::Safe);
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 68fb58ec4..bec9b0ad0 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -2535,11 +2535,11 @@ void FunctionValidator::visitRefCast(RefCast* curr) {
curr,
"ref.cast target type and ref type must have a common supertype");
- // TODO: Remove this restriction
- shouldBeEqual(curr->type.getNullability(),
- curr->ref->type.getNullability(),
- curr,
- "ref.cast to a different nullability not yet implemented");
+ // We should never have a nullable cast of a non-nullable reference, since
+ // that unnecessarily loses type information.
+ shouldBeTrue(curr->ref->type.isNullable() || curr->type.isNonNullable(),
+ curr,
+ "ref.cast null of non-nullable references are not allowed");
}
void FunctionValidator::visitBrOn(BrOn* curr) {
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index cf71689b1..682936461 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -945,6 +945,7 @@ void RefTest::finalize() {
void RefCast::finalize() {
if (ref->type == Type::unreachable) {
type = Type::unreachable;
+ return;
}
// Do not unnecessarily lose non-nullability information.
if (ref->type.isNonNullable() && type.isNullable()) {
diff --git a/test/lit/passes/optimize-instructions-gc.wast b/test/lit/passes/optimize-instructions-gc.wast
index 1bf72ae04..fe18c015b 100644
--- a/test/lit/passes/optimize-instructions-gc.wast
+++ b/test/lit/passes/optimize-instructions-gc.wast
@@ -1791,59 +1791,57 @@
;; CHECK: (func $incompatible-cast-of-null (type $void)
;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (block (result nullref)
+ ;; CHECK-NEXT: (block (result (ref $array))
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (ref.null none)
;; CHECK-NEXT: )
- ;; CHECK-NEXT: (ref.null none)
+ ;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (ref.as_non_null
- ;; CHECK-NEXT: (block (result nullref)
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (ref.as_non_null
- ;; CHECK-NEXT: (ref.null none)
- ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (block (result (ref $array))
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (ref.as_non_null
+ ;; CHECK-NEXT: (ref.null none)
;; CHECK-NEXT: )
- ;; CHECK-NEXT: (ref.null none)
;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; NOMNL: (func $incompatible-cast-of-null (type $void)
;; NOMNL-NEXT: (drop
- ;; NOMNL-NEXT: (block (result nullref)
+ ;; NOMNL-NEXT: (block (result (ref $array))
;; NOMNL-NEXT: (drop
;; NOMNL-NEXT: (ref.null none)
;; NOMNL-NEXT: )
- ;; NOMNL-NEXT: (ref.null none)
+ ;; NOMNL-NEXT: (unreachable)
;; NOMNL-NEXT: )
;; NOMNL-NEXT: )
;; NOMNL-NEXT: (drop
- ;; NOMNL-NEXT: (ref.as_non_null
- ;; NOMNL-NEXT: (block (result nullref)
- ;; NOMNL-NEXT: (drop
- ;; NOMNL-NEXT: (ref.as_non_null
- ;; NOMNL-NEXT: (ref.null none)
- ;; NOMNL-NEXT: )
+ ;; NOMNL-NEXT: (block (result (ref $array))
+ ;; NOMNL-NEXT: (drop
+ ;; NOMNL-NEXT: (ref.as_non_null
+ ;; NOMNL-NEXT: (ref.null none)
;; NOMNL-NEXT: )
- ;; NOMNL-NEXT: (ref.null none)
;; NOMNL-NEXT: )
+ ;; NOMNL-NEXT: (unreachable)
;; NOMNL-NEXT: )
;; NOMNL-NEXT: )
;; NOMNL-NEXT: )
(func $incompatible-cast-of-null
(drop
- (ref.cast null $array
- (ref.null $struct)
+ (ref.cast $array
+ ;; The child is null, so the cast will trap. Replace it with an
+ ;; unreachable.
+ (ref.null none)
)
)
(drop
(ref.cast $array
- ;; The fallthrough is null, but the node's child's type is non-nullable,
- ;; so we must add a ref.as_non_null on the outside to keep the type
- ;; identical.
+ ;; Even though the child type is non-null, it is still valid to do this
+ ;; transformation. In practice this code will trap before getting to our
+ ;; new unreachable.
(ref.as_non_null
(ref.null $struct)
)
diff --git a/test/spec/ref_cast.wast b/test/spec/ref_cast.wast
index 658f20c23..704063ec6 100644
--- a/test/spec/ref_cast.wast
+++ b/test/spec/ref_cast.wast
@@ -30,23 +30,34 @@
(call $init)
(drop (ref.cast null $t0 (ref.null data)))
+ (drop (ref.cast null $t0 (struct.new_default $t0)))
(drop (ref.cast null $t0 (global.get $tab.0)))
(drop (ref.cast null $t0 (global.get $tab.1)))
(drop (ref.cast null $t0 (global.get $tab.2)))
(drop (ref.cast null $t0 (global.get $tab.3)))
(drop (ref.cast null $t0 (global.get $tab.4)))
-
- (drop (ref.cast null $t0 (ref.null data)))
+ (drop (ref.cast $t0 (global.get $tab.0)))
+ (drop (ref.cast $t0 (global.get $tab.1)))
+ (drop (ref.cast $t0 (global.get $tab.2)))
+ (drop (ref.cast $t0 (global.get $tab.3)))
+ (drop (ref.cast $t0 (global.get $tab.4)))
+
+ (drop (ref.cast null $t1 (ref.null data)))
+ (drop (ref.cast null $t1 (struct.new_default $t1)))
(drop (ref.cast null $t1 (global.get $tab.1)))
(drop (ref.cast null $t1 (global.get $tab.2)))
+ (drop (ref.cast $t1 (global.get $tab.1)))
+ (drop (ref.cast $t1 (global.get $tab.2)))
- (drop (ref.cast null $t0 (ref.null data)))
+ (drop (ref.cast null $t2 (ref.null data)))
+ (drop (ref.cast null $t2 (struct.new_default $t2)))
(drop (ref.cast null $t2 (global.get $tab.2)))
+ (drop (ref.cast $t2 (global.get $tab.2)))
- (drop (ref.cast null $t0 (ref.null data)))
+ (drop (ref.cast null $t3 (ref.null data)))
+ (drop (ref.cast null $t3 (struct.new_default $t3)))
(drop (ref.cast null $t3 (global.get $tab.3)))
-
- (drop (ref.cast null $t0 (ref.null data)))
+ (drop (ref.cast $t3 (global.get $tab.3)))
)
(func (export "test-canon")
@@ -103,8 +114,15 @@
)
(i32.const 1)
)
-)
+ (func (export "test-trap-null")
+ (drop
+ (ref.cast $t0
+ (ref.null $t0)
+ )
+ )
+ )
+)
(invoke "test-sub")
(invoke "test-canon")
@@ -114,6 +132,7 @@
(assert_return (invoke "test-ref-cast-struct"))
(assert_return (invoke "test-br-on-cast-struct") (i32.const 1))
(assert_return (invoke "test-br-on-cast-fail-struct") (i32.const 0))
+(assert_trap (invoke "test-trap-null"))
(assert_invalid
(module