diff options
author | Thomas Lively <tlively@google.com> | 2022-10-07 08:02:09 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-07 06:02:09 -0700 |
commit | 7fc26f3e78f72ecaa5b79ebe042b95a0be422327 (patch) | |
tree | f87c84fc691aaf311fbd71c176ee37723c76ae20 /src | |
parent | e8884de3c880a7de4bb1f8eae3df5f00f4164b4d (diff) | |
download | binaryen-7fc26f3e78f72ecaa5b79ebe042b95a0be422327.tar.gz binaryen-7fc26f3e78f72ecaa5b79ebe042b95a0be422327.tar.bz2 binaryen-7fc26f3e78f72ecaa5b79ebe042b95a0be422327.zip |
Implement bottom heap types (#5115)
These types, `none`, `nofunc`, and `noextern` are uninhabited, so references to
them can only possibly be null. To simplify the IR and increase type precision,
introduce new invariants that all `ref.null` instructions must be typed with one
of these new bottom types and that `Literals` have a bottom type iff they
represent null values. These new invariants requires several additional changes.
First, it is now possible that the `ref` or `target` child of a `StructGet`,
`StructSet`, `ArrayGet`, `ArraySet`, or `CallRef` instruction has a bottom
reference type, so it is not possible to determine what heap type annotation to
emit in the binary or text formats. (The bottom types are not valid type
annotations since they do not have indices in the type section.)
To fix that problem, update the printer and binary emitter to emit unreachables
instead of the instruction with undetermined type annotation. This is a valid
transformation because the only possible value that could flow into those
instructions in that case is null, and all of those instructions trap on nulls.
That fix uncovered a latent bug in the binary parser in which new unreachables
within unreachable code were handled incorrectly. This bug was not previously
found by the fuzzer because we generally stop emitting code once we encounter an
instruction with type `unreachable`. Now, however, it is possible to emit an
`unreachable` for instructions that do not have type `unreachable` (but are
known to trap at runtime), so we will continue emitting code. See the new
test/lit/parse-double-unreachable.wast for details.
Update other miscellaneous code that creates `RefNull` expressions and null
`Literals` to maintain the new invariants as well.
Diffstat (limited to 'src')
-rw-r--r-- | src/binaryen-c.cpp | 188 | ||||
-rw-r--r-- | src/binaryen-c.h | 4 | ||||
-rw-r--r-- | src/ir/effects.h | 28 | ||||
-rw-r--r-- | src/ir/manipulation.h | 1 | ||||
-rw-r--r-- | src/ir/possible-contents.cpp | 34 | ||||
-rw-r--r-- | src/ir/struct-utils.h | 5 | ||||
-rw-r--r-- | src/literal.h | 23 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 36 | ||||
-rw-r--r-- | src/passes/JSPI.cpp | 6 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 37 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 55 | ||||
-rw-r--r-- | src/passes/TypeRefining.cpp | 2 | ||||
-rw-r--r-- | src/tools/fuzzing/fuzzing.cpp | 57 | ||||
-rw-r--r-- | src/tools/fuzzing/heap-types.cpp | 37 | ||||
-rw-r--r-- | src/tools/wasm-ctor-eval.cpp | 5 | ||||
-rw-r--r-- | src/tools/wasm-reduce.cpp | 2 | ||||
-rw-r--r-- | src/wasm-binary.h | 8 | ||||
-rw-r--r-- | src/wasm-builder.h | 38 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 7 | ||||
-rw-r--r-- | src/wasm-type.h | 13 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 146 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 168 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 30 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 29 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 156 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 102 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 9 |
28 files changed, 858 insertions, 372 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index cc7ed5b15..f2908920c 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -51,96 +51,111 @@ static_assert(sizeof(BinaryenLiteral) == sizeof(Literal), BinaryenLiteral toBinaryenLiteral(Literal x) { BinaryenLiteral ret; ret.type = x.type.getID(); - if (x.type.isRef()) { - auto heapType = x.type.getHeapType(); - if (heapType.isBasic()) { - switch (heapType.getBasic()) { - case HeapType::func: - ret.func = x.isNull() ? nullptr : x.getFunc().c_str(); - break; - case HeapType::ext: - case HeapType::eq: - assert(x.isNull() && "unexpected non-null reference type literal"); - break; - case HeapType::any: - case HeapType::i31: - case HeapType::data: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: reftypes"); - } - return ret; + assert(x.type.isSingle()); + if (x.type.isBasic()) { + switch (x.type.getBasic()) { + case Type::i32: + ret.i32 = x.geti32(); + return ret; + case Type::i64: + ret.i64 = x.geti64(); + return ret; + case Type::f32: + ret.i32 = x.reinterpreti32(); + return ret; + case Type::f64: + ret.i64 = x.reinterpreti64(); + return ret; + case Type::v128: + memcpy(&ret.v128, x.getv128Ptr(), 16); + return ret; + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("unexpected type"); } - WASM_UNREACHABLE("TODO: reftypes"); } - TODO_SINGLE_COMPOUND(x.type); - switch (x.type.getBasic()) { - case Type::i32: - ret.i32 = x.geti32(); - break; - case Type::i64: - ret.i64 = x.geti64(); - break; - case Type::f32: - ret.i32 = x.reinterpreti32(); - break; - case Type::f64: - ret.i64 = x.reinterpreti64(); - break; - case Type::v128: - memcpy(&ret.v128, x.getv128Ptr(), 16); - break; - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("unexpected type"); + assert(x.type.isRef()); + auto heapType = x.type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + WASM_UNREACHABLE("TODO: i31"); + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Null. + return ret; + } } - return ret; + if (heapType.isSignature()) { + ret.func = x.getFunc().c_str(); + return ret; + } + assert(x.isData()); + WASM_UNREACHABLE("TODO: gc data"); } Literal fromBinaryenLiteral(BinaryenLiteral x) { auto type = Type(x.type); - if (type.isRef()) { - auto heapType = type.getHeapType(); - if (type.isNullable()) { - return Literal::makeNull(heapType); + if (type.isBasic()) { + switch (type.getBasic()) { + case Type::i32: + return Literal(x.i32); + case Type::i64: + return Literal(x.i64); + case Type::f32: + return Literal(x.i32).castToF32(); + case Type::f64: + return Literal(x.i64).castToF64(); + case Type::v128: + return Literal(x.v128); + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("unexpected type"); } - if (heapType.isBasic()) { - switch (heapType.getBasic()) { - case HeapType::func: - case HeapType::any: - case HeapType::eq: - case HeapType::data: - assert(false && "Literals must have concrete types"); - WASM_UNREACHABLE("no fallthrough here"); - case HeapType::ext: - case HeapType::i31: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: reftypes"); - } + } + assert(type.isRef()); + auto heapType = type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + WASM_UNREACHABLE("TODO: i31"); + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + assert(type.isNullable()); + return Literal::makeNull(heapType); } } - assert(type.isBasic()); - switch (type.getBasic()) { - case Type::i32: - return Literal(x.i32); - case Type::i64: - return Literal(x.i64); - case Type::f32: - return Literal(x.i32).castToF32(); - case Type::f64: - return Literal(x.i64).castToF64(); - case Type::v128: - return Literal(x.v128); - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("unexpected type"); + if (heapType.isSignature()) { + return Literal::makeFunc(Name(x.func), heapType); } - WASM_UNREACHABLE("invalid type"); + assert(heapType.isData()); + WASM_UNREACHABLE("TODO: gc data"); } // Mutexes (global for now; in theory if multiple modules @@ -197,6 +212,15 @@ BinaryenType BinaryenTypeStringviewWTF16() { BinaryenType BinaryenTypeStringviewIter() { return Type(HeapType::stringview_iter, Nullable).getID(); } +BinaryenType BinaryenTypeNullref() { + return Type(HeapType::none, Nullable).getID(); +} +BinaryenType BinaryenTypeNullExternref(void) { + return Type(HeapType::noext, Nullable).getID(); +} +BinaryenType BinaryenTypeNullFuncref(void) { + return Type(HeapType::nofunc, Nullable).getID(); +} BinaryenType BinaryenTypeUnreachable(void) { return Type::unreachable; } BinaryenType BinaryenTypeAuto(void) { return uintptr_t(-1); } @@ -1484,7 +1508,8 @@ BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module, BinaryenType type) { Type type_(type); assert(type_.isNullable()); - return static_cast<Expression*>(Builder(*(Module*)module).makeRefNull(type_)); + return static_cast<Expression*>( + Builder(*(Module*)module).makeRefNull(type_.getHeapType())); } BinaryenExpressionRef BinaryenRefIs(BinaryenModuleRef module, @@ -1699,10 +1724,11 @@ BinaryenExpressionRef BinaryenArrayInit(BinaryenModuleRef module, BinaryenExpressionRef BinaryenArrayGet(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenExpressionRef index, + BinaryenType type, bool signed_) { return static_cast<Expression*>( Builder(*(Module*)module) - .makeArrayGet((Expression*)ref, (Expression*)index, signed_)); + .makeArrayGet((Expression*)ref, (Expression*)index, Type(type), signed_)); } BinaryenExpressionRef BinaryenArraySet(BinaryenModuleRef module, BinaryenExpressionRef ref, diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 5b343e8ba..9cc282721 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -109,6 +109,9 @@ BINARYEN_API BinaryenType BinaryenTypeStringref(void); BINARYEN_API BinaryenType BinaryenTypeStringviewWTF8(void); BINARYEN_API BinaryenType BinaryenTypeStringviewWTF16(void); BINARYEN_API BinaryenType BinaryenTypeStringviewIter(void); +BINARYEN_API BinaryenType BinaryenTypeNullref(void); +BINARYEN_API BinaryenType BinaryenTypeNullExternref(void); +BINARYEN_API BinaryenType BinaryenTypeNullFuncref(void); BINARYEN_API BinaryenType BinaryenTypeUnreachable(void); // Not a real type. Used as the last parameter to BinaryenBlock to let // the API figure out the type instead of providing one. @@ -1044,6 +1047,7 @@ BinaryenArrayInit(BinaryenModuleRef module, BINARYEN_API BinaryenExpressionRef BinaryenArrayGet(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenExpressionRef index, + BinaryenType type, bool signed_); BINARYEN_API BinaryenExpressionRef BinaryenArraySet(BinaryenModuleRef module, diff --git a/src/ir/effects.h b/src/ir/effects.h index 63d9fafc5..2dfe616f9 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -701,6 +701,10 @@ private: } } void visitCallRef(CallRef* curr) { + if (curr->target->type.isNull()) { + parent.trap = true; + return; + } parent.calls = true; if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { parent.throws_ = true; @@ -724,6 +728,10 @@ private: if (curr->ref->type == Type::unreachable) { return; } + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } if (curr->ref->type.getHeapType() .getStruct() .fields[curr->index] @@ -736,6 +744,10 @@ private: } } void visitStructSet(StructSet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.writesStruct = true; // traps when the arg is null if (curr->ref->type.isNullable()) { @@ -745,22 +757,38 @@ private: void visitArrayNew(ArrayNew* curr) {} void visitArrayInit(ArrayInit* curr) {} void visitArrayGet(ArrayGet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.readsArray = true; // traps when the arg is null or the index out of bounds parent.implicitTrap = true; } void visitArraySet(ArraySet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.writesArray = true; // traps when the arg is null or the index out of bounds parent.implicitTrap = true; } void visitArrayLen(ArrayLen* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } // traps when the arg is null if (curr->ref->type.isNullable()) { parent.implicitTrap = true; } } void visitArrayCopy(ArrayCopy* curr) { + if (curr->destRef->type.isNull() || curr->srcRef->type.isNull()) { + parent.trap = true; + return; + } parent.readsArray = true; parent.writesArray = true; // traps when a ref is null, or when out of bounds. diff --git a/src/ir/manipulation.h b/src/ir/manipulation.h index 54822f2bd..33c7d1bd7 100644 --- a/src/ir/manipulation.h +++ b/src/ir/manipulation.h @@ -41,6 +41,7 @@ template<typename InputType> inline Nop* nop(InputType* target) { template<typename InputType> inline RefNull* refNull(InputType* target, Type type) { + assert(type.isNullable() && type.getHeapType().isBottom()); auto* ret = convert<InputType, RefNull>(target); ret->finalize(type); return ret; diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index 7c69121e6..9ad43f090 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp @@ -47,23 +47,6 @@ void PossibleContents::combine(const PossibleContents& other) { // First handle the trivial cases of them being equal, or one of them is // None or Many. if (*this == other) { - // Nulls are a special case, since they compare equal even if their type is - // different. We would like to make this function symmetric, that is, that - // combine(a, b) == combine(b, a) (otherwise, things can be odd and we could - // get nondeterminism in the flow analysis which does not have a - // determinstic order). To fix that, pick the LUB. - if (isNull()) { - assert(other.isNull()); - auto lub = HeapType::getLeastUpperBound(type.getHeapType(), - otherType.getHeapType()); - if (!lub) { - // TODO: Remove this workaround once we have bottom types to assign to - // null literals. - value = Many(); - return; - } - value = Literal::makeNull(*lub); - } return; } if (other.isNone()) { @@ -97,10 +80,18 @@ void PossibleContents::combine(const PossibleContents& other) { // Special handling for references from here. - // Nulls are always equal to each other, even if their types differ. + if (isNull() && other.isNull()) { + // These must be nulls in different hierarchies, otherwise this would have + // been handled by the `*this == other` case above. + assert(type != otherType); + value = Many(); + return; + } + + // Nulls can be combined in by just adding nullability to a type. if (isNull() || other.isNull()) { - // Only one of them can be null here, since we already checked if *this == - // other, which would have been true had both been null. + // Only one of them can be null here, since we already handled the case + // where they were both null. assert(!isNull() || !other.isNull()); // If only one is a null, but the other's type is known exactly, then the // combination is to add nullability (if the type is *not* known exactly, @@ -797,7 +788,8 @@ struct InfoCollector // part of the main IR, which is potentially confusing during debugging, // however, which is a downside. Builder builder(*getModule()); - auto* get = builder.makeArrayGet(curr->srcRef, curr->srcIndex); + auto* get = + builder.makeArrayGet(curr->srcRef, curr->srcIndex, curr->srcRef->type); visitArrayGet(get); auto* set = builder.makeArraySet(curr->destRef, curr->destIndex, get); visitArraySet(set); diff --git a/src/ir/struct-utils.h b/src/ir/struct-utils.h index 9d02bb779..9f880985f 100644 --- a/src/ir/struct-utils.h +++ b/src/ir/struct-utils.h @@ -50,6 +50,7 @@ struct StructValuesMap : public std::unordered_map<HeapType, StructValues<T>> { // When we access an item, if it does not already exist, create it with a // vector of the right length for that type. StructValues<T>& operator[](HeapType type) { + assert(type.isStruct()); auto inserted = this->insert({type, {}}); auto& values = inserted.first->second; if (inserted.second) { @@ -159,7 +160,7 @@ struct StructScanner void visitStructSet(StructSet* curr) { auto type = curr->ref->type; - if (type == Type::unreachable) { + if (type == Type::unreachable || type.isNull()) { return; } @@ -173,7 +174,7 @@ struct StructScanner void visitStructGet(StructGet* curr) { auto type = curr->ref->type; - if (type == Type::unreachable) { + if (type == Type::unreachable || type.isNull()) { return; } diff --git a/src/literal.h b/src/literal.h index 318ab012a..7d7c778bc 100644 --- a/src/literal.h +++ b/src/literal.h @@ -79,7 +79,9 @@ public: explicit Literal(const std::array<Literal, 4>&); explicit Literal(const std::array<Literal, 2>&); explicit Literal(Name func, HeapType type) - : func(func), type(type, NonNullable) {} + : func(func), type(type, NonNullable) { + assert(type.isSignature()); + } explicit Literal(std::shared_ptr<GCData> gcData, HeapType type); Literal(const Literal& other); Literal& operator=(const Literal& other); @@ -90,21 +92,8 @@ public: bool isFunction() const { return type.isFunction(); } bool isData() const { return type.isData(); } - bool isNull() const { - if (type.isNullable()) { - if (type.isFunction()) { - return func.isNull(); - } - if (isData()) { - return !gcData; - } - if (type.getHeapType() == HeapType::i31) { - return i32 == 0; - } - return true; - } - return false; - } + bool isNull() const { return type.isNull(); } + bool isZero() const { switch (type.getBasic()) { case Type::i32: @@ -239,7 +228,7 @@ public: } } static Literal makeNull(HeapType type) { - return Literal(Type(type, Nullable)); + return Literal(Type(type.getBottom(), Nullable)); } static Literal makeFunc(Name func, HeapType type) { return Literal(func, type); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index c03b404b0..f325f70c6 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -32,6 +32,7 @@ #include "ir/branch-utils.h" #include "ir/debug.h" +#include "ir/drop.h" #include "ir/eh-utils.h" #include "ir/element-utils.h" #include "ir/literal-utils.h" @@ -251,6 +252,10 @@ struct Updater : public PostWalker<Updater> { Name returnName; bool isReturn; Builder* builder; + PassOptions& options; + + Updater(PassOptions& options) : options(options) {} + void visitReturn(Return* curr) { replaceCurrent(builder->makeBreak(returnName, curr->value)); } @@ -259,7 +264,7 @@ struct Updater : public PostWalker<Updater> { // achieve this, make the call a non-return call and add a break. This does // not cause unbounded stack growth because inlining and return calling both // avoid creating a new stack frame. - template<typename T> void handleReturnCall(T* curr, HeapType targetType) { + template<typename T> void handleReturnCall(T* curr, Type results) { if (isReturn) { // If the inlined callsite was already a return_call, then we can keep // return_calls in the inlined function rather than downgrading them. @@ -269,7 +274,7 @@ struct Updater : public PostWalker<Updater> { return; } curr->isReturn = false; - curr->type = targetType.getSignature().results; + curr->type = results; if (curr->type.isConcrete()) { replaceCurrent(builder->makeBreak(returnName, curr)); } else { @@ -278,17 +283,25 @@ struct Updater : public PostWalker<Updater> { } void visitCall(Call* curr) { if (curr->isReturn) { - handleReturnCall(curr, module->getFunction(curr->target)->type); + handleReturnCall(curr, module->getFunction(curr->target)->getResults()); } } void visitCallIndirect(CallIndirect* curr) { if (curr->isReturn) { - handleReturnCall(curr, curr->heapType); + handleReturnCall(curr, curr->heapType.getSignature().results); } } void visitCallRef(CallRef* curr) { + Type targetType = curr->target->type; + if (targetType.isNull()) { + // We don't know what type the call should return, but we can't leave it + // as a potentially-invalid return_call_ref, either. + replaceCurrent(getDroppedChildrenAndAppend( + curr, *module, options, Builder(*module).makeUnreachable())); + return; + } if (curr->isReturn) { - handleReturnCall(curr, curr->target->type.getHeapType()); + handleReturnCall(curr, targetType.getHeapType().getSignature().results); } } void visitLocalGet(LocalGet* curr) { @@ -301,8 +314,10 @@ struct Updater : public PostWalker<Updater> { // Core inlining logic. Modifies the outside function (adding locals as // needed), and returns the inlined code. -static Expression* -doInlining(Module* module, Function* into, const InliningAction& action) { +static Expression* doInlining(Module* module, + Function* into, + const InliningAction& action, + PassOptions& options) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); // Works for return_call, too @@ -337,7 +352,7 @@ doInlining(Module* module, Function* into, const InliningAction& action) { *action.callSite = block; } // Prepare to update the inlined code's locals and other things. - Updater updater; + Updater updater(options); updater.module = module; updater.returnName = block->name; updater.isReturn = call->isReturn; @@ -1002,7 +1017,7 @@ struct Inlining : public Pass { action.contents = getActuallyInlinedFunction(action.contents); // Perform the inlining and update counts. - doInlining(module, func, action); + doInlining(module, func, action, getPassOptions()); inlinedUses[inlinedName]++; inlinedInto.insert(func); assert(inlinedUses[inlinedName] <= infos[inlinedName].refs); @@ -1116,7 +1131,8 @@ struct InlineMainPass : public Pass { // No call at all. return; } - doInlining(module, main, InliningAction(callSite, originalMain)); + doInlining( + module, main, InliningAction(callSite, originalMain), getPassOptions()); } }; diff --git a/src/passes/JSPI.cpp b/src/passes/JSPI.cpp index 9660b962d..fe24d60f7 100644 --- a/src/passes/JSPI.cpp +++ b/src/passes/JSPI.cpp @@ -45,8 +45,10 @@ struct JSPI : public Pass { // Create a global to store the suspender that is passed into exported // functions and will then need to be passed out to the imported functions. Name suspender = Names::getValidGlobalName(*module, "suspender"); - module->addGlobal(builder.makeGlobal( - suspender, externref, builder.makeRefNull(externref), Builder::Mutable)); + module->addGlobal(builder.makeGlobal(suspender, + externref, + builder.makeRefNull(HeapType::noext), + Builder::Mutable)); // Keep track of already wrapped functions since they can be exported // multiple times, but only one wrapper is needed. diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 4ffd6be67..059ef28fb 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -1267,6 +1267,10 @@ struct OptimizeInstructions } void visitCallRef(CallRef* curr) { + skipNonNullCast(curr->target); + if (trapOnNull(curr, curr->target)) { + return; + } if (curr->target->type == Type::unreachable) { // The call_ref is not reached; leave this for DCE. return; @@ -1509,6 +1513,17 @@ struct OptimizeInstructions return getDroppedChildrenAndAppend(curr, result); } + bool trapOnNull(Expression* curr, Expression* ref) { + if (ref->type.isNull()) { + replaceCurrent(getDroppedChildrenAndAppend( + curr, Builder(*getModule()).makeUnreachable())); + // Propagate the unreachability. + refinalize = true; + return true; + } + return false; + } + void visitRefEq(RefEq* curr) { // The types may prove that the same reference cannot appear on both sides. auto leftType = curr->left->type; @@ -1564,10 +1579,16 @@ struct OptimizeInstructions } } - void visitStructGet(StructGet* curr) { skipNonNullCast(curr->ref); } + void visitStructGet(StructGet* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitStructSet(StructSet* curr) { skipNonNullCast(curr->ref); + if (trapOnNull(curr, curr->ref)) { + return; + } if (curr->ref->type != Type::unreachable && curr->value->type.isInteger()) { const auto& fields = curr->ref->type.getHeapType().getStruct().fields; @@ -1715,10 +1736,16 @@ struct OptimizeInstructions return true; } - void visitArrayGet(ArrayGet* curr) { skipNonNullCast(curr->ref); } + void visitArrayGet(ArrayGet* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitArraySet(ArraySet* curr) { skipNonNullCast(curr->ref); + if (trapOnNull(curr, curr->ref)) { + return; + } if (curr->ref->type != Type::unreachable && curr->value->type.isInteger()) { auto element = curr->ref->type.getHeapType().getArray().element; @@ -1726,11 +1753,15 @@ struct OptimizeInstructions } } - void visitArrayLen(ArrayLen* curr) { skipNonNullCast(curr->ref); } + void visitArrayLen(ArrayLen* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitArrayCopy(ArrayCopy* curr) { skipNonNullCast(curr->destRef); skipNonNullCast(curr->srcRef); + trapOnNull(curr, curr->destRef) || trapOnNull(curr, curr->srcRef); } bool canBeCastTo(HeapType a, HeapType b) { diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 466614d14..c90fdf167 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -130,7 +130,7 @@ public: } Flow visitStructSet(StructSet* curr) { return Flow(NONCONSTANT_FLOW); } Flow visitStructGet(StructGet* curr) { - if (curr->ref->type != Type::unreachable) { + if (curr->ref->type != Type::unreachable && !curr->ref->type.isNull()) { // If this field is immutable then we may be able to precompute this, as // if we also created the data in this function (or it was created in an // immutable global) then we know the value in the field. If it is @@ -164,7 +164,7 @@ public: } Flow visitArraySet(ArraySet* curr) { return Flow(NONCONSTANT_FLOW); } Flow visitArrayGet(ArrayGet* curr) { - if (curr->ref->type != Type::unreachable) { + if (curr->ref->type != Type::unreachable && !curr->ref->type.isNull()) { // See above with struct.get auto element = curr->ref->type.getHeapType().getArray().element; if (element.mutable_ == Immutable) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 01b004d97..7ebad3322 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -115,6 +115,15 @@ bool maybePrintRefShorthand(std::ostream& o, Type type) { case HeapType::stringview_iter: o << "stringview_iter"; return true; + case HeapType::none: + o << "nullref"; + return true; + case HeapType::noext: + o << "nullexternref"; + return true; + case HeapType::nofunc: + o << "nullfuncref"; + return true; } } return false; @@ -2058,10 +2067,17 @@ struct PrintExpressionContents } return false; } + bool printUnreachableOrNullReplacement(Expression* curr) { + if (curr->type == Type::unreachable || curr->type.isNull()) { + printMedium(o, "block"); + return true; + } + return false; + } void visitCallRef(CallRef* curr) { // TODO: Workaround if target has bottom type. - if (printUnreachableReplacement(curr->target)) { + if (printUnreachableOrNullReplacement(curr->target)) { return; } printMedium(o, curr->isReturn ? "return_call_ref " : "call_ref "); @@ -2144,7 +2160,7 @@ struct PrintExpressionContents }); } void visitStructGet(StructGet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } auto heapType = curr->ref->type.getHeapType(); @@ -2163,7 +2179,7 @@ struct PrintExpressionContents printFieldName(heapType, curr->index); } void visitStructSet(StructSet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "struct.set "); @@ -2192,7 +2208,7 @@ struct PrintExpressionContents TypeNamePrinter(o, wasm).print(curr->type.getHeapType()); } void visitArrayGet(ArrayGet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } const auto& element = curr->ref->type.getHeapType().getArray().element; @@ -2208,22 +2224,22 @@ struct PrintExpressionContents TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArraySet(ArraySet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "array.set "); TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArrayLen(ArrayLen* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "array.len "); TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArrayCopy(ArrayCopy* curr) { - if (printUnreachableReplacement(curr->srcRef) || - printUnreachableReplacement(curr->destRef)) { + if (printUnreachableOrNullReplacement(curr->srcRef) || + printUnreachableOrNullReplacement(curr->destRef)) { return; } printMedium(o, "array.copy "); @@ -2746,19 +2762,29 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { drop.value = child; printFullLine(&drop); } + Unreachable unreachable; + printFullLine(&unreachable); decIndent(); } + // This must be used for the same Expressions that use + // PrintExpressionContents::printUnreachableOrNullReplacement. + void maybePrintUnreachableOrNullReplacement(Expression* curr, Type type) { + if (type.isNull()) { + type = Type::unreachable; + } + maybePrintUnreachableReplacement(curr, type); + } void visitCallRef(CallRef* curr) { - maybePrintUnreachableReplacement(curr, curr->target->type); + maybePrintUnreachableOrNullReplacement(curr, curr->target->type); } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } void visitStructSet(StructSet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitStructGet(StructGet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitArrayNew(ArrayNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); @@ -2767,10 +2793,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { maybePrintUnreachableReplacement(curr, curr->type); } void visitArraySet(ArraySet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitArrayGet(ArrayGet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); + } + void visitArrayLen(ArrayLen* curr) { + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } // Module-level visitors void printSupertypeOr(HeapType curr, std::string noSuper) { diff --git a/src/passes/TypeRefining.cpp b/src/passes/TypeRefining.cpp index 6ce503cc0..e9aa07ca6 100644 --- a/src/passes/TypeRefining.cpp +++ b/src/passes/TypeRefining.cpp @@ -251,7 +251,7 @@ struct TypeRefining : public Pass { } void visitStructGet(StructGet* curr) { - if (curr->ref->type == Type::unreachable) { + if (curr->ref->type == Type::unreachable || curr->ref->type.isNull()) { return; } diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index 2b368a4e8..09df9ac31 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -1976,7 +1976,7 @@ Expression* TranslateToFuzzReader::makeRefFuncConst(Type type) { // to add a ref.as_non_null to validate, and the code will trap when we get // here). if ((type.isNullable() && oneIn(2)) || (type.isNonNullable() && oneIn(16))) { - Expression* ret = builder.makeRefNull(Type(heapType, Nullable)); + Expression* ret = builder.makeRefNull(HeapType::nofunc); if (!type.isNullable()) { ret = builder.makeRefAs(RefAsNonNull, ret); } @@ -2000,7 +2000,7 @@ Expression* TranslateToFuzzReader::makeConst(Type type) { assert(wasm.features.hasReferenceTypes()); // With a low chance, just emit a null if that is valid. if (type.isNullable() && oneIn(8)) { - return builder.makeRefNull(type); + return builder.makeRefNull(type.getHeapType()); } if (type.getHeapType().isBasic()) { return makeConstBasicRef(type); @@ -2050,7 +2050,7 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { // a subtype of anyref, but we cannot create constants of it, except // for null. assert(type.isNullable()); - return builder.makeRefNull(type); + return builder.makeRefNull(HeapType::none); } auto nullability = getSubType(type.getNullability()); // i31.new is not allowed in initializer expressions. @@ -2065,7 +2065,7 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { case HeapType::i31: { assert(wasm.features.hasGC()); if (type.isNullable() && oneIn(4)) { - return builder.makeRefNull(type); + return builder.makeRefNull(HeapType::none); } return builder.makeI31New(makeConst(Type::i32)); } @@ -2086,10 +2086,22 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { return builder.makeArrayInit(trivialArray, {}); } } - default: { - WASM_UNREACHABLE("invalid basic ref type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: { + auto null = builder.makeRefNull(heapType); + if (!type.isNullable()) { + return builder.makeRefAs(RefAsNonNull, null); + } + return null; } } + WASM_UNREACHABLE("invalid basic ref type"); } Expression* TranslateToFuzzReader::makeConstCompoundRef(Type type) { @@ -2104,15 +2116,14 @@ Expression* TranslateToFuzzReader::makeConstCompoundRef(Type type) { // We weren't able to directly materialize a non-null constant. Try again to // create a null. if (type.isNullable()) { - return builder.makeRefNull(type); + return builder.makeRefNull(heapType); } // We have to produce a non-null value. Possibly create a null and cast it // to non-null even though that will trap at runtime. We must have a // function context for this because the cast is not allowed in globals. if (funcContext) { - return builder.makeRefAs(RefAsNonNull, - builder.makeRefNull(Type(heapType, Nullable))); + return builder.makeRefAs(RefAsNonNull, builder.makeRefNull(heapType)); } // Otherwise, we are not in a function context. This can happen if we need @@ -3138,33 +3149,49 @@ Nullability TranslateToFuzzReader::getSubType(Nullability nullability) { } HeapType TranslateToFuzzReader::getSubType(HeapType type) { + if (oneIn(2)) { + return type; + } if (type.isBasic()) { switch (type.getBasic()) { case HeapType::func: // TODO: Typed function references. - return HeapType::func; + return pick(FeatureOptions<HeapType>() + .add(FeatureSet::ReferenceTypes, HeapType::func) + .add(FeatureSet::GC, HeapType::nofunc)); case HeapType::ext: - return HeapType::ext; + return pick(FeatureOptions<HeapType>() + .add(FeatureSet::ReferenceTypes, HeapType::ext) + .add(FeatureSet::GC, HeapType::noext)); case HeapType::any: // TODO: nontrivial types as well. assert(wasm.features.hasReferenceTypes()); assert(wasm.features.hasGC()); - return pick(HeapType::any, HeapType::eq, HeapType::i31, HeapType::data); + return pick(HeapType::any, + HeapType::eq, + HeapType::i31, + HeapType::data, + HeapType::none); case HeapType::eq: // TODO: nontrivial types as well. assert(wasm.features.hasReferenceTypes()); assert(wasm.features.hasGC()); - return pick(HeapType::eq, HeapType::i31, HeapType::data); + return pick( + HeapType::eq, HeapType::i31, HeapType::data, HeapType::none); case HeapType::i31: - return HeapType::i31; + return pick(HeapType::i31, HeapType::none); case HeapType::data: // TODO: nontrivial types as well. - return HeapType::data; + return pick(HeapType::data, HeapType::none); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + break; } } // TODO: nontrivial types as well. diff --git a/src/tools/fuzzing/heap-types.cpp b/src/tools/fuzzing/heap-types.cpp index ad87c6fab..351035f93 100644 --- a/src/tools/fuzzing/heap-types.cpp +++ b/src/tools/fuzzing/heap-types.cpp @@ -69,7 +69,7 @@ struct HeapTypeGeneratorImpl { typeIndices.insert({builder[i], i}); // Everything is a subtype of itself. subtypeIndices[i].push_back(i); - if (i < numRoots) { + if (i < numRoots || rand.oneIn(2)) { // This is a root type with no supertype. Choose a kind for this type. typeKinds.emplace_back(generateHeapTypeKind()); } else { @@ -148,6 +148,11 @@ struct HeapTypeGeneratorImpl { } HeapType::BasicHeapType generateBasicHeapType() { + // Choose bottom types more rarely. + if (rand.oneIn(16)) { + return rand.pick(HeapType::noext, HeapType::nofunc, HeapType::none); + } + // TODO: strings return rand.pick(HeapType::func, HeapType::ext, HeapType::any, @@ -254,6 +259,8 @@ struct HeapTypeGeneratorImpl { HeapType pickSubFunc() { if (auto type = pickKind<SignatureKind>()) { return *type; + } else if (rand.oneIn(2)) { + return HeapType::nofunc; } else { return HeapType::func; } @@ -262,6 +269,8 @@ struct HeapTypeGeneratorImpl { HeapType pickSubData() { if (auto type = pickKind<DataKind>()) { return *type; + } else if (rand.oneIn(2)) { + return HeapType::none; } else { return HeapType::data; } @@ -292,7 +301,7 @@ struct HeapTypeGeneratorImpl { // can only choose those defined before the end of the current recursion // group. std::vector<Index> candidateIndices; - for (auto i : subtypeIndices[typeIndices[type]]) { + for (auto i : subtypeIndices[it->second]) { if (i < recGroupEnds[index]) { candidateIndices.push_back(i); } @@ -301,6 +310,9 @@ struct HeapTypeGeneratorImpl { } else { // This is not a constructed type, so it must be a basic type. assert(type.isBasic()); + if (rand.oneIn(8)) { + return type.getBottom(); + } switch (type.getBasic()) { case HeapType::ext: return HeapType::ext; @@ -318,7 +330,10 @@ struct HeapTypeGeneratorImpl { case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + return type; } WASM_UNREACHABLE("unexpected kind"); } @@ -403,6 +418,17 @@ struct HeapTypeGeneratorImpl { } HeapTypeKind getSubKind(HeapTypeKind super) { + if (rand.oneIn(16)) { + // Occasionally go directly to the bottom type. + if (auto* basic = std::get_if<BasicKind>(&super)) { + return HeapType(*basic).getBottom(); + } else if (std::get_if<SignatureKind>(&super)) { + return HeapType::nofunc; + } else if (std::get_if<DataKind>(&super)) { + return HeapType::none; + } + WASM_UNREACHABLE("unexpected kind"); + } if (auto* basic = std::get_if<BasicKind>(&super)) { if (rand.oneIn(8)) { return super; @@ -441,7 +467,10 @@ struct HeapTypeGeneratorImpl { case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + return super; } WASM_UNREACHABLE("unexpected kind"); } else { diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index 5e9874ccc..45cc63c37 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -553,10 +553,7 @@ public: // This is GC data, which we must handle in a more careful way. auto* data = value.getGCData().get(); - if (!data) { - // This is a null, so simply emit one. - return builder.makeRefNull(value.type); - } + assert(data); // There was actual GC data allocated here. auto type = value.type; diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 6febcf6b9..557e8a770 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -1142,7 +1142,7 @@ struct Reducer } // try to replace with a trivial value if (curr->type.isNullable()) { - RefNull* n = builder->makeRefNull(curr->type); + RefNull* n = builder->makeRefNull(curr->type.getHeapType()); return tryToReplaceCurrent(n); } if (curr->type.isTuple() && curr->type.isDefaultable()) { diff --git a/src/wasm-binary.h b/src/wasm-binary.h index e9ad665c4..ca21662e6 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -382,6 +382,10 @@ enum EncodedType { stringview_wtf8 = -0x1d, // 0x63 stringview_wtf16 = -0x1e, // 0x62 stringview_iter = -0x1f, // 0x61 + // bottom types + nullexternref = -0x17, // 0x69 + nullfuncref = -0x18, // 0x68 + nullref = -0x1b, // 0x65 // type forms Func = -0x20, // 0x60 Struct = -0x21, // 0x5f @@ -411,6 +415,10 @@ enum EncodedHeapType { stringview_wtf8_heap = -0x1d, // 0x63 stringview_wtf16_heap = -0x1e, // 0x62 stringview_iter_heap = -0x1f, // 0x61 + // bottom types + noext = -0x17, // 0x69 + nofunc = -0x18, // 0x68 + none = -0x1b, // 0x65 }; namespace UserSections { diff --git a/src/wasm-builder.h b/src/wasm-builder.h index cc3b99138..b07af4627 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -699,10 +699,11 @@ public: } RefNull* makeRefNull(HeapType type) { auto* ret = wasm.allocator.alloc<RefNull>(); - ret->finalize(Type(type, Nullable)); + ret->finalize(Type(type.getBottom(), Nullable)); return ret; } RefNull* makeRefNull(Type type) { + assert(type.isNullable() && type.isNull()); auto* ret = wasm.allocator.alloc<RefNull>(); ret->finalize(type); return ret; @@ -942,11 +943,14 @@ public: ret->finalize(); return ret; } - ArrayGet* - makeArrayGet(Expression* ref, Expression* index, bool signed_ = false) { + ArrayGet* makeArrayGet(Expression* ref, + Expression* index, + Type type, + bool signed_ = false) { auto* ret = wasm.allocator.alloc<ArrayGet>(); ret->ref = ref; ret->index = index; + ret->type = type; ret->signed_ = signed_; ret->finalize(); return ret; @@ -1262,7 +1266,7 @@ public: if (curr->type.isTuple() && curr->type.isDefaultable()) { return makeConstantExpression(Literal::makeZeros(curr->type)); } - if (curr->type.isNullable()) { + if (curr->type.isNullable() && curr->type.isNull()) { return ExpressionManipulator::refNull(curr, curr->type); } if (curr->type.isRef() && curr->type.getHeapType() == HeapType::i31) { @@ -1329,18 +1333,24 @@ public: Expression* validateAndMakeCallRef(Expression* target, const T& args, bool isReturn = false) { - if (!target->type.isRef()) { - if (target->type == Type::unreachable) { - // An unreachable target is not supported. Similiar to br_on_cast, just - // emit an unreachable sequence, since we don't have enough information - // to create a full call_ref. - auto* block = makeBlock(args); - block->list.push_back(target); - block->finalize(Type::unreachable); - return block; - } + if (target->type != Type::unreachable && !target->type.isRef()) { throw ParseException("Non-reference type for a call_ref", line, col); } + // TODO: This won't be necessary once type annotations are mandatory on + // call_ref. + if (target->type == Type::unreachable || + target->type.getHeapType() == HeapType::nofunc) { + // An unreachable target is not supported. Similiar to br_on_cast, just + // emit an unreachable sequence, since we don't have enough information + // to create a full call_ref. + std::vector<Expression*> children; + for (auto* arg : args) { + children.push_back(makeDrop(arg)); + } + children.push_back(makeDrop(target)); + children.push_back(makeUnreachable()); + return makeBlock(children, Type::unreachable); + } auto heapType = target->type.getHeapType(); if (!heapType.isSignature()) { throw ParseException("Invalid reference type for a call_ref", line, col); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 69434a297..10b84a82f 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1350,12 +1350,11 @@ public: case RefIsNull: return Literal(value.isNull()); case RefIsFunc: - return Literal(!value.isNull() && value.type.isFunction()); + return Literal(value.type.isFunction()); case RefIsData: - return Literal(!value.isNull() && value.isData()); + return Literal(value.isData()); case RefIsI31: - return Literal(!value.isNull() && - value.type.getHeapType() == HeapType::i31); + return Literal(value.type.getHeapType() == HeapType::i31); default: WASM_UNREACHABLE("unimplemented ref.is_*"); } diff --git a/src/wasm-type.h b/src/wasm-type.h index 940009cbe..dbdbe3d31 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -169,6 +169,8 @@ public: // is irrelevant. (For that reason, this is only the negation of isNullable() // on references, but both return false on non-references.) bool isNonNullable() const; + // Whether this type is only inhabited by null values. + bool isNull() const; bool isStruct() const; bool isArray() const; bool isDefaultable() const; @@ -326,8 +328,11 @@ public: stringview_wtf8, stringview_wtf16, stringview_iter, + none, + noext, + nofunc, }; - static constexpr BasicHeapType _last_basic_type = stringview_iter; + static constexpr BasicHeapType _last_basic_type = nofunc; // BasicHeapType can be implicitly upgraded to HeapType constexpr HeapType(BasicHeapType id) : id(id) {} @@ -358,6 +363,7 @@ public: bool isSignature() const; bool isStruct() const; bool isArray() const; + bool isBottom() const; Signature getSignature() const; const Struct& getStruct() const; @@ -371,6 +377,9 @@ public: // number of supertypes in its supertype chain. size_t getDepth() const; + // Get the bottom heap type for this heap type's hierarchy. + BasicHeapType getBottom() const; + // Get the recursion group for this non-basic type. RecGroup getRecGroup() const; size_t getRecGroupIndex() const; @@ -421,6 +430,8 @@ public: std::string toString() const; }; +inline bool Type::isNull() const { return isRef() && getHeapType().isBottom(); } + // A recursion group consisting of one or more HeapTypes. HeapTypes with single // members are encoded without using any additional memory, which is why // `getHeapTypes` has to return a vector by value; it might have to create one diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 43f407525..09022eea0 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -44,19 +44,25 @@ Literal::Literal(Type type) : type(type) { memset(&v128, 0, 16); return; case Type::none: - return; case Type::unreachable: - break; + WASM_UNREACHABLE("Invalid literal type"); + return; } } - if (isData()) { - assert(!type.isNonNullable()); + if (type.isNull()) { + assert(type.isNullable()); new (&gcData) std::shared_ptr<GCData>(); - } else { - // For anything else, zero out all the union data. - memset(&v128, 0, 16); + return; } + + if (type.isRef() && type.getHeapType() == HeapType::i31) { + assert(type.isNonNullable()); + i32 = 0; + return; + } + + WASM_UNREACHABLE("Unexpected literal type"); } Literal::Literal(const uint8_t init[16]) : type(Type::v128) { @@ -64,9 +70,9 @@ Literal::Literal(const uint8_t init[16]) : type(Type::v128) { } Literal::Literal(std::shared_ptr<GCData> gcData, HeapType type) - : gcData(gcData), type(type, gcData ? NonNullable : Nullable) { + : gcData(gcData), type(type, NonNullable) { // The type must be a proper type for GC data. - assert(isData()); + assert((isData() && gcData) || (type.isBottom() && !gcData)); } Literal::Literal(const Literal& other) : type(other.type) { @@ -89,6 +95,10 @@ Literal::Literal(const Literal& other) : type(other.type) { break; } } + if (other.isNull()) { + new (&gcData) std::shared_ptr<GCData>(); + return; + } if (other.isData()) { new (&gcData) std::shared_ptr<GCData>(other.gcData); return; @@ -98,23 +108,30 @@ Literal::Literal(const Literal& other) : type(other.type) { return; } if (type.isRef()) { + assert(!type.isNullable()); auto heapType = type.getHeapType(); if (heapType.isBasic()) { switch (heapType.getBasic()) { - case HeapType::ext: - case HeapType::any: - case HeapType::eq: - return; // null case HeapType::i31: i32 = other.i32; return; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Null + return; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: case HeapType::func: case HeapType::data: + WASM_UNREACHABLE("invalid type"); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("invalid type"); + WASM_UNREACHABLE("TODO: string literals"); } } } @@ -125,7 +142,7 @@ Literal::~Literal() { if (type.isBasic()) { return; } - if (isData()) { + if (isNull() || isData()) { gcData.~shared_ptr(); } } @@ -239,7 +256,7 @@ std::array<uint8_t, 16> Literal::getv128() const { } std::shared_ptr<GCData> Literal::getGCData() const { - assert(isData()); + assert(isNull() || isData()); return gcData; } @@ -325,11 +342,6 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { - // The types must be identical, unless both are references - in that case, - // nulls of different types *do* compare equal. - if (type.isRef() && other.type.isRef() && (isNull() || other.isNull())) { - return isNull() && other.isNull(); - } if (type != other.type) { return false; } @@ -350,7 +362,9 @@ bool Literal::operator==(const Literal& other) const { } } else if (type.isRef()) { assert(type.isRef()); - // Note that we've already handled nulls earlier. + if (type.isNull()) { + return true; + } if (type.isFunction()) { assert(func.is() && other.func.is()); return func == other.func; @@ -361,8 +375,6 @@ bool Literal::operator==(const Literal& other) const { if (type.getHeapType() == HeapType::i31) { return i32 == other.i32; } - // other non-null reference type literals cannot represent concrete values, - // i.e. there is no concrete anyref or eqref other than null. WASM_UNREACHABLE("unexpected type"); } WASM_UNREACHABLE("unexpected type"); @@ -463,52 +475,8 @@ void Literal::printVec128(std::ostream& o, const std::array<uint8_t, 16>& v) { std::ostream& operator<<(std::ostream& o, Literal literal) { prepareMinorColor(o); - if (literal.type.isFunction()) { - if (literal.isNull()) { - o << "funcref(null)"; - } else { - o << "funcref(" << literal.getFunc() << ")"; - } - } else if (literal.type.isRef()) { - if (literal.isData()) { - auto data = literal.getGCData(); - if (data) { - o << "[ref " << data->type << ' ' << data->values << ']'; - } else { - o << "[ref null " << literal.type << ']'; - } - } else { - switch (literal.type.getHeapType().getBasic()) { - case HeapType::ext: - assert(literal.isNull() && "unexpected non-null externref literal"); - o << "externref(null)"; - break; - case HeapType::any: - assert(literal.isNull() && "unexpected non-null anyref literal"); - o << "anyref(null)"; - break; - case HeapType::eq: - assert(literal.isNull() && "unexpected non-null eqref literal"); - o << "eqref(null)"; - break; - case HeapType::i31: - if (literal.isNull()) { - o << "i31ref(null)"; - } else { - o << "i31ref(" << literal.geti31() << ")"; - } - break; - case HeapType::func: - case HeapType::data: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("type should have been handled above"); - } - } - } else { - TODO_SINGLE_COMPOUND(literal.type); + assert(literal.type.isSingle()); + if (literal.type.isBasic()) { switch (literal.type.getBasic()) { case Type::none: o << "?"; @@ -532,6 +500,44 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { case Type::unreachable: WASM_UNREACHABLE("unexpected type"); } + } else { + assert(literal.type.isRef()); + auto heapType = literal.type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + o << "i31ref(" << literal.geti31() << ")"; + break; + case HeapType::none: + o << "nullref"; + break; + case HeapType::noext: + o << "nullexternref"; + break; + case HeapType::nofunc: + o << "nullfuncref"; + break; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + } + } else if (heapType.isSignature()) { + o << "funcref(" << literal.getFunc() << ")"; + } else { + assert(literal.isData()); + auto data = literal.getGCData(); + assert(data); + o << "[ref " << data->type << ' ' << data->values << ']'; + } } restoreNormalColor(o); return o; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 55dafadd4..f2698bd79 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1429,6 +1429,25 @@ void WasmBinaryWriter::writeType(Type type) { case HeapType::stringview_iter: o << S32LEB(BinaryConsts::EncodedType::stringview_iter); return; + case HeapType::none: + o << S32LEB(BinaryConsts::EncodedType::nullref); + return; + case HeapType::noext: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::externref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullexternref); + } + return; + case HeapType::nofunc: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::funcref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullfuncref); + } + return; } } if (type.isNullable()) { @@ -1468,46 +1487,63 @@ void WasmBinaryWriter::writeType(Type type) { } void WasmBinaryWriter::writeHeapType(HeapType type) { + // ref.null always has a bottom heap type in Binaryen IR, but those types are + // only actually valid with GC enabled. When GC is not enabled, emit the + // corresponding valid top types instead. + if (!wasm->features.hasGC()) { + if (type == HeapType::nofunc || type.isSignature()) { + type = HeapType::func; + } else if (type == HeapType::noext) { + type = HeapType::ext; + } + } + if (type.isSignature() || type.isStruct() || type.isArray()) { o << S64LEB(getTypeIndex(type)); // TODO: Actually s33 return; } int ret = 0; - if (type.isBasic()) { - switch (type.getBasic()) { - case HeapType::ext: - ret = BinaryConsts::EncodedHeapType::ext; - break; - case HeapType::func: - ret = BinaryConsts::EncodedHeapType::func; - break; - case HeapType::any: - ret = BinaryConsts::EncodedHeapType::any; - break; - case HeapType::eq: - ret = BinaryConsts::EncodedHeapType::eq; - break; - case HeapType::i31: - ret = BinaryConsts::EncodedHeapType::i31; - break; - case HeapType::data: - ret = BinaryConsts::EncodedHeapType::data; - break; - case HeapType::string: - ret = BinaryConsts::EncodedHeapType::string; - break; - case HeapType::stringview_wtf8: - ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; - break; - case HeapType::stringview_wtf16: - ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; - break; - case HeapType::stringview_iter: - ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; - break; - } - } else { - WASM_UNREACHABLE("TODO: compound GC types"); + assert(type.isBasic()); + switch (type.getBasic()) { + case HeapType::ext: + ret = BinaryConsts::EncodedHeapType::ext; + break; + case HeapType::func: + ret = BinaryConsts::EncodedHeapType::func; + break; + case HeapType::any: + ret = BinaryConsts::EncodedHeapType::any; + break; + case HeapType::eq: + ret = BinaryConsts::EncodedHeapType::eq; + break; + case HeapType::i31: + ret = BinaryConsts::EncodedHeapType::i31; + break; + case HeapType::data: + ret = BinaryConsts::EncodedHeapType::data; + break; + case HeapType::string: + ret = BinaryConsts::EncodedHeapType::string; + break; + case HeapType::stringview_wtf8: + ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; + break; + case HeapType::stringview_wtf16: + ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; + break; + case HeapType::stringview_iter: + ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; + break; + case HeapType::none: + ret = BinaryConsts::EncodedHeapType::none; + break; + case HeapType::noext: + ret = BinaryConsts::EncodedHeapType::noext; + break; + case HeapType::nofunc: + ret = BinaryConsts::EncodedHeapType::nofunc; + break; } o << S64LEB(ret); // TODO: Actually s33 } @@ -1867,6 +1903,15 @@ bool WasmBinaryBuilder::getBasicType(int32_t code, Type& out) { case BinaryConsts::EncodedType::stringview_iter: out = Type(HeapType::stringview_iter, Nullable); return true; + case BinaryConsts::EncodedType::nullref: + out = Type(HeapType::none, Nullable); + return true; + case BinaryConsts::EncodedType::nullexternref: + out = Type(HeapType::noext, Nullable); + return true; + case BinaryConsts::EncodedType::nullfuncref: + out = Type(HeapType::nofunc, Nullable); + return true; default: return false; } @@ -1904,6 +1949,15 @@ bool WasmBinaryBuilder::getBasicHeapType(int64_t code, HeapType& out) { case BinaryConsts::EncodedHeapType::stringview_iter_heap: out = HeapType::stringview_iter; return true; + case BinaryConsts::EncodedHeapType::none: + out = HeapType::none; + return true; + case BinaryConsts::EncodedHeapType::noext: + out = HeapType::noext; + return true; + case BinaryConsts::EncodedHeapType::nofunc: + out = HeapType::nofunc; + return true; default: return false; } @@ -2849,7 +2903,14 @@ void WasmBinaryBuilder::skipUnreachableCode() { expressionStack = savedStack; return; } - pushExpression(curr); + if (curr->type == Type::unreachable) { + // Nothing before this unreachable should be available to future + // expressions. They will get `(unreachable)`s if they try to pop past + // this point. + expressionStack.clear(); + } else { + pushExpression(curr); + } } } @@ -6530,7 +6591,7 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { void WasmBinaryBuilder::visitRefNull(RefNull* curr) { BYN_TRACE("zz node: RefNull\n"); - curr->finalize(getHeapType()); + curr->finalize(getHeapType().getBottom()); } void WasmBinaryBuilder::visitRefIs(RefIs* curr, uint8_t code) { @@ -6941,28 +7002,29 @@ bool WasmBinaryBuilder::maybeVisitStructNew(Expression*& out, uint32_t code) { } bool WasmBinaryBuilder::maybeVisitStructGet(Expression*& out, uint32_t code) { - StructGet* curr; + bool signed_ = false; switch (code) { case BinaryConsts::StructGet: - curr = allocator.alloc<StructGet>(); + case BinaryConsts::StructGetU: break; case BinaryConsts::StructGetS: - curr = allocator.alloc<StructGet>(); - curr->signed_ = true; - break; - case BinaryConsts::StructGetU: - curr = allocator.alloc<StructGet>(); - curr->signed_ = false; + signed_ = true; break; default: return false; } auto heapType = getIndexedHeapType(); - curr->index = getU32LEB(); - curr->ref = popNonVoidExpression(); - validateHeapTypeUsingChild(curr->ref, heapType); - curr->finalize(); - out = curr; + if (!heapType.isStruct()) { + throwError("Expected struct heaptype"); + } + auto index = getU32LEB(); + if (index >= heapType.getStruct().fields.size()) { + throwError("Struct field index out of bounds"); + } + auto type = heapType.getStruct().fields[index].type; + auto ref = popNonVoidExpression(); + validateHeapTypeUsingChild(ref, heapType); + out = Builder(wasm).makeStructGet(index, ref, type, signed_); return true; } @@ -7022,10 +7084,14 @@ bool WasmBinaryBuilder::maybeVisitArrayGet(Expression*& out, uint32_t code) { return false; } auto heapType = getIndexedHeapType(); + if (!heapType.isArray()) { + throwError("Expected array heaptype"); + } + auto type = heapType.getArray().element.type; auto* index = popNonVoidExpression(); auto* ref = popNonVoidExpression(); validateHeapTypeUsingChild(ref, heapType); - out = Builder(wasm).makeArrayGet(ref, index, signed_); + out = Builder(wasm).makeArrayGet(ref, index, type, signed_); return true; } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 6de8744a6..b23419071 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1196,6 +1196,15 @@ Type SExpressionWasmBuilder::stringToType(const char* str, if (strncmp(str, "stringview_iter", 15) == 0 && (prefix || str[15] == 0)) { return Type(HeapType::stringview_iter, Nullable); } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return Type(HeapType::none, Nullable); + } + if (strncmp(str, "nullexternref", 13) == 0 && (prefix || str[13] == 0)) { + return Type(HeapType::noext, Nullable); + } + if (strncmp(str, "nullfuncref", 11) == 0 && (prefix || str[11] == 0)) { + return Type(HeapType::nofunc, Nullable); + } if (allowError) { return Type::none; } @@ -1249,6 +1258,17 @@ HeapType SExpressionWasmBuilder::stringToHeapType(const char* str, return HeapType::stringview_iter; } } + if (str[0] == 'n') { + if (strncmp(str, "none", 4) == 0 && (prefix || str[4] == 0)) { + return HeapType::none; + } + if (strncmp(str, "noextern", 8) == 0 && (prefix || str[8] == 0)) { + return HeapType::noext; + } + if (strncmp(str, "nofunc", 6) == 0 && (prefix || str[6] == 0)) { + return HeapType::nofunc; + } + } throw ParseException(std::string("invalid wasm heap type: ") + str); } @@ -2615,9 +2635,9 @@ Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { // (ref.null func), or it may be the name of a defined type, such as // (ref.null $struct.FOO) if (s[1]->dollared()) { - ret->finalize(parseHeapType(*s[1])); + ret->finalize(parseHeapType(*s[1]).getBottom()); } else { - ret->finalize(stringToHeapType(s[1]->str())); + ret->finalize(stringToHeapType(s[1]->str()).getBottom()); } return ret; } @@ -2990,10 +3010,14 @@ Expression* SExpressionWasmBuilder::makeArrayInitStatic(Element& s) { Expression* SExpressionWasmBuilder::makeArrayGet(Element& s, bool signed_) { auto heapType = parseHeapType(*s[1]); + if (!heapType.isArray()) { + throw ParseException("bad array heap type", s.line, s.col); + } auto ref = parseExpression(*s[2]); + auto type = heapType.getArray().element.type; validateHeapTypeUsingChild(ref, heapType, s); auto index = parseExpression(*s[3]); - return Builder(wasm).makeArrayGet(ref, index, signed_); + return Builder(wasm).makeArrayGet(ref, index, type, signed_); } Expression* SExpressionWasmBuilder::makeArraySet(Element& s) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 71bc98928..13d85d338 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2014,7 +2014,10 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { void BinaryInstWriter::visitCallRef(CallRef* curr) { assert(curr->target->type != Type::unreachable); - // TODO: `emitUnreachable` if target has bottom type. + if (curr->target->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef : BinaryConsts::CallRef); parent.writeIndexedHeapType(curr->target->type.getHeapType()); @@ -2090,6 +2093,10 @@ void BinaryInstWriter::visitStructNew(StructNew* curr) { } void BinaryInstWriter::visitStructGet(StructGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } const auto& heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getStruct().fields[curr->index]; int8_t op; @@ -2106,6 +2113,10 @@ void BinaryInstWriter::visitStructGet(StructGet* curr) { } void BinaryInstWriter::visitStructSet(StructSet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::StructSet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); o << U32LEB(curr->index); @@ -2129,6 +2140,10 @@ void BinaryInstWriter::visitArrayInit(ArrayInit* curr) { } void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } auto heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getArray().element; int8_t op; @@ -2144,16 +2159,28 @@ void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { } void BinaryInstWriter::visitArraySet(ArraySet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArraySet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayLen(ArrayLen* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayLen); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayCopy(ArrayCopy* curr) { + if (curr->srcRef->type.isNull() || curr->destRef->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayCopy); parent.writeIndexedHeapType(curr->destRef->type.getHeapType()); parent.writeIndexedHeapType(curr->srcRef->type.getHeapType()); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index d24e42acb..43b381a1b 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -578,6 +578,15 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, if (a == b) { return a; } + if (HeapType(a).getBottom() != HeapType(b).getBottom()) { + return {}; + } + if (HeapType(a).isBottom()) { + return b; + } + if (HeapType(b).isBottom()) { + return a; + } // Canonicalize to have `a` be the lesser type. if (unsigned(a) > unsigned(b)) { std::swap(a, b); @@ -585,7 +594,7 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, switch (a) { case HeapType::ext: case HeapType::func: - return {}; + return std::nullopt; case HeapType::any: return {HeapType::any}; case HeapType::eq: @@ -604,6 +613,11 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, case HeapType::stringview_wtf16: case HeapType::stringview_iter: return {HeapType::any}; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Bottom types already handled. + break; } WASM_UNREACHABLE("unexpected basic type"); } @@ -1085,6 +1099,12 @@ FeatureSet Type::getFeatures() const { case HeapType::stringview_wtf16: case HeapType::stringview_iter: return FeatureSet::ReferenceTypes | FeatureSet::Strings; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Technically introduced in GC, but used internally as part of + // ref.null with just reference types. + return FeatureSet::ReferenceTypes; } } // Note: Technically typed function references also require the typed @@ -1360,6 +1380,29 @@ bool HeapType::isArray() const { } } +bool HeapType::isBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + case func: + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + return false; + case none: + case noext: + case nofunc: + return true; + } + } + return false; +} + Signature HeapType::getSignature() const { assert(isSignature()); return getHeapTypeInfo(*this)->signature; @@ -1420,11 +1463,52 @@ size_t HeapType::getDepth() const { case HeapType::stringview_iter: depth += 2; break; + case HeapType::none: + case HeapType::nofunc: + case HeapType::noext: + // Bottom types are infinitely deep. + depth = size_t(-1l); } } return depth; } +HeapType::BasicHeapType HeapType::getBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + return noext; + case func: + return nofunc; + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + case none: + return none; + case noext: + return noext; + case nofunc: + return nofunc; + } + } + auto* info = getHeapTypeInfo(*this); + switch (info->kind) { + case HeapTypeInfo::BasicKind: + return HeapType(info->basic).getBottom(); + case HeapTypeInfo::SignatureKind: + return nofunc; + case HeapTypeInfo::StructKind: + case HeapTypeInfo::ArrayKind: + return none; + } + WASM_UNREACHABLE("unexpected kind"); +} + bool HeapType::isSubType(HeapType left, HeapType right) { // As an optimization, in the common case do not even construct a SubTyper. if (left == right) { @@ -1451,6 +1535,15 @@ std::optional<HeapType> HeapType::getLeastUpperBound(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (getTypeSystem() == TypeSystem::Equirecursive) { return TypeBounder().getLeastUpperBound(a, b); } @@ -1653,27 +1746,34 @@ bool SubTyper::isSubType(HeapType a, HeapType b) { if (b.isBasic()) { switch (b.getBasic()) { case HeapType::ext: - return a == HeapType::ext; + return a == HeapType::noext; case HeapType::func: - return a.isSignature(); + return a == HeapType::nofunc || a.isSignature(); case HeapType::any: - return a != HeapType::ext && !a.isFunction(); + return a == HeapType::eq || a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::eq: - return a == HeapType::i31 || a.isData(); + return a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::i31: - return false; + return a == HeapType::none; case HeapType::data: - return a.isData(); + return a == HeapType::none || a.isData(); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: + return a == HeapType::none; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: return false; } } if (a.isBasic()) { - // Basic HeapTypes are never subtypes of compound HeapTypes. - return false; + // Basic HeapTypes are only subtypes of compound HeapTypes if they are + // bottom types. + return a == b.getBottom(); } if (typeSystem == TypeSystem::Nominal || typeSystem == TypeSystem::Isorecursive) { @@ -1823,6 +1923,15 @@ std::optional<HeapType> TypeBounder::lub(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (a.isBasic() || b.isBasic()) { return getBasicHeapTypeLUB(getBasicHeapSupertype(a), @@ -2000,12 +2109,18 @@ std::ostream& TypePrinter::print(Type type) { // Print shorthands for certain basic heap types. if (type.isNullable()) { switch (heapType.getBasic()) { + case HeapType::ext: + return os << "externref"; case HeapType::func: return os << "funcref"; case HeapType::any: return os << "anyref"; case HeapType::eq: return os << "eqref"; + case HeapType::i31: + return os << "i31ref"; + case HeapType::data: + return os << "dataref"; case HeapType::string: return os << "stringref"; case HeapType::stringview_wtf8: @@ -2014,17 +2129,12 @@ std::ostream& TypePrinter::print(Type type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; - default: - break; - } - } else { - switch (heapType.getBasic()) { - case HeapType::i31: - return os << "i31ref"; - case HeapType::data: - return os << "dataref"; - default: - break; + case HeapType::none: + return os << "nullref"; + case HeapType::noext: + return os << "nullexternref"; + case HeapType::nofunc: + return os << "nullfuncref"; } } } @@ -2063,6 +2173,12 @@ std::ostream& TypePrinter::print(HeapType type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; + case HeapType::none: + return os << "none"; + case HeapType::noext: + return os << "noextern"; + case HeapType::nofunc: + return os << "nofunc"; } } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 5e3fdc6e7..ba309ddea 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2110,13 +2110,12 @@ void FunctionValidator::visitRefNull(RefNull* curr) { shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(), curr, "ref.null requires reference-types to be enabled"); + if (!shouldBeTrue( + curr->type.isNullable(), curr, "ref.null types must be nullable")) { + return; + } shouldBeTrue( - curr->type.isNullable(), curr, "ref.null types must be nullable"); - - // The type of the null must also be valid for the features. - shouldBeTrue(curr->type.getFeatures() <= getModule()->features, - curr->type, - "ref.null type should be allowed"); + curr->type.isNull(), curr, "ref.null must have a bottom heap type"); } void FunctionValidator::visitRefIs(RefIs* curr) { @@ -2454,12 +2453,15 @@ void FunctionValidator::visitCallRef(CallRef* curr) { validateReturnCall(curr); shouldBeTrue( getModule()->features.hasGC(), curr, "call_ref requires gc to be enabled"); - if (curr->target->type != Type::unreachable) { - if (shouldBeTrue(curr->target->type.isFunction(), - curr, - "call_ref target must be a function reference")) { - validateCallParamsAndResult(curr, curr->target->type.getHeapType()); - } + if (curr->target->type == Type::unreachable || + (curr->target->type.isRef() && + curr->target->type.getHeapType() == HeapType::nofunc)) { + return; + } + if (shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference")) { + validateCallParamsAndResult(curr, curr->target->type.getHeapType()); } } @@ -2580,7 +2582,7 @@ void FunctionValidator::visitStructGet(StructGet* curr) { shouldBeTrue(getModule()->features.hasGC(), curr, "struct.get requires gc to be enabled"); - if (curr->ref->type == Type::unreachable) { + if (curr->type == Type::unreachable || curr->ref->type.isNull()) { return; } if (!shouldBeTrue(curr->ref->type.isStruct(), @@ -2610,22 +2612,28 @@ void FunctionValidator::visitStructSet(StructSet* curr) { if (curr->ref->type == Type::unreachable) { return; } - if (!shouldBeTrue(curr->ref->type.isStruct(), + if (!shouldBeTrue(curr->ref->type.isRef(), curr->ref, - "struct.set ref must be a struct")) { + "struct.set ref must be a reference type")) { return; } - if (curr->ref->type != Type::unreachable) { - const auto& fields = curr->ref->type.getHeapType().getStruct().fields; - shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); - auto& field = fields[curr->index]; - shouldBeSubType(curr->value->type, - field.type, - curr, - "struct.set must have the proper type"); - shouldBeEqual( - field.mutable_, Mutable, curr, "struct.set field must be mutable"); + auto type = curr->ref->type.getHeapType(); + if (type == HeapType::none) { + return; } + if (!shouldBeTrue( + type.isStruct(), curr->ref, "struct.set ref must be a struct")) { + return; + } + const auto& fields = type.getStruct().fields; + shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); + auto& field = fields[curr->index]; + shouldBeSubType(curr->value->type, + field.type, + curr, + "struct.set must have the proper type"); + shouldBeEqual( + field.mutable_, Mutable, curr, "struct.set field must be mutable"); } void FunctionValidator::visitArrayNew(ArrayNew* curr) { @@ -2688,7 +2696,18 @@ void FunctionValidator::visitArrayGet(ArrayGet* curr) { if (curr->type == Type::unreachable) { return; } - const auto& element = curr->ref->type.getHeapType().getArray().element; + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.get target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } + const auto& element = heapType.getArray().element; // If the type is not packed, it must be marked internally as unsigned, by // convention. if (element.type != Type::i32 || element.packedType == Field::not_packed) { @@ -2706,6 +2725,17 @@ void FunctionValidator::visitArraySet(ArraySet* curr) { if (curr->type == Type::unreachable) { return; } + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.set target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } const auto& element = curr->ref->type.getHeapType().getArray().element; shouldBeSubType(curr->value->type, element.type, @@ -2736,9 +2766,23 @@ void FunctionValidator::visitArrayCopy(ArrayCopy* curr) { if (curr->type == Type::unreachable) { return; } - const auto& srcElement = curr->srcRef->type.getHeapType().getArray().element; - const auto& destElement = - curr->destRef->type.getHeapType().getArray().element; + if (!shouldBeSubType(curr->srcRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy source should be an array reference") || + !shouldBeSubType(curr->destRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy destination should be an array reference")) { + return; + } + auto srcHeapType = curr->srcRef->type.getHeapType(); + auto destHeapType = curr->destRef->type.getHeapType(); + if (srcHeapType == HeapType::none || destHeapType == HeapType::none) { + return; + } + const auto& srcElement = srcHeapType.getArray().element; + const auto& destElement = destHeapType.getArray().element; shouldBeSubType(srcElement.type, destElement.type, curr, diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 724fc12e2..27690f43e 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -796,7 +796,10 @@ void MemoryGrow::finalize() { } } -void RefNull::finalize(HeapType heapType) { type = Type(heapType, Nullable); } +void RefNull::finalize(HeapType heapType) { + assert(heapType.isBottom()); + type = Type(heapType, Nullable); +} void RefNull::finalize(Type type_) { type = type_; } @@ -1033,7 +1036,7 @@ void StructNew::finalize() { void StructGet::finalize() { if (ref->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getStruct().fields[index].type; } } @@ -1066,7 +1069,7 @@ void ArrayInit::finalize() { void ArrayGet::finalize() { if (ref->type == Type::unreachable || index->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getArray().element.type; } } |