summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/passes/TypeRefining.cpp63
-rw-r--r--src/wasm-builder.h4
-rw-r--r--test/lit/passes/type-refining.wast109
3 files changed, 175 insertions, 1 deletions
diff --git a/src/passes/TypeRefining.cpp b/src/passes/TypeRefining.cpp
index 864cb81b3..9a0b289e5 100644
--- a/src/passes/TypeRefining.cpp
+++ b/src/passes/TypeRefining.cpp
@@ -353,6 +353,69 @@ struct TypeRefining : public Pass {
TypeRewriter(wasm, *this).update();
ReFinalize().run(getPassRunner(), &wasm);
+
+ // After refinalizing, we may still have situations that do not validate.
+ // In some cases we can infer something more precise than can be represented
+ // in wasm, like here:
+ //
+ // (try (result A)
+ // (struct.get ..) ;; returns B.
+ // (catch
+ // (const A)
+ // )
+ //
+ // The try body cannot throw, so the catch is never reached, and we can
+ // infer the fallthrough has the subtype B. But in wasm the type of the try
+ // must remain the supertype A. If that try is written into a StructSet that
+ // we refined, that will error.
+ //
+ // To fix this, we add a cast here, and expect that other passes will remove
+ // the cast after other optimizations simplify things (in this example, the
+ // catch can be removed).
+ struct WriteUpdater : public WalkerPass<PostWalker<WriteUpdater>> {
+ bool isFunctionParallel() override { return true; }
+
+ // Only affects struct.new/sets.
+ bool requiresNonNullableLocalFixups() override { return false; }
+
+ std::unique_ptr<Pass> create() override {
+ return std::make_unique<WriteUpdater>();
+ }
+
+ void visitStructNew(StructNew* curr) {
+ if (curr->type == Type::unreachable || curr->isWithDefault()) {
+ return;
+ }
+
+ auto& fields = curr->type.getHeapType().getStruct().fields;
+
+ for (Index i = 0; i < fields.size(); i++) {
+ auto*& operand = curr->operands[i];
+ auto fieldType = fields[i].type;
+ if (!Type::isSubType(operand->type, fieldType)) {
+ operand = Builder(*getModule()).makeRefCast(operand, fieldType);
+ }
+ }
+ }
+
+ void visitStructSet(StructSet* curr) {
+ if (curr->type == Type::unreachable) {
+ return;
+ }
+
+ auto fieldType =
+ curr->ref->type.getHeapType().getStruct().fields[curr->index].type;
+
+ if (!Type::isSubType(curr->value->type, fieldType)) {
+ curr->value =
+ Builder(*getModule()).makeRefCast(curr->value, fieldType);
+ }
+ }
+ };
+
+ WriteUpdater updater;
+ updater.run(getPassRunner(), &wasm);
+ updater.runOnModuleCode(getPassRunner(), &wasm);
}
};
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index e870d6b63..a3bf22fdc 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -876,7 +876,9 @@ public:
ret->finalize();
return ret;
}
- RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) {
+ RefCast* makeRefCast(Expression* ref,
+ Type type,
+ RefCast::Safety safety = RefCast::Safe) {
auto* ret = wasm.allocator.alloc<RefCast>();
ret->ref = ref;
ret->type = type;
diff --git a/test/lit/passes/type-refining.wast b/test/lit/passes/type-refining.wast
index 9c2241edd..fb9f29568 100644
--- a/test/lit/passes/type-refining.wast
+++ b/test/lit/passes/type-refining.wast
@@ -1222,3 +1222,112 @@
)
)
)
+
+(module
+ ;; CHECK: (rec
+ ;; CHECK-NEXT: (type $ref|$A|_externref_=>_none (func (param (ref $A) externref)))
+
+ ;; CHECK: (type $A (struct (field (mut (ref noextern)))))
+ (type $A (struct (field (mut externref))))
+
+ ;; CHECK: (type $externref_=>_anyref (func (param externref) (result anyref)))
+
+ ;; CHECK: (type $none_=>_none (func))
+
+ ;; CHECK: (type $none_=>_none (func))
+
+ ;; CHECK: (tag $tag (param))
+ (tag $tag)
+
+ ;; CHECK: (func $struct.new (type $externref_=>_anyref) (param $extern externref) (result anyref)
+ ;; CHECK-NEXT: (struct.new $A
+ ;; CHECK-NEXT: (ref.cast noextern
+ ;; CHECK-NEXT: (try $try (result externref)
+ ;; CHECK-NEXT: (do
+ ;; CHECK-NEXT: (struct.get $A 0
+ ;; CHECK-NEXT: (struct.new $A
+ ;; CHECK-NEXT: (ref.as_non_null
+ ;; CHECK-NEXT: (ref.null noextern)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (catch $tag
+ ;; CHECK-NEXT: (local.get $extern)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $struct.new (param $extern externref) (result anyref)
+ ;; A noextern is written into the struct field and then read. Note that the
+ ;; try's catch is never reached, since the body cannot throw, so the
+ ;; fallthrough of the try is the struct.get, which leads into a struct.new, so
+ ;; we have a copy of that field. For that reason TypeRefining thinks it can
+ ;; refine the type of the field from externref to noextern. However, the
+ ;; validation rule for try-catch prevents the try from being refined so,
+ ;; since the catch has to be taken into account, and it has a less refined
+ ;; type than the body.
+ ;;
+ ;; In such situations we rely on other optimizations to improve things, like
+ ;; getting rid of the catch in this case. In this pass we add a cast to get
+ ;; things to validate, which should be removable by other passes later on.
+ (struct.new $A
+ (try (result externref)
+ (do
+ (struct.get $A 0
+ (struct.new $A
+ (ref.as_non_null
+ (ref.null noextern)
+ )
+ )
+ )
+ )
+ (catch $tag
+ (local.get $extern)
+ )
+ )
+ )
+ )
+
+ ;; CHECK: (func $struct.set (type $ref|$A|_externref_=>_none) (param $ref (ref $A)) (param $extern externref)
+ ;; CHECK-NEXT: (struct.set $A 0
+ ;; CHECK-NEXT: (local.get $ref)
+ ;; CHECK-NEXT: (ref.cast noextern
+ ;; CHECK-NEXT: (try $try (result externref)
+ ;; CHECK-NEXT: (do
+ ;; CHECK-NEXT: (struct.get $A 0
+ ;; CHECK-NEXT: (struct.new $A
+ ;; CHECK-NEXT: (ref.as_non_null
+ ;; CHECK-NEXT: (ref.null noextern)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (catch $tag
+ ;; CHECK-NEXT: (local.get $extern)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $struct.set (param $ref (ref $A)) (param $extern externref)
+ (struct.set $A 0
+ (local.get $ref)
+ (try (result externref)
+ (do
+ (struct.get $A 0
+ (struct.new $A
+ (ref.as_non_null
+ (ref.null noextern)
+ )
+ )
+ )
+ )
+ (catch $tag
+ (local.get $extern)
+ )
+ )
+ )
+ )
+)