diff options
-rw-r--r-- | src/passes/TypeRefining.cpp | 63 | ||||
-rw-r--r-- | src/wasm-builder.h | 4 | ||||
-rw-r--r-- | test/lit/passes/type-refining.wast | 109 |
3 files changed, 175 insertions, 1 deletions
diff --git a/src/passes/TypeRefining.cpp b/src/passes/TypeRefining.cpp index 864cb81b3..9a0b289e5 100644 --- a/src/passes/TypeRefining.cpp +++ b/src/passes/TypeRefining.cpp @@ -353,6 +353,69 @@ struct TypeRefining : public Pass { TypeRewriter(wasm, *this).update(); ReFinalize().run(getPassRunner(), &wasm); + + // After refinalizing, we may still have situations that do not validate. + // In some cases we can infer something more precise than can be represented + // in wasm, like here: + // + // (try (result A) + // (struct.get ..) ;; returns B. + // (catch + // (const A) + // ) + // + // The try body cannot throw, so the catch is never reached, and we can + // infer the fallthrough has the subtype B. But in wasm the type of the try + // must remain the supertype A. If that try is written into a StructSet that + // we refined, that will error. + // + // To fix this, we add a cast here, and expect that other passes will remove + // the cast after other optimizations simplify things (in this example, the + // catch can be removed). + struct WriteUpdater : public WalkerPass<PostWalker<WriteUpdater>> { + bool isFunctionParallel() override { return true; } + + // Only affects struct.new/sets. + bool requiresNonNullableLocalFixups() override { return false; } + + std::unique_ptr<Pass> create() override { + return std::make_unique<WriteUpdater>(); + } + + void visitStructNew(StructNew* curr) { + if (curr->type == Type::unreachable || curr->isWithDefault()) { + return; + } + + auto& fields = curr->type.getHeapType().getStruct().fields; + + for (Index i = 0; i < fields.size(); i++) { + auto*& operand = curr->operands[i]; + auto fieldType = fields[i].type; + if (!Type::isSubType(operand->type, fieldType)) { + operand = Builder(*getModule()).makeRefCast(operand, fieldType); + } + } + } + + void visitStructSet(StructSet* curr) { + if (curr->type == Type::unreachable) { + return; + } + + auto fieldType = + curr->ref->type.getHeapType().getStruct().fields[curr->index].type; + + if (!Type::isSubType(curr->value->type, fieldType)) { + curr->value = + Builder(*getModule()).makeRefCast(curr->value, fieldType); + } + } + }; + + WriteUpdater updater; + updater.run(getPassRunner(), &wasm); + updater.runOnModuleCode(getPassRunner(), &wasm); } }; diff --git a/src/wasm-builder.h b/src/wasm-builder.h index e870d6b63..a3bf22fdc 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -876,7 +876,9 @@ public: ret->finalize(); return ret; } - RefCast* makeRefCast(Expression* ref, Type type, RefCast::Safety safety) { + RefCast* makeRefCast(Expression* ref, + Type type, + RefCast::Safety safety = RefCast::Safe) { auto* ret = wasm.allocator.alloc<RefCast>(); ret->ref = ref; ret->type = type; diff --git a/test/lit/passes/type-refining.wast b/test/lit/passes/type-refining.wast index 9c2241edd..fb9f29568 100644 --- a/test/lit/passes/type-refining.wast +++ b/test/lit/passes/type-refining.wast @@ -1222,3 +1222,112 @@ ) ) ) + +(module + ;; CHECK: (rec + ;; CHECK-NEXT: (type $ref|$A|_externref_=>_none (func (param (ref $A) externref))) + + ;; CHECK: (type $A (struct (field (mut (ref noextern))))) + (type $A (struct (field (mut externref)))) + + ;; CHECK: (type $externref_=>_anyref (func (param externref) (result anyref))) + + ;; CHECK: (type $none_=>_none (func)) + + ;; CHECK: (type $none_=>_none (func)) + + ;; CHECK: (tag $tag (param)) + (tag $tag) + + ;; CHECK: (func $struct.new (type $externref_=>_anyref) (param $extern externref) (result anyref) + ;; CHECK-NEXT: (struct.new $A + ;; CHECK-NEXT: (ref.cast noextern + ;; CHECK-NEXT: (try $try (result externref) + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: (struct.get $A 0 + ;; CHECK-NEXT: (struct.new $A + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (ref.null noextern) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $tag + ;; CHECK-NEXT: (local.get $extern) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $struct.new (param $extern externref) (result anyref) + ;; A noextern is written into the struct field and then read. Note that the + ;; try's catch is never reached, since the body cannot throw, so the + ;; fallthrough of the try is the struct.get, which leads into a struct.new, so + ;; we have a copy of that field. For that reason TypeRefining thinks it can + ;; refine the type of the field from externref to noextern. However, the + ;; validation rule for try-catch prevents the try from being refined so, + ;; since the catch has to be taken into account, and it has a less refined + ;; type than the body. + ;; + ;; In such situations we rely on other optimizations to improve things, like + ;; getting rid of the catch in this case. In this pass we add a cast to get + ;; things to validate, which should be removable by other passes later on. + (struct.new $A + (try (result externref) + (do + (struct.get $A 0 + (struct.new $A + (ref.as_non_null + (ref.null noextern) + ) + ) + ) + ) + (catch $tag + (local.get $extern) + ) + ) + ) + ) + + ;; CHECK: (func $struct.set (type $ref|$A|_externref_=>_none) (param $ref (ref $A)) (param $extern externref) + ;; CHECK-NEXT: (struct.set $A 0 + ;; CHECK-NEXT: (local.get $ref) + ;; CHECK-NEXT: (ref.cast noextern + ;; CHECK-NEXT: (try $try (result externref) + ;; CHECK-NEXT: (do + ;; CHECK-NEXT: (struct.get $A 0 + ;; CHECK-NEXT: (struct.new $A + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (ref.null noextern) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (catch $tag + ;; CHECK-NEXT: (local.get $extern) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $struct.set (param $ref (ref $A)) (param $extern externref) + (struct.set $A 0 + (local.get $ref) + (try (result externref) + (do + (struct.get $A 0 + (struct.new $A + (ref.as_non_null + (ref.null noextern) + ) + ) + ) + ) + (catch $tag + (local.get $extern) + ) + ) + ) + ) +) |