summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/binaryen-c.cpp20
-rw-r--r--src/binaryen-c.h11
-rw-r--r--src/ir/module-utils.cpp2
-rw-r--r--src/passes/OptimizeInstructions.cpp9
-rw-r--r--src/passes/Print.cpp11
-rw-r--r--src/passes/TypeMerging.cpp6
-rw-r--r--src/wasm-builder.h5
-rw-r--r--src/wasm-delegations-fields.def1
-rw-r--r--src/wasm-interpreter.h42
-rw-r--r--src/wasm.h10
-rw-r--r--src/wasm/wasm-binary.cpp20
-rw-r--r--src/wasm/wasm-s-parser.cpp24
-rw-r--r--src/wasm/wasm-stack.cpp7
-rw-r--r--src/wasm/wasm-validator.cpp8
-rw-r--r--src/wasm/wasm.cpp8
15 files changed, 96 insertions, 88 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index e3d1505fa..3a42b6f25 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -1749,11 +1749,10 @@ BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module,
}
BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module,
BinaryenExpressionRef ref,
- BinaryenHeapType intendedType) {
- return static_cast<Expression*>(Builder(*(Module*)module)
- .makeRefCast((Expression*)ref,
- HeapType(intendedType),
- RefCast::Safety::Safe));
+ BinaryenType type) {
+ return static_cast<Expression*>(
+ Builder(*(Module*)module)
+ .makeRefCast((Expression*)ref, Type(type), RefCast::Safety::Safe));
}
BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module,
BinaryenOp op,
@@ -4087,17 +4086,6 @@ void BinaryenRefCastSetRef(BinaryenExpressionRef expr,
assert(refExpr);
static_cast<RefCast*>(expression)->ref = (Expression*)refExpr;
}
-BinaryenHeapType BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr) {
- auto* expression = (Expression*)expr;
- assert(expression->is<RefCast>());
- return static_cast<RefCast*>(expression)->intendedType.getID();
-}
-void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr,
- BinaryenHeapType intendedType) {
- auto* expression = (Expression*)expr;
- assert(expression->is<RefCast>());
- static_cast<RefCast*>(expression)->intendedType = HeapType(intendedType);
-}
// BrOn
BinaryenOp BinaryenBrOnGetOp(BinaryenExpressionRef expr) {
auto* expression = (Expression*)expr;
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index c9331d49d..9fc0d8b04 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -1042,10 +1042,9 @@ BINARYEN_API BinaryenExpressionRef
BinaryenRefTest(BinaryenModuleRef module,
BinaryenExpressionRef ref,
BinaryenHeapType intendedType);
-BINARYEN_API BinaryenExpressionRef
-BinaryenRefCast(BinaryenModuleRef module,
- BinaryenExpressionRef ref,
- BinaryenHeapType intendedType);
+BINARYEN_API BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module,
+ BinaryenExpressionRef ref,
+ BinaryenType type);
BINARYEN_API BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module,
BinaryenOp op,
const char* name,
@@ -2374,10 +2373,6 @@ BINARYEN_API BinaryenExpressionRef
BinaryenRefCastGetRef(BinaryenExpressionRef expr);
BINARYEN_API void BinaryenRefCastSetRef(BinaryenExpressionRef expr,
BinaryenExpressionRef refExpr);
-BINARYEN_API BinaryenHeapType
-BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr);
-BINARYEN_API void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr,
- BinaryenHeapType intendedType);
// BrOn
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index f3d9201b7..8e6ac0d1a 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -71,7 +71,7 @@ struct CodeScanner
} else if (curr->is<ArrayInit>()) {
counts.note(curr->type);
} else if (auto* cast = curr->dynCast<RefCast>()) {
- counts.note(cast->intendedType);
+ counts.note(cast->type);
} else if (auto* cast = curr->dynCast<RefTest>()) {
counts.note(cast->intendedType);
} else if (auto* cast = curr->dynCast<BrOn>()) {
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index d218f5627..69a3a1fdb 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -1854,7 +1854,7 @@ struct OptimizeInstructions
auto fallthrough =
Properties::getFallthrough(curr->ref, getPassOptions(), *getModule());
- auto intendedType = curr->intendedType;
+ auto intendedType = curr->type.getHeapType();
// If the value is a null, it will just flow through, and we do not need
// the cast. However, if that would change the type, then things are less
@@ -1942,7 +1942,7 @@ struct OptimizeInstructions
if (auto* child = ref->dynCast<RefCast>()) {
// Repeated casts can be removed, leaving just the most demanding of
// them.
- auto childIntendedType = child->intendedType;
+ auto childIntendedType = child->type.getHeapType();
if (HeapType::isSubType(intendedType, childIntendedType)) {
// Skip the child.
if (curr->ref == child) {
@@ -2000,6 +2000,11 @@ struct OptimizeInstructions
if (auto* as = curr->ref->dynCast<RefAs>()) {
if (as->op == RefAsNonNull) {
curr->ref = as->value;
+ // Match the nullability of the new child.
+ // TODO: Combine the ref.as_non_null into the cast once we allow that.
+ if (curr->ref->type.isNullable()) {
+ curr->type = Type(curr->type.getHeapType(), Nullable);
+ }
curr->finalize();
as->value = curr;
as->finalize();
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 68a2d16dc..2cd021b6f 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -2113,17 +2113,19 @@ struct PrintExpressionContents
printHeapType(o, curr->intendedType, wasm);
}
void visitRefCast(RefCast* curr) {
+ if (printUnreachableReplacement(curr)) {
+ return;
+ }
if (curr->safety == RefCast::Unsafe) {
printMedium(o, "ref.cast_nop ");
} else {
- // Emulate legacy polymorphic behavior for now.
- if (curr->ref->type.isNullable()) {
+ if (curr->type.isNullable()) {
printMedium(o, "ref.cast null ");
} else {
printMedium(o, "ref.cast ");
}
}
- printHeapType(o, curr->intendedType, wasm);
+ printHeapType(o, curr->type.getHeapType(), wasm);
}
void visitBrOn(BrOn* curr) {
@@ -2825,6 +2827,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
void visitCallRef(CallRef* curr) {
maybePrintUnreachableOrNullReplacement(curr, curr->target->type);
}
+ void visitRefCast(RefCast* curr) {
+ maybePrintUnreachableReplacement(curr, curr->type);
+ }
void visitStructNew(StructNew* curr) {
maybePrintUnreachableReplacement(curr, curr->type);
}
diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp
index 0fc766f4e..d873152f2 100644
--- a/src/passes/TypeMerging.cpp
+++ b/src/passes/TypeMerging.cpp
@@ -83,6 +83,12 @@ struct CastFinder
#include "wasm-delegations-fields.def"
}
+
+ void visitRefCast(Expression* curr) {
+ if (curr->type != Type::unreachable) {
+ referredTypes.insert(curr->type.getHeapType());
+ }
+ }
};
struct TypeMerging : public Pass {
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index ff834013d..173e5976f 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -875,11 +875,10 @@ public:
ret->finalize();
return ret;
}
- RefCast*
- makeRefCast(Expression* ref, HeapType intendedType, RefCast::Safety safety) {
+ RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) {
auto* ret = wasm.allocator.alloc<RefCast>();
ret->ref = ref;
- ret->intendedType = intendedType;
+ ret->type = type;
ret->safety = safety;
ret->finalize();
return ret;
diff --git a/src/wasm-delegations-fields.def b/src/wasm-delegations-fields.def
index 85a088dd8..dc0a13f41 100644
--- a/src/wasm-delegations-fields.def
+++ b/src/wasm-delegations-fields.def
@@ -621,7 +621,6 @@ switch (DELEGATE_ID) {
}
case Expression::Id::RefCastId: {
DELEGATE_START(RefCast);
- DELEGATE_FIELD_HEAPTYPE(RefCast, intendedType);
DELEGATE_FIELD_CHILD(RefCast, ref);
DELEGATE_END(RefCast);
break;
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index e06d550b1..17c8a6456 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -1430,10 +1430,6 @@ public:
struct Breaking : Flow {
Breaking(Flow breaking) : Flow(breaking) {}
};
- // The null input to the cast.
- struct Null : Literal {
- Null(Literal original) : Literal(original) {}
- };
// The result of the successful cast.
struct Success : Literal {
Success(Literal result) : Literal(result) {}
@@ -1443,20 +1439,12 @@ public:
Failure(Literal original) : Literal(original) {}
};
- std::variant<Breaking, Null, Success, Failure> state;
+ std::variant<Breaking, Success, Failure> state;
template<class T> Cast(T state) : state(state) {}
Flow* getBreaking() { return std::get_if<Breaking>(&state); }
- Literal* getNull() { return std::get_if<Null>(&state); }
Literal* getSuccess() { return std::get_if<Success>(&state); }
Literal* getFailure() { return std::get_if<Failure>(&state); }
- Literal* getNullOrFailure() {
- if (auto* original = getNull()) {
- return original;
- } else {
- return getFailure();
- }
- }
};
template<typename T> Cast doCast(T* curr) {
@@ -1464,21 +1452,19 @@ public:
if (ref.breaking()) {
return typename Cast::Breaking{ref};
}
- Literal original = ref.getSingleValue();
- if (original.isNull()) {
- return typename Cast::Null{original};
- }
- // The input may not be GC data or a function; for example it could be an
- // anyref or an i31. The cast definitely fails in these cases.
- if (!original.isData() && !original.isFunction()) {
- return typename Cast::Failure{original};
+ Literal val = ref.getSingleValue();
+ Type castType = curr->getCastType();
+ if (val.isNull()) {
+ if (castType.isNullable()) {
+ return typename Cast::Success{val};
+ } else {
+ return typename Cast::Failure{val};
+ }
}
- HeapType actualType = original.type.getHeapType();
- // We have the actual and intended types, so perform the cast.
- if (HeapType::isSubType(actualType, curr->intendedType)) {
- return typename Cast::Success{original};
+ if (HeapType::isSubType(val.type.getHeapType(), castType.getHeapType())) {
+ return typename Cast::Success{val};
} else {
- return typename Cast::Failure{original};
+ return typename Cast::Failure{val};
}
}
@@ -1496,8 +1482,6 @@ public:
auto cast = doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
- } else if (cast.getNull()) {
- return Literal::makeNull(curr->type.getHeapType());
} else if (auto* result = cast.getSuccess()) {
return *result;
}
@@ -1512,7 +1496,7 @@ public:
auto cast = doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
- } else if (auto* original = cast.getNullOrFailure()) {
+ } else if (auto* original = cast.getFailure()) {
if (curr->op == BrOnCast) {
return *original;
} else {
diff --git a/src/wasm.h b/src/wasm.h
index 56dfa1ff2..3a44556d5 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1525,6 +1525,9 @@ public:
HeapType intendedType;
void finalize();
+
+ // TODO: Support ref.test null as well.
+ Type getCastType() { return Type(intendedType, NonNullable); }
};
class RefCast : public SpecificExpression<Expression::RefCastId> {
@@ -1533,14 +1536,14 @@ public:
Expression* ref;
- HeapType intendedType;
-
// Support the unsafe `ref.cast_nop_static` to enable precise cast overhead
// measurements.
enum Safety { Safe, Unsafe };
Safety safety = Safe;
void finalize();
+
+ Type getCastType() { return type; }
};
class BrOn : public SpecificExpression<Expression::BrOnId> {
@@ -1555,6 +1558,9 @@ public:
void finalize();
+ // TODO: Support br_on_cast* null as well.
+ Type getCastType() { return Type(intendedType, NonNullable); }
+
// Returns the type sent on the branch, if it is taken.
Type getSentType();
};
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 0bad4d5ef..0160ffaf6 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -6908,21 +6908,27 @@ bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) {
code == BinaryConsts::RefCastNull || code == BinaryConsts::RefCastNop) {
bool legacy =
code == BinaryConsts::RefCastStatic || code == BinaryConsts::RefCastNop;
- auto intendedType = legacy ? getIndexedHeapType() : getHeapType();
+ auto heapType = legacy ? getIndexedHeapType() : getHeapType();
auto* ref = popNonVoidExpression();
- // Even though we're parsing new instructions, we only support those that
- // emulate the legacy polymorphic behavior for now.
+ Nullability nullability;
+ if (legacy) {
+ // Legacy polymorphic behavior.
+ nullability = ref->type.getNullability();
+ } else {
+ nullability = code == BinaryConsts::RefCast ? NonNullable : Nullable;
+ }
+ // Only accept instructions emulating the legacy behavior for now.
if (ref->type.isRef()) {
- if (code == BinaryConsts::RefCast && ref->type.isNullable()) {
+ if (nullability == NonNullable && ref->type.isNullable()) {
throwError("ref.cast on nullable input not yet supported");
- } else if (code == BinaryConsts::RefCastNull &&
- ref->type.isNonNullable()) {
+ } else if (nullability == Nullable && ref->type.isNonNullable()) {
throwError("ref.cast null on non-nullable input not yet supported");
}
}
auto safety =
code == BinaryConsts::RefCastNop ? RefCast::Unsafe : RefCast::Safe;
- out = Builder(wasm).makeRefCast(ref, intendedType, safety);
+ auto type = Type(heapType, nullability);
+ out = Builder(wasm).makeRefCast(ref, type, safety);
return true;
}
return false;
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 98f7c1e87..cf287d8a1 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -2783,8 +2783,11 @@ Expression* SExpressionWasmBuilder::makeRefTest(Element& s) {
Expression* SExpressionWasmBuilder::makeRefCast(Element& s) {
int i = 1;
- std::optional<Nullability> nullability;
- if (s[0]->str().str != "ref.cast_static") {
+ Nullability nullability;
+ bool legacy = false;
+ if (s[0]->str().str == "ref.cast_static") {
+ legacy = true;
+ } else {
nullability = NonNullable;
if (s[i]->str().str == "null") {
nullability = Nullable;
@@ -2793,22 +2796,29 @@ Expression* SExpressionWasmBuilder::makeRefCast(Element& s) {
}
auto heapType = parseHeapType(*s[i++]);
auto* ref = parseExpression(*s[i++]);
- if (nullability && ref->type.isRef()) {
- if (*nullability == NonNullable && ref->type.isNullable()) {
+ if (legacy) {
+ // Legacy polymorphic behavior.
+ nullability = ref->type.getNullability();
+ } else if (ref->type.isRef()) {
+ // Only accept instructions emulating the legacy behavior for now.
+ if (nullability == NonNullable && ref->type.isNullable()) {
throw ParseException(
"ref.cast on nullable input not yet supported", s.line, s.col);
- } else if (*nullability == Nullable && ref->type.isNonNullable()) {
+ } else if (nullability == Nullable && ref->type.isNonNullable()) {
throw ParseException(
"ref.cast null on non-nullable input not yet supported", s.line, s.col);
}
}
- return Builder(wasm).makeRefCast(ref, heapType, RefCast::Safe);
+ auto type = Type(heapType, nullability);
+ return Builder(wasm).makeRefCast(ref, type, RefCast::Safe);
}
Expression* SExpressionWasmBuilder::makeRefCastNop(Element& s) {
auto heapType = parseHeapType(*s[1]);
auto* ref = parseExpression(*s[2]);
- return Builder(wasm).makeRefCast(ref, heapType, RefCast::Unsafe);
+ // Legacy polymorphic behavior.
+ auto type = Type(heapType, ref->type.getNullability());
+ return Builder(wasm).makeRefCast(ref, type, RefCast::Unsafe);
}
Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) {
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index db40a5980..911447b48 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -2033,15 +2033,14 @@ void BinaryInstWriter::visitRefCast(RefCast* curr) {
o << int8_t(BinaryConsts::GCPrefix);
if (curr->safety == RefCast::Unsafe) {
o << U32LEB(BinaryConsts::RefCastNop);
- parent.writeIndexedHeapType(curr->intendedType);
+ parent.writeIndexedHeapType(curr->type.getHeapType());
} else {
- // Emulate legacy polymorphic behavior for now.
- if (curr->ref->type.isNullable()) {
+ if (curr->type.isNullable()) {
o << U32LEB(BinaryConsts::RefCastNull);
} else {
o << U32LEB(BinaryConsts::RefCast);
}
- parent.writeHeapType(curr->intendedType);
+ parent.writeHeapType(curr->type.getHeapType());
}
}
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 726c70c7c..b102ecb1c 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -2530,10 +2530,16 @@ void FunctionValidator::visitRefCast(RefCast* curr) {
return;
}
shouldBeEqual(
- curr->intendedType.getBottom(),
+ curr->type.getHeapType().getBottom(),
curr->ref->type.getHeapType().getBottom(),
curr,
"ref.cast target type and ref type must have a common supertype");
+
+ // TODO: Remove this restriction
+ shouldBeEqual(curr->type.getNullability(),
+ curr->ref->type.getNullability(),
+ curr,
+ "ref.cast to a different nullability not yet implemented");
}
void FunctionValidator::visitBrOn(BrOn* curr) {
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index 604854878..cf71689b1 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -945,10 +945,10 @@ void RefTest::finalize() {
void RefCast::finalize() {
if (ref->type == Type::unreachable) {
type = Type::unreachable;
- } else {
- // The output of ref.cast may be null if the input is null (in that case the
- // null is passed through).
- type = Type(intendedType, ref->type.getNullability());
+ }
+ // Do not unnecessarily lose non-nullability information.
+ if (ref->type.isNonNullable() && type.isNullable()) {
+ type = Type(type.getHeapType(), NonNullable);
}
}