summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Lively <tlively@google.com>2022-12-20 09:52:54 -0600
committerGitHub <noreply@github.com>2022-12-20 07:52:54 -0800
commit569f789622f116177c8a1e32fb62a4e5a5c9dfe0 (patch)
tree3a67097f753f3a22a1ebde6cd4f50c532e773663 /src
parent12ad604c17407f6b36d52c6404f2dab32e5c7960 (diff)
downloadbinaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.tar.gz
binaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.tar.bz2
binaryen-569f789622f116177c8a1e32fb62a4e5a5c9dfe0.zip
Update RefCast representation to drop extra HeapType (#5350)
The latest upstream version of ref.cast is parameterized with a target reference type, not just a heap type, because the nullability of the result is parameterizable. As a first step toward implementing these new, more flexible ref.cast instructions, change the internal representation of ref.cast to use the expression type as the cast target rather than storing a separate heap type field. For now require that the encoded semantics match the previously allowed semantics, though, so that none of the optimization passes need to be updated.
Diffstat (limited to 'src')
-rw-r--r--src/binaryen-c.cpp20
-rw-r--r--src/binaryen-c.h11
-rw-r--r--src/ir/module-utils.cpp2
-rw-r--r--src/passes/OptimizeInstructions.cpp9
-rw-r--r--src/passes/Print.cpp11
-rw-r--r--src/passes/TypeMerging.cpp6
-rw-r--r--src/wasm-builder.h5
-rw-r--r--src/wasm-delegations-fields.def1
-rw-r--r--src/wasm-interpreter.h42
-rw-r--r--src/wasm.h10
-rw-r--r--src/wasm/wasm-binary.cpp20
-rw-r--r--src/wasm/wasm-s-parser.cpp24
-rw-r--r--src/wasm/wasm-stack.cpp7
-rw-r--r--src/wasm/wasm-validator.cpp8
-rw-r--r--src/wasm/wasm.cpp8
15 files changed, 96 insertions, 88 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index e3d1505fa..3a42b6f25 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -1749,11 +1749,10 @@ BinaryenExpressionRef BinaryenRefTest(BinaryenModuleRef module,
}
BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module,
BinaryenExpressionRef ref,
- BinaryenHeapType intendedType) {
- return static_cast<Expression*>(Builder(*(Module*)module)
- .makeRefCast((Expression*)ref,
- HeapType(intendedType),
- RefCast::Safety::Safe));
+ BinaryenType type) {
+ return static_cast<Expression*>(
+ Builder(*(Module*)module)
+ .makeRefCast((Expression*)ref, Type(type), RefCast::Safety::Safe));
}
BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module,
BinaryenOp op,
@@ -4087,17 +4086,6 @@ void BinaryenRefCastSetRef(BinaryenExpressionRef expr,
assert(refExpr);
static_cast<RefCast*>(expression)->ref = (Expression*)refExpr;
}
-BinaryenHeapType BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr) {
- auto* expression = (Expression*)expr;
- assert(expression->is<RefCast>());
- return static_cast<RefCast*>(expression)->intendedType.getID();
-}
-void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr,
- BinaryenHeapType intendedType) {
- auto* expression = (Expression*)expr;
- assert(expression->is<RefCast>());
- static_cast<RefCast*>(expression)->intendedType = HeapType(intendedType);
-}
// BrOn
BinaryenOp BinaryenBrOnGetOp(BinaryenExpressionRef expr) {
auto* expression = (Expression*)expr;
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index c9331d49d..9fc0d8b04 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -1042,10 +1042,9 @@ BINARYEN_API BinaryenExpressionRef
BinaryenRefTest(BinaryenModuleRef module,
BinaryenExpressionRef ref,
BinaryenHeapType intendedType);
-BINARYEN_API BinaryenExpressionRef
-BinaryenRefCast(BinaryenModuleRef module,
- BinaryenExpressionRef ref,
- BinaryenHeapType intendedType);
+BINARYEN_API BinaryenExpressionRef BinaryenRefCast(BinaryenModuleRef module,
+ BinaryenExpressionRef ref,
+ BinaryenType type);
BINARYEN_API BinaryenExpressionRef BinaryenBrOn(BinaryenModuleRef module,
BinaryenOp op,
const char* name,
@@ -2374,10 +2373,6 @@ BINARYEN_API BinaryenExpressionRef
BinaryenRefCastGetRef(BinaryenExpressionRef expr);
BINARYEN_API void BinaryenRefCastSetRef(BinaryenExpressionRef expr,
BinaryenExpressionRef refExpr);
-BINARYEN_API BinaryenHeapType
-BinaryenRefCastGetIntendedType(BinaryenExpressionRef expr);
-BINARYEN_API void BinaryenRefCastSetIntendedType(BinaryenExpressionRef expr,
- BinaryenHeapType intendedType);
// BrOn
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index f3d9201b7..8e6ac0d1a 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -71,7 +71,7 @@ struct CodeScanner
} else if (curr->is<ArrayInit>()) {
counts.note(curr->type);
} else if (auto* cast = curr->dynCast<RefCast>()) {
- counts.note(cast->intendedType);
+ counts.note(cast->type);
} else if (auto* cast = curr->dynCast<RefTest>()) {
counts.note(cast->intendedType);
} else if (auto* cast = curr->dynCast<BrOn>()) {
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index d218f5627..69a3a1fdb 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -1854,7 +1854,7 @@ struct OptimizeInstructions
auto fallthrough =
Properties::getFallthrough(curr->ref, getPassOptions(), *getModule());
- auto intendedType = curr->intendedType;
+ auto intendedType = curr->type.getHeapType();
// If the value is a null, it will just flow through, and we do not need
// the cast. However, if that would change the type, then things are less
@@ -1942,7 +1942,7 @@ struct OptimizeInstructions
if (auto* child = ref->dynCast<RefCast>()) {
// Repeated casts can be removed, leaving just the most demanding of
// them.
- auto childIntendedType = child->intendedType;
+ auto childIntendedType = child->type.getHeapType();
if (HeapType::isSubType(intendedType, childIntendedType)) {
// Skip the child.
if (curr->ref == child) {
@@ -2000,6 +2000,11 @@ struct OptimizeInstructions
if (auto* as = curr->ref->dynCast<RefAs>()) {
if (as->op == RefAsNonNull) {
curr->ref = as->value;
+ // Match the nullability of the new child.
+ // TODO: Combine the ref.as_non_null into the cast once we allow that.
+ if (curr->ref->type.isNullable()) {
+ curr->type = Type(curr->type.getHeapType(), Nullable);
+ }
curr->finalize();
as->value = curr;
as->finalize();
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 68a2d16dc..2cd021b6f 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -2113,17 +2113,19 @@ struct PrintExpressionContents
printHeapType(o, curr->intendedType, wasm);
}
void visitRefCast(RefCast* curr) {
+ if (printUnreachableReplacement(curr)) {
+ return;
+ }
if (curr->safety == RefCast::Unsafe) {
printMedium(o, "ref.cast_nop ");
} else {
- // Emulate legacy polymorphic behavior for now.
- if (curr->ref->type.isNullable()) {
+ if (curr->type.isNullable()) {
printMedium(o, "ref.cast null ");
} else {
printMedium(o, "ref.cast ");
}
}
- printHeapType(o, curr->intendedType, wasm);
+ printHeapType(o, curr->type.getHeapType(), wasm);
}
void visitBrOn(BrOn* curr) {
@@ -2825,6 +2827,9 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
void visitCallRef(CallRef* curr) {
maybePrintUnreachableOrNullReplacement(curr, curr->target->type);
}
+ void visitRefCast(RefCast* curr) {
+ maybePrintUnreachableReplacement(curr, curr->type);
+ }
void visitStructNew(StructNew* curr) {
maybePrintUnreachableReplacement(curr, curr->type);
}
diff --git a/src/passes/TypeMerging.cpp b/src/passes/TypeMerging.cpp
index 0fc766f4e..d873152f2 100644
--- a/src/passes/TypeMerging.cpp
+++ b/src/passes/TypeMerging.cpp
@@ -83,6 +83,12 @@ struct CastFinder
#include "wasm-delegations-fields.def"
}
+
+ void visitRefCast(Expression* curr) {
+ if (curr->type != Type::unreachable) {
+ referredTypes.insert(curr->type.getHeapType());
+ }
+ }
};
struct TypeMerging : public Pass {
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index ff834013d..173e5976f 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -875,11 +875,10 @@ public:
ret->finalize();
return ret;
}
- RefCast*
- makeRefCast(Expression* ref, HeapType intendedType, RefCast::Safety safety) {
+ RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) {
auto* ret = wasm.allocator.alloc<RefCast>();
ret->ref = ref;
- ret->intendedType = intendedType;
+ ret->type = type;
ret->safety = safety;
ret->finalize();
return ret;
diff --git a/src/wasm-delegations-fields.def b/src/wasm-delegations-fields.def
index 85a088dd8..dc0a13f41 100644
--- a/src/wasm-delegations-fields.def
+++ b/src/wasm-delegations-fields.def
@@ -621,7 +621,6 @@ switch (DELEGATE_ID) {
}
case Expression::Id::RefCastId: {
DELEGATE_START(RefCast);
- DELEGATE_FIELD_HEAPTYPE(RefCast, intendedType);
DELEGATE_FIELD_CHILD(RefCast, ref);
DELEGATE_END(RefCast);
break;
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index e06d550b1..17c8a6456 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -1430,10 +1430,6 @@ public:
struct Breaking : Flow {
Breaking(Flow breaking) : Flow(breaking) {}
};
- // The null input to the cast.
- struct Null : Literal {
- Null(Literal original) : Literal(original) {}
- };
// The result of the successful cast.
struct Success : Literal {
Success(Literal result) : Literal(result) {}
@@ -1443,20 +1439,12 @@ public:
Failure(Literal original) : Literal(original) {}
};
- std::variant<Breaking, Null, Success, Failure> state;
+ std::variant<Breaking, Success, Failure> state;
template<class T> Cast(T state) : state(state) {}
Flow* getBreaking() { return std::get_if<Breaking>(&state); }
- Literal* getNull() { return std::get_if<Null>(&state); }
Literal* getSuccess() { return std::get_if<Success>(&state); }
Literal* getFailure() { return std::get_if<Failure>(&state); }
- Literal* getNullOrFailure() {
- if (auto* original = getNull()) {
- return original;
- } else {
- return getFailure();
- }
- }
};
template<typename T> Cast doCast(T* curr) {
@@ -1464,21 +1452,19 @@ public:
if (ref.breaking()) {
return typename Cast::Breaking{ref};
}
- Literal original = ref.getSingleValue();
- if (original.isNull()) {
- return typename Cast::Null{original};
- }
- // The input may not be GC data or a function; for example it could be an
- // anyref or an i31. The cast definitely fails in these cases.
- if (!original.isData() && !original.isFunction()) {
- return typename Cast::Failure{original};
+ Literal val = ref.getSingleValue();
+ Type castType = curr->getCastType();
+ if (val.isNull()) {
+ if (castType.isNullable()) {
+ return typename Cast::Success{val};
+ } else {
+ return typename Cast::Failure{val};
+ }
}
- HeapType actualType = original.type.getHeapType();
- // We have the actual and intended types, so perform the cast.
- if (HeapType::isSubType(actualType, curr->intendedType)) {
- return typename Cast::Success{original};
+ if (HeapType::isSubType(val.type.getHeapType(), castType.getHeapType())) {
+ return typename Cast::Success{val};
} else {
- return typename Cast::Failure{original};
+ return typename Cast::Failure{val};
}
}
@@ -1496,8 +1482,6 @@ public:
auto cast = doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
- } else if (cast.getNull()) {
- return Literal::makeNull(curr->type.getHeapType());
} else if (auto* result = cast.getSuccess()) {
return *result;
}
@@ -1512,7 +1496,7 @@ public:
auto cast = doCast(curr);
if (auto* breaking = cast.getBreaking()) {
return *breaking;
- } else if (auto* original = cast.getNullOrFailure()) {
+ } else if (auto* original = cast.getFailure()) {
if (curr->op == BrOnCast) {
return *original;
} else {
diff --git a/src/wasm.h b/src/wasm.h
index 56dfa1ff2..3a44556d5 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -1525,6 +1525,9 @@ public:
HeapType intendedType;
void finalize();
+
+ // TODO: Support ref.test null as well.
+ Type getCastType() { return Type(intendedType, NonNullable); }
};
class RefCast : public SpecificExpression<Expression::RefCastId> {
@@ -1533,14 +1536,14 @@ public:
Expression* ref;
- HeapType intendedType;
-
// Support the unsafe `ref.cast_nop_static` to enable precise cast overhead
// measurements.
enum Safety { Safe, Unsafe };
Safety safety = Safe;
void finalize();
+
+ Type getCastType() { return type; }
};
class BrOn : public SpecificExpression<Expression::BrOnId> {
@@ -1555,6 +1558,9 @@ public:
void finalize();
+ // TODO: Support br_on_cast* null as well.
+ Type getCastType() { return Type(intendedType, NonNullable); }
+
// Returns the type sent on the branch, if it is taken.
Type getSentType();
};
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 0bad4d5ef..0160ffaf6 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -6908,21 +6908,27 @@ bool WasmBinaryBuilder::maybeVisitRefCast(Expression*& out, uint32_t code) {
code == BinaryConsts::RefCastNull || code == BinaryConsts::RefCastNop) {
bool legacy =
code == BinaryConsts::RefCastStatic || code == BinaryConsts::RefCastNop;
- auto intendedType = legacy ? getIndexedHeapType() : getHeapType();
+ auto heapType = legacy ? getIndexedHeapType() : getHeapType();
auto* ref = popNonVoidExpression();
- // Even though we're parsing new instructions, we only support those that
- // emulate the legacy polymorphic behavior for now.
+ Nullability nullability;
+ if (legacy) {
+ // Legacy polymorphic behavior.
+ nullability = ref->type.getNullability();
+ } else {
+ nullability = code == BinaryConsts::RefCast ? NonNullable : Nullable;
+ }
+ // Only accept instructions emulating the legacy behavior for now.
if (ref->type.isRef()) {
- if (code == BinaryConsts::RefCast && ref->type.isNullable()) {
+ if (nullability == NonNullable && ref->type.isNullable()) {
throwError("ref.cast on nullable input not yet supported");
- } else if (code == BinaryConsts::RefCastNull &&
- ref->type.isNonNullable()) {
+ } else if (nullability == Nullable && ref->type.isNonNullable()) {
throwError("ref.cast null on non-nullable input not yet supported");
}
}
auto safety =
code == BinaryConsts::RefCastNop ? RefCast::Unsafe : RefCast::Safe;
- out = Builder(wasm).makeRefCast(ref, intendedType, safety);
+ auto type = Type(heapType, nullability);
+ out = Builder(wasm).makeRefCast(ref, type, safety);
return true;
}
return false;
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 98f7c1e87..cf287d8a1 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -2783,8 +2783,11 @@ Expression* SExpressionWasmBuilder::makeRefTest(Element& s) {
Expression* SExpressionWasmBuilder::makeRefCast(Element& s) {
int i = 1;
- std::optional<Nullability> nullability;
- if (s[0]->str().str != "ref.cast_static") {
+ Nullability nullability;
+ bool legacy = false;
+ if (s[0]->str().str == "ref.cast_static") {
+ legacy = true;
+ } else {
nullability = NonNullable;
if (s[i]->str().str == "null") {
nullability = Nullable;
@@ -2793,22 +2796,29 @@ Expression* SExpressionWasmBuilder::makeRefCast(Element& s) {
}
auto heapType = parseHeapType(*s[i++]);
auto* ref = parseExpression(*s[i++]);
- if (nullability && ref->type.isRef()) {
- if (*nullability == NonNullable && ref->type.isNullable()) {
+ if (legacy) {
+ // Legacy polymorphic behavior.
+ nullability = ref->type.getNullability();
+ } else if (ref->type.isRef()) {
+ // Only accept instructions emulating the legacy behavior for now.
+ if (nullability == NonNullable && ref->type.isNullable()) {
throw ParseException(
"ref.cast on nullable input not yet supported", s.line, s.col);
- } else if (*nullability == Nullable && ref->type.isNonNullable()) {
+ } else if (nullability == Nullable && ref->type.isNonNullable()) {
throw ParseException(
"ref.cast null on non-nullable input not yet supported", s.line, s.col);
}
}
- return Builder(wasm).makeRefCast(ref, heapType, RefCast::Safe);
+ auto type = Type(heapType, nullability);
+ return Builder(wasm).makeRefCast(ref, type, RefCast::Safe);
}
Expression* SExpressionWasmBuilder::makeRefCastNop(Element& s) {
auto heapType = parseHeapType(*s[1]);
auto* ref = parseExpression(*s[2]);
- return Builder(wasm).makeRefCast(ref, heapType, RefCast::Unsafe);
+ // Legacy polymorphic behavior.
+ auto type = Type(heapType, ref->type.getNullability());
+ return Builder(wasm).makeRefCast(ref, type, RefCast::Unsafe);
}
Expression* SExpressionWasmBuilder::makeBrOn(Element& s, BrOnOp op) {
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index db40a5980..911447b48 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -2033,15 +2033,14 @@ void BinaryInstWriter::visitRefCast(RefCast* curr) {
o << int8_t(BinaryConsts::GCPrefix);
if (curr->safety == RefCast::Unsafe) {
o << U32LEB(BinaryConsts::RefCastNop);
- parent.writeIndexedHeapType(curr->intendedType);
+ parent.writeIndexedHeapType(curr->type.getHeapType());
} else {
- // Emulate legacy polymorphic behavior for now.
- if (curr->ref->type.isNullable()) {
+ if (curr->type.isNullable()) {
o << U32LEB(BinaryConsts::RefCastNull);
} else {
o << U32LEB(BinaryConsts::RefCast);
}
- parent.writeHeapType(curr->intendedType);
+ parent.writeHeapType(curr->type.getHeapType());
}
}
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 726c70c7c..b102ecb1c 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -2530,10 +2530,16 @@ void FunctionValidator::visitRefCast(RefCast* curr) {
return;
}
shouldBeEqual(
- curr->intendedType.getBottom(),
+ curr->type.getHeapType().getBottom(),
curr->ref->type.getHeapType().getBottom(),
curr,
"ref.cast target type and ref type must have a common supertype");
+
+ // TODO: Remove this restriction
+ shouldBeEqual(curr->type.getNullability(),
+ curr->ref->type.getNullability(),
+ curr,
+ "ref.cast to a different nullability not yet implemented");
}
void FunctionValidator::visitBrOn(BrOn* curr) {
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index 604854878..cf71689b1 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -945,10 +945,10 @@ void RefTest::finalize() {
void RefCast::finalize() {
if (ref->type == Type::unreachable) {
type = Type::unreachable;
- } else {
- // The output of ref.cast may be null if the input is null (in that case the
- // null is passed through).
- type = Type(intendedType, ref->type.getNullability());
+ }
+ // Do not unnecessarily lose non-nullability information.
+ if (ref->type.isNonNullable() && type.isNullable()) {
+ type = Type(type.getHeapType(), NonNullable);
}
}