diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/binaryen-c.cpp | 188 | ||||
-rw-r--r-- | src/binaryen-c.h | 4 | ||||
-rw-r--r-- | src/ir/effects.h | 28 | ||||
-rw-r--r-- | src/ir/manipulation.h | 1 | ||||
-rw-r--r-- | src/ir/possible-contents.cpp | 34 | ||||
-rw-r--r-- | src/ir/struct-utils.h | 5 | ||||
-rw-r--r-- | src/literal.h | 23 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 36 | ||||
-rw-r--r-- | src/passes/JSPI.cpp | 6 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 37 | ||||
-rw-r--r-- | src/passes/Precompute.cpp | 4 | ||||
-rw-r--r-- | src/passes/Print.cpp | 55 | ||||
-rw-r--r-- | src/passes/TypeRefining.cpp | 2 | ||||
-rw-r--r-- | src/tools/fuzzing/fuzzing.cpp | 57 | ||||
-rw-r--r-- | src/tools/fuzzing/heap-types.cpp | 37 | ||||
-rw-r--r-- | src/tools/wasm-ctor-eval.cpp | 5 | ||||
-rw-r--r-- | src/tools/wasm-reduce.cpp | 2 | ||||
-rw-r--r-- | src/wasm-binary.h | 8 | ||||
-rw-r--r-- | src/wasm-builder.h | 38 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 7 | ||||
-rw-r--r-- | src/wasm-type.h | 13 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 146 | ||||
-rw-r--r-- | src/wasm/wasm-binary.cpp | 168 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 30 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 29 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 156 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 102 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 9 |
28 files changed, 858 insertions, 372 deletions
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index cc7ed5b15..f2908920c 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -51,96 +51,111 @@ static_assert(sizeof(BinaryenLiteral) == sizeof(Literal), BinaryenLiteral toBinaryenLiteral(Literal x) { BinaryenLiteral ret; ret.type = x.type.getID(); - if (x.type.isRef()) { - auto heapType = x.type.getHeapType(); - if (heapType.isBasic()) { - switch (heapType.getBasic()) { - case HeapType::func: - ret.func = x.isNull() ? nullptr : x.getFunc().c_str(); - break; - case HeapType::ext: - case HeapType::eq: - assert(x.isNull() && "unexpected non-null reference type literal"); - break; - case HeapType::any: - case HeapType::i31: - case HeapType::data: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: reftypes"); - } - return ret; + assert(x.type.isSingle()); + if (x.type.isBasic()) { + switch (x.type.getBasic()) { + case Type::i32: + ret.i32 = x.geti32(); + return ret; + case Type::i64: + ret.i64 = x.geti64(); + return ret; + case Type::f32: + ret.i32 = x.reinterpreti32(); + return ret; + case Type::f64: + ret.i64 = x.reinterpreti64(); + return ret; + case Type::v128: + memcpy(&ret.v128, x.getv128Ptr(), 16); + return ret; + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("unexpected type"); } - WASM_UNREACHABLE("TODO: reftypes"); } - TODO_SINGLE_COMPOUND(x.type); - switch (x.type.getBasic()) { - case Type::i32: - ret.i32 = x.geti32(); - break; - case Type::i64: - ret.i64 = x.geti64(); - break; - case Type::f32: - ret.i32 = x.reinterpreti32(); - break; - case Type::f64: - ret.i64 = x.reinterpreti64(); - break; - case Type::v128: - memcpy(&ret.v128, x.getv128Ptr(), 16); - break; - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("unexpected type"); + assert(x.type.isRef()); + auto heapType = x.type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + WASM_UNREACHABLE("TODO: i31"); + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Null. + return ret; + } } - return ret; + if (heapType.isSignature()) { + ret.func = x.getFunc().c_str(); + return ret; + } + assert(x.isData()); + WASM_UNREACHABLE("TODO: gc data"); } Literal fromBinaryenLiteral(BinaryenLiteral x) { auto type = Type(x.type); - if (type.isRef()) { - auto heapType = type.getHeapType(); - if (type.isNullable()) { - return Literal::makeNull(heapType); + if (type.isBasic()) { + switch (type.getBasic()) { + case Type::i32: + return Literal(x.i32); + case Type::i64: + return Literal(x.i64); + case Type::f32: + return Literal(x.i32).castToF32(); + case Type::f64: + return Literal(x.i64).castToF64(); + case Type::v128: + return Literal(x.v128); + case Type::none: + case Type::unreachable: + WASM_UNREACHABLE("unexpected type"); } - if (heapType.isBasic()) { - switch (heapType.getBasic()) { - case HeapType::func: - case HeapType::any: - case HeapType::eq: - case HeapType::data: - assert(false && "Literals must have concrete types"); - WASM_UNREACHABLE("no fallthrough here"); - case HeapType::ext: - case HeapType::i31: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: reftypes"); - } + } + assert(type.isRef()); + auto heapType = type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + WASM_UNREACHABLE("TODO: i31"); + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + assert(type.isNullable()); + return Literal::makeNull(heapType); } } - assert(type.isBasic()); - switch (type.getBasic()) { - case Type::i32: - return Literal(x.i32); - case Type::i64: - return Literal(x.i64); - case Type::f32: - return Literal(x.i32).castToF32(); - case Type::f64: - return Literal(x.i64).castToF64(); - case Type::v128: - return Literal(x.v128); - case Type::none: - case Type::unreachable: - WASM_UNREACHABLE("unexpected type"); + if (heapType.isSignature()) { + return Literal::makeFunc(Name(x.func), heapType); } - WASM_UNREACHABLE("invalid type"); + assert(heapType.isData()); + WASM_UNREACHABLE("TODO: gc data"); } // Mutexes (global for now; in theory if multiple modules @@ -197,6 +212,15 @@ BinaryenType BinaryenTypeStringviewWTF16() { BinaryenType BinaryenTypeStringviewIter() { return Type(HeapType::stringview_iter, Nullable).getID(); } +BinaryenType BinaryenTypeNullref() { + return Type(HeapType::none, Nullable).getID(); +} +BinaryenType BinaryenTypeNullExternref(void) { + return Type(HeapType::noext, Nullable).getID(); +} +BinaryenType BinaryenTypeNullFuncref(void) { + return Type(HeapType::nofunc, Nullable).getID(); +} BinaryenType BinaryenTypeUnreachable(void) { return Type::unreachable; } BinaryenType BinaryenTypeAuto(void) { return uintptr_t(-1); } @@ -1484,7 +1508,8 @@ BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module, BinaryenType type) { Type type_(type); assert(type_.isNullable()); - return static_cast<Expression*>(Builder(*(Module*)module).makeRefNull(type_)); + return static_cast<Expression*>( + Builder(*(Module*)module).makeRefNull(type_.getHeapType())); } BinaryenExpressionRef BinaryenRefIs(BinaryenModuleRef module, @@ -1699,10 +1724,11 @@ BinaryenExpressionRef BinaryenArrayInit(BinaryenModuleRef module, BinaryenExpressionRef BinaryenArrayGet(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenExpressionRef index, + BinaryenType type, bool signed_) { return static_cast<Expression*>( Builder(*(Module*)module) - .makeArrayGet((Expression*)ref, (Expression*)index, signed_)); + .makeArrayGet((Expression*)ref, (Expression*)index, Type(type), signed_)); } BinaryenExpressionRef BinaryenArraySet(BinaryenModuleRef module, BinaryenExpressionRef ref, diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 5b343e8ba..9cc282721 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -109,6 +109,9 @@ BINARYEN_API BinaryenType BinaryenTypeStringref(void); BINARYEN_API BinaryenType BinaryenTypeStringviewWTF8(void); BINARYEN_API BinaryenType BinaryenTypeStringviewWTF16(void); BINARYEN_API BinaryenType BinaryenTypeStringviewIter(void); +BINARYEN_API BinaryenType BinaryenTypeNullref(void); +BINARYEN_API BinaryenType BinaryenTypeNullExternref(void); +BINARYEN_API BinaryenType BinaryenTypeNullFuncref(void); BINARYEN_API BinaryenType BinaryenTypeUnreachable(void); // Not a real type. Used as the last parameter to BinaryenBlock to let // the API figure out the type instead of providing one. @@ -1044,6 +1047,7 @@ BinaryenArrayInit(BinaryenModuleRef module, BINARYEN_API BinaryenExpressionRef BinaryenArrayGet(BinaryenModuleRef module, BinaryenExpressionRef ref, BinaryenExpressionRef index, + BinaryenType type, bool signed_); BINARYEN_API BinaryenExpressionRef BinaryenArraySet(BinaryenModuleRef module, diff --git a/src/ir/effects.h b/src/ir/effects.h index 63d9fafc5..2dfe616f9 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -701,6 +701,10 @@ private: } } void visitCallRef(CallRef* curr) { + if (curr->target->type.isNull()) { + parent.trap = true; + return; + } parent.calls = true; if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { parent.throws_ = true; @@ -724,6 +728,10 @@ private: if (curr->ref->type == Type::unreachable) { return; } + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } if (curr->ref->type.getHeapType() .getStruct() .fields[curr->index] @@ -736,6 +744,10 @@ private: } } void visitStructSet(StructSet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.writesStruct = true; // traps when the arg is null if (curr->ref->type.isNullable()) { @@ -745,22 +757,38 @@ private: void visitArrayNew(ArrayNew* curr) {} void visitArrayInit(ArrayInit* curr) {} void visitArrayGet(ArrayGet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.readsArray = true; // traps when the arg is null or the index out of bounds parent.implicitTrap = true; } void visitArraySet(ArraySet* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } parent.writesArray = true; // traps when the arg is null or the index out of bounds parent.implicitTrap = true; } void visitArrayLen(ArrayLen* curr) { + if (curr->ref->type.isNull()) { + parent.trap = true; + return; + } // traps when the arg is null if (curr->ref->type.isNullable()) { parent.implicitTrap = true; } } void visitArrayCopy(ArrayCopy* curr) { + if (curr->destRef->type.isNull() || curr->srcRef->type.isNull()) { + parent.trap = true; + return; + } parent.readsArray = true; parent.writesArray = true; // traps when a ref is null, or when out of bounds. diff --git a/src/ir/manipulation.h b/src/ir/manipulation.h index 54822f2bd..33c7d1bd7 100644 --- a/src/ir/manipulation.h +++ b/src/ir/manipulation.h @@ -41,6 +41,7 @@ template<typename InputType> inline Nop* nop(InputType* target) { template<typename InputType> inline RefNull* refNull(InputType* target, Type type) { + assert(type.isNullable() && type.getHeapType().isBottom()); auto* ret = convert<InputType, RefNull>(target); ret->finalize(type); return ret; diff --git a/src/ir/possible-contents.cpp b/src/ir/possible-contents.cpp index 7c69121e6..9ad43f090 100644 --- a/src/ir/possible-contents.cpp +++ b/src/ir/possible-contents.cpp @@ -47,23 +47,6 @@ void PossibleContents::combine(const PossibleContents& other) { // First handle the trivial cases of them being equal, or one of them is // None or Many. if (*this == other) { - // Nulls are a special case, since they compare equal even if their type is - // different. We would like to make this function symmetric, that is, that - // combine(a, b) == combine(b, a) (otherwise, things can be odd and we could - // get nondeterminism in the flow analysis which does not have a - // determinstic order). To fix that, pick the LUB. - if (isNull()) { - assert(other.isNull()); - auto lub = HeapType::getLeastUpperBound(type.getHeapType(), - otherType.getHeapType()); - if (!lub) { - // TODO: Remove this workaround once we have bottom types to assign to - // null literals. - value = Many(); - return; - } - value = Literal::makeNull(*lub); - } return; } if (other.isNone()) { @@ -97,10 +80,18 @@ void PossibleContents::combine(const PossibleContents& other) { // Special handling for references from here. - // Nulls are always equal to each other, even if their types differ. + if (isNull() && other.isNull()) { + // These must be nulls in different hierarchies, otherwise this would have + // been handled by the `*this == other` case above. + assert(type != otherType); + value = Many(); + return; + } + + // Nulls can be combined in by just adding nullability to a type. if (isNull() || other.isNull()) { - // Only one of them can be null here, since we already checked if *this == - // other, which would have been true had both been null. + // Only one of them can be null here, since we already handled the case + // where they were both null. assert(!isNull() || !other.isNull()); // If only one is a null, but the other's type is known exactly, then the // combination is to add nullability (if the type is *not* known exactly, @@ -797,7 +788,8 @@ struct InfoCollector // part of the main IR, which is potentially confusing during debugging, // however, which is a downside. Builder builder(*getModule()); - auto* get = builder.makeArrayGet(curr->srcRef, curr->srcIndex); + auto* get = + builder.makeArrayGet(curr->srcRef, curr->srcIndex, curr->srcRef->type); visitArrayGet(get); auto* set = builder.makeArraySet(curr->destRef, curr->destIndex, get); visitArraySet(set); diff --git a/src/ir/struct-utils.h b/src/ir/struct-utils.h index 9d02bb779..9f880985f 100644 --- a/src/ir/struct-utils.h +++ b/src/ir/struct-utils.h @@ -50,6 +50,7 @@ struct StructValuesMap : public std::unordered_map<HeapType, StructValues<T>> { // When we access an item, if it does not already exist, create it with a // vector of the right length for that type. StructValues<T>& operator[](HeapType type) { + assert(type.isStruct()); auto inserted = this->insert({type, {}}); auto& values = inserted.first->second; if (inserted.second) { @@ -159,7 +160,7 @@ struct StructScanner void visitStructSet(StructSet* curr) { auto type = curr->ref->type; - if (type == Type::unreachable) { + if (type == Type::unreachable || type.isNull()) { return; } @@ -173,7 +174,7 @@ struct StructScanner void visitStructGet(StructGet* curr) { auto type = curr->ref->type; - if (type == Type::unreachable) { + if (type == Type::unreachable || type.isNull()) { return; } diff --git a/src/literal.h b/src/literal.h index 318ab012a..7d7c778bc 100644 --- a/src/literal.h +++ b/src/literal.h @@ -79,7 +79,9 @@ public: explicit Literal(const std::array<Literal, 4>&); explicit Literal(const std::array<Literal, 2>&); explicit Literal(Name func, HeapType type) - : func(func), type(type, NonNullable) {} + : func(func), type(type, NonNullable) { + assert(type.isSignature()); + } explicit Literal(std::shared_ptr<GCData> gcData, HeapType type); Literal(const Literal& other); Literal& operator=(const Literal& other); @@ -90,21 +92,8 @@ public: bool isFunction() const { return type.isFunction(); } bool isData() const { return type.isData(); } - bool isNull() const { - if (type.isNullable()) { - if (type.isFunction()) { - return func.isNull(); - } - if (isData()) { - return !gcData; - } - if (type.getHeapType() == HeapType::i31) { - return i32 == 0; - } - return true; - } - return false; - } + bool isNull() const { return type.isNull(); } + bool isZero() const { switch (type.getBasic()) { case Type::i32: @@ -239,7 +228,7 @@ public: } } static Literal makeNull(HeapType type) { - return Literal(Type(type, Nullable)); + return Literal(Type(type.getBottom(), Nullable)); } static Literal makeFunc(Name func, HeapType type) { return Literal(func, type); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index c03b404b0..f325f70c6 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -32,6 +32,7 @@ #include "ir/branch-utils.h" #include "ir/debug.h" +#include "ir/drop.h" #include "ir/eh-utils.h" #include "ir/element-utils.h" #include "ir/literal-utils.h" @@ -251,6 +252,10 @@ struct Updater : public PostWalker<Updater> { Name returnName; bool isReturn; Builder* builder; + PassOptions& options; + + Updater(PassOptions& options) : options(options) {} + void visitReturn(Return* curr) { replaceCurrent(builder->makeBreak(returnName, curr->value)); } @@ -259,7 +264,7 @@ struct Updater : public PostWalker<Updater> { // achieve this, make the call a non-return call and add a break. This does // not cause unbounded stack growth because inlining and return calling both // avoid creating a new stack frame. - template<typename T> void handleReturnCall(T* curr, HeapType targetType) { + template<typename T> void handleReturnCall(T* curr, Type results) { if (isReturn) { // If the inlined callsite was already a return_call, then we can keep // return_calls in the inlined function rather than downgrading them. @@ -269,7 +274,7 @@ struct Updater : public PostWalker<Updater> { return; } curr->isReturn = false; - curr->type = targetType.getSignature().results; + curr->type = results; if (curr->type.isConcrete()) { replaceCurrent(builder->makeBreak(returnName, curr)); } else { @@ -278,17 +283,25 @@ struct Updater : public PostWalker<Updater> { } void visitCall(Call* curr) { if (curr->isReturn) { - handleReturnCall(curr, module->getFunction(curr->target)->type); + handleReturnCall(curr, module->getFunction(curr->target)->getResults()); } } void visitCallIndirect(CallIndirect* curr) { if (curr->isReturn) { - handleReturnCall(curr, curr->heapType); + handleReturnCall(curr, curr->heapType.getSignature().results); } } void visitCallRef(CallRef* curr) { + Type targetType = curr->target->type; + if (targetType.isNull()) { + // We don't know what type the call should return, but we can't leave it + // as a potentially-invalid return_call_ref, either. + replaceCurrent(getDroppedChildrenAndAppend( + curr, *module, options, Builder(*module).makeUnreachable())); + return; + } if (curr->isReturn) { - handleReturnCall(curr, curr->target->type.getHeapType()); + handleReturnCall(curr, targetType.getHeapType().getSignature().results); } } void visitLocalGet(LocalGet* curr) { @@ -301,8 +314,10 @@ struct Updater : public PostWalker<Updater> { // Core inlining logic. Modifies the outside function (adding locals as // needed), and returns the inlined code. -static Expression* -doInlining(Module* module, Function* into, const InliningAction& action) { +static Expression* doInlining(Module* module, + Function* into, + const InliningAction& action, + PassOptions& options) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); // Works for return_call, too @@ -337,7 +352,7 @@ doInlining(Module* module, Function* into, const InliningAction& action) { *action.callSite = block; } // Prepare to update the inlined code's locals and other things. - Updater updater; + Updater updater(options); updater.module = module; updater.returnName = block->name; updater.isReturn = call->isReturn; @@ -1002,7 +1017,7 @@ struct Inlining : public Pass { action.contents = getActuallyInlinedFunction(action.contents); // Perform the inlining and update counts. - doInlining(module, func, action); + doInlining(module, func, action, getPassOptions()); inlinedUses[inlinedName]++; inlinedInto.insert(func); assert(inlinedUses[inlinedName] <= infos[inlinedName].refs); @@ -1116,7 +1131,8 @@ struct InlineMainPass : public Pass { // No call at all. return; } - doInlining(module, main, InliningAction(callSite, originalMain)); + doInlining( + module, main, InliningAction(callSite, originalMain), getPassOptions()); } }; diff --git a/src/passes/JSPI.cpp b/src/passes/JSPI.cpp index 9660b962d..fe24d60f7 100644 --- a/src/passes/JSPI.cpp +++ b/src/passes/JSPI.cpp @@ -45,8 +45,10 @@ struct JSPI : public Pass { // Create a global to store the suspender that is passed into exported // functions and will then need to be passed out to the imported functions. Name suspender = Names::getValidGlobalName(*module, "suspender"); - module->addGlobal(builder.makeGlobal( - suspender, externref, builder.makeRefNull(externref), Builder::Mutable)); + module->addGlobal(builder.makeGlobal(suspender, + externref, + builder.makeRefNull(HeapType::noext), + Builder::Mutable)); // Keep track of already wrapped functions since they can be exported // multiple times, but only one wrapper is needed. diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 4ffd6be67..059ef28fb 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -1267,6 +1267,10 @@ struct OptimizeInstructions } void visitCallRef(CallRef* curr) { + skipNonNullCast(curr->target); + if (trapOnNull(curr, curr->target)) { + return; + } if (curr->target->type == Type::unreachable) { // The call_ref is not reached; leave this for DCE. return; @@ -1509,6 +1513,17 @@ struct OptimizeInstructions return getDroppedChildrenAndAppend(curr, result); } + bool trapOnNull(Expression* curr, Expression* ref) { + if (ref->type.isNull()) { + replaceCurrent(getDroppedChildrenAndAppend( + curr, Builder(*getModule()).makeUnreachable())); + // Propagate the unreachability. + refinalize = true; + return true; + } + return false; + } + void visitRefEq(RefEq* curr) { // The types may prove that the same reference cannot appear on both sides. auto leftType = curr->left->type; @@ -1564,10 +1579,16 @@ struct OptimizeInstructions } } - void visitStructGet(StructGet* curr) { skipNonNullCast(curr->ref); } + void visitStructGet(StructGet* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitStructSet(StructSet* curr) { skipNonNullCast(curr->ref); + if (trapOnNull(curr, curr->ref)) { + return; + } if (curr->ref->type != Type::unreachable && curr->value->type.isInteger()) { const auto& fields = curr->ref->type.getHeapType().getStruct().fields; @@ -1715,10 +1736,16 @@ struct OptimizeInstructions return true; } - void visitArrayGet(ArrayGet* curr) { skipNonNullCast(curr->ref); } + void visitArrayGet(ArrayGet* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitArraySet(ArraySet* curr) { skipNonNullCast(curr->ref); + if (trapOnNull(curr, curr->ref)) { + return; + } if (curr->ref->type != Type::unreachable && curr->value->type.isInteger()) { auto element = curr->ref->type.getHeapType().getArray().element; @@ -1726,11 +1753,15 @@ struct OptimizeInstructions } } - void visitArrayLen(ArrayLen* curr) { skipNonNullCast(curr->ref); } + void visitArrayLen(ArrayLen* curr) { + skipNonNullCast(curr->ref); + trapOnNull(curr, curr->ref); + } void visitArrayCopy(ArrayCopy* curr) { skipNonNullCast(curr->destRef); skipNonNullCast(curr->srcRef); + trapOnNull(curr, curr->destRef) || trapOnNull(curr, curr->srcRef); } bool canBeCastTo(HeapType a, HeapType b) { diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 466614d14..c90fdf167 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -130,7 +130,7 @@ public: } Flow visitStructSet(StructSet* curr) { return Flow(NONCONSTANT_FLOW); } Flow visitStructGet(StructGet* curr) { - if (curr->ref->type != Type::unreachable) { + if (curr->ref->type != Type::unreachable && !curr->ref->type.isNull()) { // If this field is immutable then we may be able to precompute this, as // if we also created the data in this function (or it was created in an // immutable global) then we know the value in the field. If it is @@ -164,7 +164,7 @@ public: } Flow visitArraySet(ArraySet* curr) { return Flow(NONCONSTANT_FLOW); } Flow visitArrayGet(ArrayGet* curr) { - if (curr->ref->type != Type::unreachable) { + if (curr->ref->type != Type::unreachable && !curr->ref->type.isNull()) { // See above with struct.get auto element = curr->ref->type.getHeapType().getArray().element; if (element.mutable_ == Immutable) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 01b004d97..7ebad3322 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -115,6 +115,15 @@ bool maybePrintRefShorthand(std::ostream& o, Type type) { case HeapType::stringview_iter: o << "stringview_iter"; return true; + case HeapType::none: + o << "nullref"; + return true; + case HeapType::noext: + o << "nullexternref"; + return true; + case HeapType::nofunc: + o << "nullfuncref"; + return true; } } return false; @@ -2058,10 +2067,17 @@ struct PrintExpressionContents } return false; } + bool printUnreachableOrNullReplacement(Expression* curr) { + if (curr->type == Type::unreachable || curr->type.isNull()) { + printMedium(o, "block"); + return true; + } + return false; + } void visitCallRef(CallRef* curr) { // TODO: Workaround if target has bottom type. - if (printUnreachableReplacement(curr->target)) { + if (printUnreachableOrNullReplacement(curr->target)) { return; } printMedium(o, curr->isReturn ? "return_call_ref " : "call_ref "); @@ -2144,7 +2160,7 @@ struct PrintExpressionContents }); } void visitStructGet(StructGet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } auto heapType = curr->ref->type.getHeapType(); @@ -2163,7 +2179,7 @@ struct PrintExpressionContents printFieldName(heapType, curr->index); } void visitStructSet(StructSet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "struct.set "); @@ -2192,7 +2208,7 @@ struct PrintExpressionContents TypeNamePrinter(o, wasm).print(curr->type.getHeapType()); } void visitArrayGet(ArrayGet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } const auto& element = curr->ref->type.getHeapType().getArray().element; @@ -2208,22 +2224,22 @@ struct PrintExpressionContents TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArraySet(ArraySet* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "array.set "); TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArrayLen(ArrayLen* curr) { - if (printUnreachableReplacement(curr->ref)) { + if (printUnreachableOrNullReplacement(curr->ref)) { return; } printMedium(o, "array.len "); TypeNamePrinter(o, wasm).print(curr->ref->type.getHeapType()); } void visitArrayCopy(ArrayCopy* curr) { - if (printUnreachableReplacement(curr->srcRef) || - printUnreachableReplacement(curr->destRef)) { + if (printUnreachableOrNullReplacement(curr->srcRef) || + printUnreachableOrNullReplacement(curr->destRef)) { return; } printMedium(o, "array.copy "); @@ -2746,19 +2762,29 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { drop.value = child; printFullLine(&drop); } + Unreachable unreachable; + printFullLine(&unreachable); decIndent(); } + // This must be used for the same Expressions that use + // PrintExpressionContents::printUnreachableOrNullReplacement. + void maybePrintUnreachableOrNullReplacement(Expression* curr, Type type) { + if (type.isNull()) { + type = Type::unreachable; + } + maybePrintUnreachableReplacement(curr, type); + } void visitCallRef(CallRef* curr) { - maybePrintUnreachableReplacement(curr, curr->target->type); + maybePrintUnreachableOrNullReplacement(curr, curr->target->type); } void visitStructNew(StructNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); } void visitStructSet(StructSet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitStructGet(StructGet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitArrayNew(ArrayNew* curr) { maybePrintUnreachableReplacement(curr, curr->type); @@ -2767,10 +2793,13 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> { maybePrintUnreachableReplacement(curr, curr->type); } void visitArraySet(ArraySet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } void visitArrayGet(ArrayGet* curr) { - maybePrintUnreachableReplacement(curr, curr->ref->type); + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); + } + void visitArrayLen(ArrayLen* curr) { + maybePrintUnreachableOrNullReplacement(curr, curr->ref->type); } // Module-level visitors void printSupertypeOr(HeapType curr, std::string noSuper) { diff --git a/src/passes/TypeRefining.cpp b/src/passes/TypeRefining.cpp index 6ce503cc0..e9aa07ca6 100644 --- a/src/passes/TypeRefining.cpp +++ b/src/passes/TypeRefining.cpp @@ -251,7 +251,7 @@ struct TypeRefining : public Pass { } void visitStructGet(StructGet* curr) { - if (curr->ref->type == Type::unreachable) { + if (curr->ref->type == Type::unreachable || curr->ref->type.isNull()) { return; } diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index 2b368a4e8..09df9ac31 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -1976,7 +1976,7 @@ Expression* TranslateToFuzzReader::makeRefFuncConst(Type type) { // to add a ref.as_non_null to validate, and the code will trap when we get // here). if ((type.isNullable() && oneIn(2)) || (type.isNonNullable() && oneIn(16))) { - Expression* ret = builder.makeRefNull(Type(heapType, Nullable)); + Expression* ret = builder.makeRefNull(HeapType::nofunc); if (!type.isNullable()) { ret = builder.makeRefAs(RefAsNonNull, ret); } @@ -2000,7 +2000,7 @@ Expression* TranslateToFuzzReader::makeConst(Type type) { assert(wasm.features.hasReferenceTypes()); // With a low chance, just emit a null if that is valid. if (type.isNullable() && oneIn(8)) { - return builder.makeRefNull(type); + return builder.makeRefNull(type.getHeapType()); } if (type.getHeapType().isBasic()) { return makeConstBasicRef(type); @@ -2050,7 +2050,7 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { // a subtype of anyref, but we cannot create constants of it, except // for null. assert(type.isNullable()); - return builder.makeRefNull(type); + return builder.makeRefNull(HeapType::none); } auto nullability = getSubType(type.getNullability()); // i31.new is not allowed in initializer expressions. @@ -2065,7 +2065,7 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { case HeapType::i31: { assert(wasm.features.hasGC()); if (type.isNullable() && oneIn(4)) { - return builder.makeRefNull(type); + return builder.makeRefNull(HeapType::none); } return builder.makeI31New(makeConst(Type::i32)); } @@ -2086,10 +2086,22 @@ Expression* TranslateToFuzzReader::makeConstBasicRef(Type type) { return builder.makeArrayInit(trivialArray, {}); } } - default: { - WASM_UNREACHABLE("invalid basic ref type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: { + auto null = builder.makeRefNull(heapType); + if (!type.isNullable()) { + return builder.makeRefAs(RefAsNonNull, null); + } + return null; } } + WASM_UNREACHABLE("invalid basic ref type"); } Expression* TranslateToFuzzReader::makeConstCompoundRef(Type type) { @@ -2104,15 +2116,14 @@ Expression* TranslateToFuzzReader::makeConstCompoundRef(Type type) { // We weren't able to directly materialize a non-null constant. Try again to // create a null. if (type.isNullable()) { - return builder.makeRefNull(type); + return builder.makeRefNull(heapType); } // We have to produce a non-null value. Possibly create a null and cast it // to non-null even though that will trap at runtime. We must have a // function context for this because the cast is not allowed in globals. if (funcContext) { - return builder.makeRefAs(RefAsNonNull, - builder.makeRefNull(Type(heapType, Nullable))); + return builder.makeRefAs(RefAsNonNull, builder.makeRefNull(heapType)); } // Otherwise, we are not in a function context. This can happen if we need @@ -3138,33 +3149,49 @@ Nullability TranslateToFuzzReader::getSubType(Nullability nullability) { } HeapType TranslateToFuzzReader::getSubType(HeapType type) { + if (oneIn(2)) { + return type; + } if (type.isBasic()) { switch (type.getBasic()) { case HeapType::func: // TODO: Typed function references. - return HeapType::func; + return pick(FeatureOptions<HeapType>() + .add(FeatureSet::ReferenceTypes, HeapType::func) + .add(FeatureSet::GC, HeapType::nofunc)); case HeapType::ext: - return HeapType::ext; + return pick(FeatureOptions<HeapType>() + .add(FeatureSet::ReferenceTypes, HeapType::ext) + .add(FeatureSet::GC, HeapType::noext)); case HeapType::any: // TODO: nontrivial types as well. assert(wasm.features.hasReferenceTypes()); assert(wasm.features.hasGC()); - return pick(HeapType::any, HeapType::eq, HeapType::i31, HeapType::data); + return pick(HeapType::any, + HeapType::eq, + HeapType::i31, + HeapType::data, + HeapType::none); case HeapType::eq: // TODO: nontrivial types as well. assert(wasm.features.hasReferenceTypes()); assert(wasm.features.hasGC()); - return pick(HeapType::eq, HeapType::i31, HeapType::data); + return pick( + HeapType::eq, HeapType::i31, HeapType::data, HeapType::none); case HeapType::i31: - return HeapType::i31; + return pick(HeapType::i31, HeapType::none); case HeapType::data: // TODO: nontrivial types as well. - return HeapType::data; + return pick(HeapType::data, HeapType::none); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + break; } } // TODO: nontrivial types as well. diff --git a/src/tools/fuzzing/heap-types.cpp b/src/tools/fuzzing/heap-types.cpp index ad87c6fab..351035f93 100644 --- a/src/tools/fuzzing/heap-types.cpp +++ b/src/tools/fuzzing/heap-types.cpp @@ -69,7 +69,7 @@ struct HeapTypeGeneratorImpl { typeIndices.insert({builder[i], i}); // Everything is a subtype of itself. subtypeIndices[i].push_back(i); - if (i < numRoots) { + if (i < numRoots || rand.oneIn(2)) { // This is a root type with no supertype. Choose a kind for this type. typeKinds.emplace_back(generateHeapTypeKind()); } else { @@ -148,6 +148,11 @@ struct HeapTypeGeneratorImpl { } HeapType::BasicHeapType generateBasicHeapType() { + // Choose bottom types more rarely. + if (rand.oneIn(16)) { + return rand.pick(HeapType::noext, HeapType::nofunc, HeapType::none); + } + // TODO: strings return rand.pick(HeapType::func, HeapType::ext, HeapType::any, @@ -254,6 +259,8 @@ struct HeapTypeGeneratorImpl { HeapType pickSubFunc() { if (auto type = pickKind<SignatureKind>()) { return *type; + } else if (rand.oneIn(2)) { + return HeapType::nofunc; } else { return HeapType::func; } @@ -262,6 +269,8 @@ struct HeapTypeGeneratorImpl { HeapType pickSubData() { if (auto type = pickKind<DataKind>()) { return *type; + } else if (rand.oneIn(2)) { + return HeapType::none; } else { return HeapType::data; } @@ -292,7 +301,7 @@ struct HeapTypeGeneratorImpl { // can only choose those defined before the end of the current recursion // group. std::vector<Index> candidateIndices; - for (auto i : subtypeIndices[typeIndices[type]]) { + for (auto i : subtypeIndices[it->second]) { if (i < recGroupEnds[index]) { candidateIndices.push_back(i); } @@ -301,6 +310,9 @@ struct HeapTypeGeneratorImpl { } else { // This is not a constructed type, so it must be a basic type. assert(type.isBasic()); + if (rand.oneIn(8)) { + return type.getBottom(); + } switch (type.getBasic()) { case HeapType::ext: return HeapType::ext; @@ -318,7 +330,10 @@ struct HeapTypeGeneratorImpl { case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + return type; } WASM_UNREACHABLE("unexpected kind"); } @@ -403,6 +418,17 @@ struct HeapTypeGeneratorImpl { } HeapTypeKind getSubKind(HeapTypeKind super) { + if (rand.oneIn(16)) { + // Occasionally go directly to the bottom type. + if (auto* basic = std::get_if<BasicKind>(&super)) { + return HeapType(*basic).getBottom(); + } else if (std::get_if<SignatureKind>(&super)) { + return HeapType::nofunc; + } else if (std::get_if<DataKind>(&super)) { + return HeapType::none; + } + WASM_UNREACHABLE("unexpected kind"); + } if (auto* basic = std::get_if<BasicKind>(&super)) { if (rand.oneIn(8)) { return super; @@ -441,7 +467,10 @@ struct HeapTypeGeneratorImpl { case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("TODO: fuzz strings"); + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + return super; } WASM_UNREACHABLE("unexpected kind"); } else { diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index 5e9874ccc..45cc63c37 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -553,10 +553,7 @@ public: // This is GC data, which we must handle in a more careful way. auto* data = value.getGCData().get(); - if (!data) { - // This is a null, so simply emit one. - return builder.makeRefNull(value.type); - } + assert(data); // There was actual GC data allocated here. auto type = value.type; diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 6febcf6b9..557e8a770 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -1142,7 +1142,7 @@ struct Reducer } // try to replace with a trivial value if (curr->type.isNullable()) { - RefNull* n = builder->makeRefNull(curr->type); + RefNull* n = builder->makeRefNull(curr->type.getHeapType()); return tryToReplaceCurrent(n); } if (curr->type.isTuple() && curr->type.isDefaultable()) { diff --git a/src/wasm-binary.h b/src/wasm-binary.h index e9ad665c4..ca21662e6 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -382,6 +382,10 @@ enum EncodedType { stringview_wtf8 = -0x1d, // 0x63 stringview_wtf16 = -0x1e, // 0x62 stringview_iter = -0x1f, // 0x61 + // bottom types + nullexternref = -0x17, // 0x69 + nullfuncref = -0x18, // 0x68 + nullref = -0x1b, // 0x65 // type forms Func = -0x20, // 0x60 Struct = -0x21, // 0x5f @@ -411,6 +415,10 @@ enum EncodedHeapType { stringview_wtf8_heap = -0x1d, // 0x63 stringview_wtf16_heap = -0x1e, // 0x62 stringview_iter_heap = -0x1f, // 0x61 + // bottom types + noext = -0x17, // 0x69 + nofunc = -0x18, // 0x68 + none = -0x1b, // 0x65 }; namespace UserSections { diff --git a/src/wasm-builder.h b/src/wasm-builder.h index cc3b99138..b07af4627 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -699,10 +699,11 @@ public: } RefNull* makeRefNull(HeapType type) { auto* ret = wasm.allocator.alloc<RefNull>(); - ret->finalize(Type(type, Nullable)); + ret->finalize(Type(type.getBottom(), Nullable)); return ret; } RefNull* makeRefNull(Type type) { + assert(type.isNullable() && type.isNull()); auto* ret = wasm.allocator.alloc<RefNull>(); ret->finalize(type); return ret; @@ -942,11 +943,14 @@ public: ret->finalize(); return ret; } - ArrayGet* - makeArrayGet(Expression* ref, Expression* index, bool signed_ = false) { + ArrayGet* makeArrayGet(Expression* ref, + Expression* index, + Type type, + bool signed_ = false) { auto* ret = wasm.allocator.alloc<ArrayGet>(); ret->ref = ref; ret->index = index; + ret->type = type; ret->signed_ = signed_; ret->finalize(); return ret; @@ -1262,7 +1266,7 @@ public: if (curr->type.isTuple() && curr->type.isDefaultable()) { return makeConstantExpression(Literal::makeZeros(curr->type)); } - if (curr->type.isNullable()) { + if (curr->type.isNullable() && curr->type.isNull()) { return ExpressionManipulator::refNull(curr, curr->type); } if (curr->type.isRef() && curr->type.getHeapType() == HeapType::i31) { @@ -1329,18 +1333,24 @@ public: Expression* validateAndMakeCallRef(Expression* target, const T& args, bool isReturn = false) { - if (!target->type.isRef()) { - if (target->type == Type::unreachable) { - // An unreachable target is not supported. Similiar to br_on_cast, just - // emit an unreachable sequence, since we don't have enough information - // to create a full call_ref. - auto* block = makeBlock(args); - block->list.push_back(target); - block->finalize(Type::unreachable); - return block; - } + if (target->type != Type::unreachable && !target->type.isRef()) { throw ParseException("Non-reference type for a call_ref", line, col); } + // TODO: This won't be necessary once type annotations are mandatory on + // call_ref. + if (target->type == Type::unreachable || + target->type.getHeapType() == HeapType::nofunc) { + // An unreachable target is not supported. Similiar to br_on_cast, just + // emit an unreachable sequence, since we don't have enough information + // to create a full call_ref. + std::vector<Expression*> children; + for (auto* arg : args) { + children.push_back(makeDrop(arg)); + } + children.push_back(makeDrop(target)); + children.push_back(makeUnreachable()); + return makeBlock(children, Type::unreachable); + } auto heapType = target->type.getHeapType(); if (!heapType.isSignature()) { throw ParseException("Invalid reference type for a call_ref", line, col); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 69434a297..10b84a82f 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1350,12 +1350,11 @@ public: case RefIsNull: return Literal(value.isNull()); case RefIsFunc: - return Literal(!value.isNull() && value.type.isFunction()); + return Literal(value.type.isFunction()); case RefIsData: - return Literal(!value.isNull() && value.isData()); + return Literal(value.isData()); case RefIsI31: - return Literal(!value.isNull() && - value.type.getHeapType() == HeapType::i31); + return Literal(value.type.getHeapType() == HeapType::i31); default: WASM_UNREACHABLE("unimplemented ref.is_*"); } diff --git a/src/wasm-type.h b/src/wasm-type.h index 940009cbe..dbdbe3d31 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -169,6 +169,8 @@ public: // is irrelevant. (For that reason, this is only the negation of isNullable() // on references, but both return false on non-references.) bool isNonNullable() const; + // Whether this type is only inhabited by null values. + bool isNull() const; bool isStruct() const; bool isArray() const; bool isDefaultable() const; @@ -326,8 +328,11 @@ public: stringview_wtf8, stringview_wtf16, stringview_iter, + none, + noext, + nofunc, }; - static constexpr BasicHeapType _last_basic_type = stringview_iter; + static constexpr BasicHeapType _last_basic_type = nofunc; // BasicHeapType can be implicitly upgraded to HeapType constexpr HeapType(BasicHeapType id) : id(id) {} @@ -358,6 +363,7 @@ public: bool isSignature() const; bool isStruct() const; bool isArray() const; + bool isBottom() const; Signature getSignature() const; const Struct& getStruct() const; @@ -371,6 +377,9 @@ public: // number of supertypes in its supertype chain. size_t getDepth() const; + // Get the bottom heap type for this heap type's hierarchy. + BasicHeapType getBottom() const; + // Get the recursion group for this non-basic type. RecGroup getRecGroup() const; size_t getRecGroupIndex() const; @@ -421,6 +430,8 @@ public: std::string toString() const; }; +inline bool Type::isNull() const { return isRef() && getHeapType().isBottom(); } + // A recursion group consisting of one or more HeapTypes. HeapTypes with single // members are encoded without using any additional memory, which is why // `getHeapTypes` has to return a vector by value; it might have to create one diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 43f407525..09022eea0 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -44,19 +44,25 @@ Literal::Literal(Type type) : type(type) { memset(&v128, 0, 16); return; case Type::none: - return; case Type::unreachable: - break; + WASM_UNREACHABLE("Invalid literal type"); + return; } } - if (isData()) { - assert(!type.isNonNullable()); + if (type.isNull()) { + assert(type.isNullable()); new (&gcData) std::shared_ptr<GCData>(); - } else { - // For anything else, zero out all the union data. - memset(&v128, 0, 16); + return; } + + if (type.isRef() && type.getHeapType() == HeapType::i31) { + assert(type.isNonNullable()); + i32 = 0; + return; + } + + WASM_UNREACHABLE("Unexpected literal type"); } Literal::Literal(const uint8_t init[16]) : type(Type::v128) { @@ -64,9 +70,9 @@ Literal::Literal(const uint8_t init[16]) : type(Type::v128) { } Literal::Literal(std::shared_ptr<GCData> gcData, HeapType type) - : gcData(gcData), type(type, gcData ? NonNullable : Nullable) { + : gcData(gcData), type(type, NonNullable) { // The type must be a proper type for GC data. - assert(isData()); + assert((isData() && gcData) || (type.isBottom() && !gcData)); } Literal::Literal(const Literal& other) : type(other.type) { @@ -89,6 +95,10 @@ Literal::Literal(const Literal& other) : type(other.type) { break; } } + if (other.isNull()) { + new (&gcData) std::shared_ptr<GCData>(); + return; + } if (other.isData()) { new (&gcData) std::shared_ptr<GCData>(other.gcData); return; @@ -98,23 +108,30 @@ Literal::Literal(const Literal& other) : type(other.type) { return; } if (type.isRef()) { + assert(!type.isNullable()); auto heapType = type.getHeapType(); if (heapType.isBasic()) { switch (heapType.getBasic()) { - case HeapType::ext: - case HeapType::any: - case HeapType::eq: - return; // null case HeapType::i31: i32 = other.i32; return; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Null + return; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: case HeapType::func: case HeapType::data: + WASM_UNREACHABLE("invalid type"); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: - WASM_UNREACHABLE("invalid type"); + WASM_UNREACHABLE("TODO: string literals"); } } } @@ -125,7 +142,7 @@ Literal::~Literal() { if (type.isBasic()) { return; } - if (isData()) { + if (isNull() || isData()) { gcData.~shared_ptr(); } } @@ -239,7 +256,7 @@ std::array<uint8_t, 16> Literal::getv128() const { } std::shared_ptr<GCData> Literal::getGCData() const { - assert(isData()); + assert(isNull() || isData()); return gcData; } @@ -325,11 +342,6 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { - // The types must be identical, unless both are references - in that case, - // nulls of different types *do* compare equal. - if (type.isRef() && other.type.isRef() && (isNull() || other.isNull())) { - return isNull() && other.isNull(); - } if (type != other.type) { return false; } @@ -350,7 +362,9 @@ bool Literal::operator==(const Literal& other) const { } } else if (type.isRef()) { assert(type.isRef()); - // Note that we've already handled nulls earlier. + if (type.isNull()) { + return true; + } if (type.isFunction()) { assert(func.is() && other.func.is()); return func == other.func; @@ -361,8 +375,6 @@ bool Literal::operator==(const Literal& other) const { if (type.getHeapType() == HeapType::i31) { return i32 == other.i32; } - // other non-null reference type literals cannot represent concrete values, - // i.e. there is no concrete anyref or eqref other than null. WASM_UNREACHABLE("unexpected type"); } WASM_UNREACHABLE("unexpected type"); @@ -463,52 +475,8 @@ void Literal::printVec128(std::ostream& o, const std::array<uint8_t, 16>& v) { std::ostream& operator<<(std::ostream& o, Literal literal) { prepareMinorColor(o); - if (literal.type.isFunction()) { - if (literal.isNull()) { - o << "funcref(null)"; - } else { - o << "funcref(" << literal.getFunc() << ")"; - } - } else if (literal.type.isRef()) { - if (literal.isData()) { - auto data = literal.getGCData(); - if (data) { - o << "[ref " << data->type << ' ' << data->values << ']'; - } else { - o << "[ref null " << literal.type << ']'; - } - } else { - switch (literal.type.getHeapType().getBasic()) { - case HeapType::ext: - assert(literal.isNull() && "unexpected non-null externref literal"); - o << "externref(null)"; - break; - case HeapType::any: - assert(literal.isNull() && "unexpected non-null anyref literal"); - o << "anyref(null)"; - break; - case HeapType::eq: - assert(literal.isNull() && "unexpected non-null eqref literal"); - o << "eqref(null)"; - break; - case HeapType::i31: - if (literal.isNull()) { - o << "i31ref(null)"; - } else { - o << "i31ref(" << literal.geti31() << ")"; - } - break; - case HeapType::func: - case HeapType::data: - case HeapType::string: - case HeapType::stringview_wtf8: - case HeapType::stringview_wtf16: - case HeapType::stringview_iter: - WASM_UNREACHABLE("type should have been handled above"); - } - } - } else { - TODO_SINGLE_COMPOUND(literal.type); + assert(literal.type.isSingle()); + if (literal.type.isBasic()) { switch (literal.type.getBasic()) { case Type::none: o << "?"; @@ -532,6 +500,44 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { case Type::unreachable: WASM_UNREACHABLE("unexpected type"); } + } else { + assert(literal.type.isRef()); + auto heapType = literal.type.getHeapType(); + if (heapType.isBasic()) { + switch (heapType.getBasic()) { + case HeapType::i31: + o << "i31ref(" << literal.geti31() << ")"; + break; + case HeapType::none: + o << "nullref"; + break; + case HeapType::noext: + o << "nullexternref"; + break; + case HeapType::nofunc: + o << "nullfuncref"; + break; + case HeapType::ext: + case HeapType::any: + WASM_UNREACHABLE("TODO: extern literals"); + case HeapType::eq: + case HeapType::func: + case HeapType::data: + WASM_UNREACHABLE("invalid type"); + case HeapType::string: + case HeapType::stringview_wtf8: + case HeapType::stringview_wtf16: + case HeapType::stringview_iter: + WASM_UNREACHABLE("TODO: string literals"); + } + } else if (heapType.isSignature()) { + o << "funcref(" << literal.getFunc() << ")"; + } else { + assert(literal.isData()); + auto data = literal.getGCData(); + assert(data); + o << "[ref " << data->type << ' ' << data->values << ']'; + } } restoreNormalColor(o); return o; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 55dafadd4..f2698bd79 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -1429,6 +1429,25 @@ void WasmBinaryWriter::writeType(Type type) { case HeapType::stringview_iter: o << S32LEB(BinaryConsts::EncodedType::stringview_iter); return; + case HeapType::none: + o << S32LEB(BinaryConsts::EncodedType::nullref); + return; + case HeapType::noext: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::externref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullexternref); + } + return; + case HeapType::nofunc: + // See comment on writeHeapType. + if (!wasm->features.hasGC()) { + o << S32LEB(BinaryConsts::EncodedType::funcref); + } else { + o << S32LEB(BinaryConsts::EncodedType::nullfuncref); + } + return; } } if (type.isNullable()) { @@ -1468,46 +1487,63 @@ void WasmBinaryWriter::writeType(Type type) { } void WasmBinaryWriter::writeHeapType(HeapType type) { + // ref.null always has a bottom heap type in Binaryen IR, but those types are + // only actually valid with GC enabled. When GC is not enabled, emit the + // corresponding valid top types instead. + if (!wasm->features.hasGC()) { + if (type == HeapType::nofunc || type.isSignature()) { + type = HeapType::func; + } else if (type == HeapType::noext) { + type = HeapType::ext; + } + } + if (type.isSignature() || type.isStruct() || type.isArray()) { o << S64LEB(getTypeIndex(type)); // TODO: Actually s33 return; } int ret = 0; - if (type.isBasic()) { - switch (type.getBasic()) { - case HeapType::ext: - ret = BinaryConsts::EncodedHeapType::ext; - break; - case HeapType::func: - ret = BinaryConsts::EncodedHeapType::func; - break; - case HeapType::any: - ret = BinaryConsts::EncodedHeapType::any; - break; - case HeapType::eq: - ret = BinaryConsts::EncodedHeapType::eq; - break; - case HeapType::i31: - ret = BinaryConsts::EncodedHeapType::i31; - break; - case HeapType::data: - ret = BinaryConsts::EncodedHeapType::data; - break; - case HeapType::string: - ret = BinaryConsts::EncodedHeapType::string; - break; - case HeapType::stringview_wtf8: - ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; - break; - case HeapType::stringview_wtf16: - ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; - break; - case HeapType::stringview_iter: - ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; - break; - } - } else { - WASM_UNREACHABLE("TODO: compound GC types"); + assert(type.isBasic()); + switch (type.getBasic()) { + case HeapType::ext: + ret = BinaryConsts::EncodedHeapType::ext; + break; + case HeapType::func: + ret = BinaryConsts::EncodedHeapType::func; + break; + case HeapType::any: + ret = BinaryConsts::EncodedHeapType::any; + break; + case HeapType::eq: + ret = BinaryConsts::EncodedHeapType::eq; + break; + case HeapType::i31: + ret = BinaryConsts::EncodedHeapType::i31; + break; + case HeapType::data: + ret = BinaryConsts::EncodedHeapType::data; + break; + case HeapType::string: + ret = BinaryConsts::EncodedHeapType::string; + break; + case HeapType::stringview_wtf8: + ret = BinaryConsts::EncodedHeapType::stringview_wtf8_heap; + break; + case HeapType::stringview_wtf16: + ret = BinaryConsts::EncodedHeapType::stringview_wtf16_heap; + break; + case HeapType::stringview_iter: + ret = BinaryConsts::EncodedHeapType::stringview_iter_heap; + break; + case HeapType::none: + ret = BinaryConsts::EncodedHeapType::none; + break; + case HeapType::noext: + ret = BinaryConsts::EncodedHeapType::noext; + break; + case HeapType::nofunc: + ret = BinaryConsts::EncodedHeapType::nofunc; + break; } o << S64LEB(ret); // TODO: Actually s33 } @@ -1867,6 +1903,15 @@ bool WasmBinaryBuilder::getBasicType(int32_t code, Type& out) { case BinaryConsts::EncodedType::stringview_iter: out = Type(HeapType::stringview_iter, Nullable); return true; + case BinaryConsts::EncodedType::nullref: + out = Type(HeapType::none, Nullable); + return true; + case BinaryConsts::EncodedType::nullexternref: + out = Type(HeapType::noext, Nullable); + return true; + case BinaryConsts::EncodedType::nullfuncref: + out = Type(HeapType::nofunc, Nullable); + return true; default: return false; } @@ -1904,6 +1949,15 @@ bool WasmBinaryBuilder::getBasicHeapType(int64_t code, HeapType& out) { case BinaryConsts::EncodedHeapType::stringview_iter_heap: out = HeapType::stringview_iter; return true; + case BinaryConsts::EncodedHeapType::none: + out = HeapType::none; + return true; + case BinaryConsts::EncodedHeapType::noext: + out = HeapType::noext; + return true; + case BinaryConsts::EncodedHeapType::nofunc: + out = HeapType::nofunc; + return true; default: return false; } @@ -2849,7 +2903,14 @@ void WasmBinaryBuilder::skipUnreachableCode() { expressionStack = savedStack; return; } - pushExpression(curr); + if (curr->type == Type::unreachable) { + // Nothing before this unreachable should be available to future + // expressions. They will get `(unreachable)`s if they try to pop past + // this point. + expressionStack.clear(); + } else { + pushExpression(curr); + } } } @@ -6530,7 +6591,7 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { void WasmBinaryBuilder::visitRefNull(RefNull* curr) { BYN_TRACE("zz node: RefNull\n"); - curr->finalize(getHeapType()); + curr->finalize(getHeapType().getBottom()); } void WasmBinaryBuilder::visitRefIs(RefIs* curr, uint8_t code) { @@ -6941,28 +7002,29 @@ bool WasmBinaryBuilder::maybeVisitStructNew(Expression*& out, uint32_t code) { } bool WasmBinaryBuilder::maybeVisitStructGet(Expression*& out, uint32_t code) { - StructGet* curr; + bool signed_ = false; switch (code) { case BinaryConsts::StructGet: - curr = allocator.alloc<StructGet>(); + case BinaryConsts::StructGetU: break; case BinaryConsts::StructGetS: - curr = allocator.alloc<StructGet>(); - curr->signed_ = true; - break; - case BinaryConsts::StructGetU: - curr = allocator.alloc<StructGet>(); - curr->signed_ = false; + signed_ = true; break; default: return false; } auto heapType = getIndexedHeapType(); - curr->index = getU32LEB(); - curr->ref = popNonVoidExpression(); - validateHeapTypeUsingChild(curr->ref, heapType); - curr->finalize(); - out = curr; + if (!heapType.isStruct()) { + throwError("Expected struct heaptype"); + } + auto index = getU32LEB(); + if (index >= heapType.getStruct().fields.size()) { + throwError("Struct field index out of bounds"); + } + auto type = heapType.getStruct().fields[index].type; + auto ref = popNonVoidExpression(); + validateHeapTypeUsingChild(ref, heapType); + out = Builder(wasm).makeStructGet(index, ref, type, signed_); return true; } @@ -7022,10 +7084,14 @@ bool WasmBinaryBuilder::maybeVisitArrayGet(Expression*& out, uint32_t code) { return false; } auto heapType = getIndexedHeapType(); + if (!heapType.isArray()) { + throwError("Expected array heaptype"); + } + auto type = heapType.getArray().element.type; auto* index = popNonVoidExpression(); auto* ref = popNonVoidExpression(); validateHeapTypeUsingChild(ref, heapType); - out = Builder(wasm).makeArrayGet(ref, index, signed_); + out = Builder(wasm).makeArrayGet(ref, index, type, signed_); return true; } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 6de8744a6..b23419071 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1196,6 +1196,15 @@ Type SExpressionWasmBuilder::stringToType(const char* str, if (strncmp(str, "stringview_iter", 15) == 0 && (prefix || str[15] == 0)) { return Type(HeapType::stringview_iter, Nullable); } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return Type(HeapType::none, Nullable); + } + if (strncmp(str, "nullexternref", 13) == 0 && (prefix || str[13] == 0)) { + return Type(HeapType::noext, Nullable); + } + if (strncmp(str, "nullfuncref", 11) == 0 && (prefix || str[11] == 0)) { + return Type(HeapType::nofunc, Nullable); + } if (allowError) { return Type::none; } @@ -1249,6 +1258,17 @@ HeapType SExpressionWasmBuilder::stringToHeapType(const char* str, return HeapType::stringview_iter; } } + if (str[0] == 'n') { + if (strncmp(str, "none", 4) == 0 && (prefix || str[4] == 0)) { + return HeapType::none; + } + if (strncmp(str, "noextern", 8) == 0 && (prefix || str[8] == 0)) { + return HeapType::noext; + } + if (strncmp(str, "nofunc", 6) == 0 && (prefix || str[6] == 0)) { + return HeapType::nofunc; + } + } throw ParseException(std::string("invalid wasm heap type: ") + str); } @@ -2615,9 +2635,9 @@ Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { // (ref.null func), or it may be the name of a defined type, such as // (ref.null $struct.FOO) if (s[1]->dollared()) { - ret->finalize(parseHeapType(*s[1])); + ret->finalize(parseHeapType(*s[1]).getBottom()); } else { - ret->finalize(stringToHeapType(s[1]->str())); + ret->finalize(stringToHeapType(s[1]->str()).getBottom()); } return ret; } @@ -2990,10 +3010,14 @@ Expression* SExpressionWasmBuilder::makeArrayInitStatic(Element& s) { Expression* SExpressionWasmBuilder::makeArrayGet(Element& s, bool signed_) { auto heapType = parseHeapType(*s[1]); + if (!heapType.isArray()) { + throw ParseException("bad array heap type", s.line, s.col); + } auto ref = parseExpression(*s[2]); + auto type = heapType.getArray().element.type; validateHeapTypeUsingChild(ref, heapType, s); auto index = parseExpression(*s[3]); - return Builder(wasm).makeArrayGet(ref, index, signed_); + return Builder(wasm).makeArrayGet(ref, index, type, signed_); } Expression* SExpressionWasmBuilder::makeArraySet(Element& s) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 71bc98928..13d85d338 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -2014,7 +2014,10 @@ void BinaryInstWriter::visitI31Get(I31Get* curr) { void BinaryInstWriter::visitCallRef(CallRef* curr) { assert(curr->target->type != Type::unreachable); - // TODO: `emitUnreachable` if target has bottom type. + if (curr->target->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(curr->isReturn ? BinaryConsts::RetCallRef : BinaryConsts::CallRef); parent.writeIndexedHeapType(curr->target->type.getHeapType()); @@ -2090,6 +2093,10 @@ void BinaryInstWriter::visitStructNew(StructNew* curr) { } void BinaryInstWriter::visitStructGet(StructGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } const auto& heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getStruct().fields[curr->index]; int8_t op; @@ -2106,6 +2113,10 @@ void BinaryInstWriter::visitStructGet(StructGet* curr) { } void BinaryInstWriter::visitStructSet(StructSet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::StructSet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); o << U32LEB(curr->index); @@ -2129,6 +2140,10 @@ void BinaryInstWriter::visitArrayInit(ArrayInit* curr) { } void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } auto heapType = curr->ref->type.getHeapType(); const auto& field = heapType.getArray().element; int8_t op; @@ -2144,16 +2159,28 @@ void BinaryInstWriter::visitArrayGet(ArrayGet* curr) { } void BinaryInstWriter::visitArraySet(ArraySet* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArraySet); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayLen(ArrayLen* curr) { + if (curr->ref->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayLen); parent.writeIndexedHeapType(curr->ref->type.getHeapType()); } void BinaryInstWriter::visitArrayCopy(ArrayCopy* curr) { + if (curr->srcRef->type.isNull() || curr->destRef->type.isNull()) { + emitUnreachable(); + return; + } o << int8_t(BinaryConsts::GCPrefix) << U32LEB(BinaryConsts::ArrayCopy); parent.writeIndexedHeapType(curr->destRef->type.getHeapType()); parent.writeIndexedHeapType(curr->srcRef->type.getHeapType()); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index d24e42acb..43b381a1b 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -578,6 +578,15 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, if (a == b) { return a; } + if (HeapType(a).getBottom() != HeapType(b).getBottom()) { + return {}; + } + if (HeapType(a).isBottom()) { + return b; + } + if (HeapType(b).isBottom()) { + return a; + } // Canonicalize to have `a` be the lesser type. if (unsigned(a) > unsigned(b)) { std::swap(a, b); @@ -585,7 +594,7 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, switch (a) { case HeapType::ext: case HeapType::func: - return {}; + return std::nullopt; case HeapType::any: return {HeapType::any}; case HeapType::eq: @@ -604,6 +613,11 @@ std::optional<HeapType> getBasicHeapTypeLUB(HeapType::BasicHeapType a, case HeapType::stringview_wtf16: case HeapType::stringview_iter: return {HeapType::any}; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Bottom types already handled. + break; } WASM_UNREACHABLE("unexpected basic type"); } @@ -1085,6 +1099,12 @@ FeatureSet Type::getFeatures() const { case HeapType::stringview_wtf16: case HeapType::stringview_iter: return FeatureSet::ReferenceTypes | FeatureSet::Strings; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: + // Technically introduced in GC, but used internally as part of + // ref.null with just reference types. + return FeatureSet::ReferenceTypes; } } // Note: Technically typed function references also require the typed @@ -1360,6 +1380,29 @@ bool HeapType::isArray() const { } } +bool HeapType::isBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + case func: + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + return false; + case none: + case noext: + case nofunc: + return true; + } + } + return false; +} + Signature HeapType::getSignature() const { assert(isSignature()); return getHeapTypeInfo(*this)->signature; @@ -1420,11 +1463,52 @@ size_t HeapType::getDepth() const { case HeapType::stringview_iter: depth += 2; break; + case HeapType::none: + case HeapType::nofunc: + case HeapType::noext: + // Bottom types are infinitely deep. + depth = size_t(-1l); } } return depth; } +HeapType::BasicHeapType HeapType::getBottom() const { + if (isBasic()) { + switch (getBasic()) { + case ext: + return noext; + case func: + return nofunc; + case any: + case eq: + case i31: + case data: + case string: + case stringview_wtf8: + case stringview_wtf16: + case stringview_iter: + case none: + return none; + case noext: + return noext; + case nofunc: + return nofunc; + } + } + auto* info = getHeapTypeInfo(*this); + switch (info->kind) { + case HeapTypeInfo::BasicKind: + return HeapType(info->basic).getBottom(); + case HeapTypeInfo::SignatureKind: + return nofunc; + case HeapTypeInfo::StructKind: + case HeapTypeInfo::ArrayKind: + return none; + } + WASM_UNREACHABLE("unexpected kind"); +} + bool HeapType::isSubType(HeapType left, HeapType right) { // As an optimization, in the common case do not even construct a SubTyper. if (left == right) { @@ -1451,6 +1535,15 @@ std::optional<HeapType> HeapType::getLeastUpperBound(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (getTypeSystem() == TypeSystem::Equirecursive) { return TypeBounder().getLeastUpperBound(a, b); } @@ -1653,27 +1746,34 @@ bool SubTyper::isSubType(HeapType a, HeapType b) { if (b.isBasic()) { switch (b.getBasic()) { case HeapType::ext: - return a == HeapType::ext; + return a == HeapType::noext; case HeapType::func: - return a.isSignature(); + return a == HeapType::nofunc || a.isSignature(); case HeapType::any: - return a != HeapType::ext && !a.isFunction(); + return a == HeapType::eq || a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::eq: - return a == HeapType::i31 || a.isData(); + return a == HeapType::i31 || a == HeapType::data || + a == HeapType::none || a.isData(); case HeapType::i31: - return false; + return a == HeapType::none; case HeapType::data: - return a.isData(); + return a == HeapType::none || a.isData(); case HeapType::string: case HeapType::stringview_wtf8: case HeapType::stringview_wtf16: case HeapType::stringview_iter: + return a == HeapType::none; + case HeapType::none: + case HeapType::noext: + case HeapType::nofunc: return false; } } if (a.isBasic()) { - // Basic HeapTypes are never subtypes of compound HeapTypes. - return false; + // Basic HeapTypes are only subtypes of compound HeapTypes if they are + // bottom types. + return a == b.getBottom(); } if (typeSystem == TypeSystem::Nominal || typeSystem == TypeSystem::Isorecursive) { @@ -1823,6 +1923,15 @@ std::optional<HeapType> TypeBounder::lub(HeapType a, HeapType b) { if (a == b) { return a; } + if (a.getBottom() != b.getBottom()) { + return {}; + } + if (a.isBottom()) { + return b; + } + if (b.isBottom()) { + return a; + } if (a.isBasic() || b.isBasic()) { return getBasicHeapTypeLUB(getBasicHeapSupertype(a), @@ -2000,12 +2109,18 @@ std::ostream& TypePrinter::print(Type type) { // Print shorthands for certain basic heap types. if (type.isNullable()) { switch (heapType.getBasic()) { + case HeapType::ext: + return os << "externref"; case HeapType::func: return os << "funcref"; case HeapType::any: return os << "anyref"; case HeapType::eq: return os << "eqref"; + case HeapType::i31: + return os << "i31ref"; + case HeapType::data: + return os << "dataref"; case HeapType::string: return os << "stringref"; case HeapType::stringview_wtf8: @@ -2014,17 +2129,12 @@ std::ostream& TypePrinter::print(Type type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; - default: - break; - } - } else { - switch (heapType.getBasic()) { - case HeapType::i31: - return os << "i31ref"; - case HeapType::data: - return os << "dataref"; - default: - break; + case HeapType::none: + return os << "nullref"; + case HeapType::noext: + return os << "nullexternref"; + case HeapType::nofunc: + return os << "nullfuncref"; } } } @@ -2063,6 +2173,12 @@ std::ostream& TypePrinter::print(HeapType type) { return os << "stringview_wtf16"; case HeapType::stringview_iter: return os << "stringview_iter"; + case HeapType::none: + return os << "none"; + case HeapType::noext: + return os << "noextern"; + case HeapType::nofunc: + return os << "nofunc"; } } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 5e3fdc6e7..ba309ddea 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -2110,13 +2110,12 @@ void FunctionValidator::visitRefNull(RefNull* curr) { shouldBeTrue(!getFunction() || getModule()->features.hasReferenceTypes(), curr, "ref.null requires reference-types to be enabled"); + if (!shouldBeTrue( + curr->type.isNullable(), curr, "ref.null types must be nullable")) { + return; + } shouldBeTrue( - curr->type.isNullable(), curr, "ref.null types must be nullable"); - - // The type of the null must also be valid for the features. - shouldBeTrue(curr->type.getFeatures() <= getModule()->features, - curr->type, - "ref.null type should be allowed"); + curr->type.isNull(), curr, "ref.null must have a bottom heap type"); } void FunctionValidator::visitRefIs(RefIs* curr) { @@ -2454,12 +2453,15 @@ void FunctionValidator::visitCallRef(CallRef* curr) { validateReturnCall(curr); shouldBeTrue( getModule()->features.hasGC(), curr, "call_ref requires gc to be enabled"); - if (curr->target->type != Type::unreachable) { - if (shouldBeTrue(curr->target->type.isFunction(), - curr, - "call_ref target must be a function reference")) { - validateCallParamsAndResult(curr, curr->target->type.getHeapType()); - } + if (curr->target->type == Type::unreachable || + (curr->target->type.isRef() && + curr->target->type.getHeapType() == HeapType::nofunc)) { + return; + } + if (shouldBeTrue(curr->target->type.isFunction(), + curr, + "call_ref target must be a function reference")) { + validateCallParamsAndResult(curr, curr->target->type.getHeapType()); } } @@ -2580,7 +2582,7 @@ void FunctionValidator::visitStructGet(StructGet* curr) { shouldBeTrue(getModule()->features.hasGC(), curr, "struct.get requires gc to be enabled"); - if (curr->ref->type == Type::unreachable) { + if (curr->type == Type::unreachable || curr->ref->type.isNull()) { return; } if (!shouldBeTrue(curr->ref->type.isStruct(), @@ -2610,22 +2612,28 @@ void FunctionValidator::visitStructSet(StructSet* curr) { if (curr->ref->type == Type::unreachable) { return; } - if (!shouldBeTrue(curr->ref->type.isStruct(), + if (!shouldBeTrue(curr->ref->type.isRef(), curr->ref, - "struct.set ref must be a struct")) { + "struct.set ref must be a reference type")) { return; } - if (curr->ref->type != Type::unreachable) { - const auto& fields = curr->ref->type.getHeapType().getStruct().fields; - shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); - auto& field = fields[curr->index]; - shouldBeSubType(curr->value->type, - field.type, - curr, - "struct.set must have the proper type"); - shouldBeEqual( - field.mutable_, Mutable, curr, "struct.set field must be mutable"); + auto type = curr->ref->type.getHeapType(); + if (type == HeapType::none) { + return; } + if (!shouldBeTrue( + type.isStruct(), curr->ref, "struct.set ref must be a struct")) { + return; + } + const auto& fields = type.getStruct().fields; + shouldBeTrue(curr->index < fields.size(), curr, "bad struct.get field"); + auto& field = fields[curr->index]; + shouldBeSubType(curr->value->type, + field.type, + curr, + "struct.set must have the proper type"); + shouldBeEqual( + field.mutable_, Mutable, curr, "struct.set field must be mutable"); } void FunctionValidator::visitArrayNew(ArrayNew* curr) { @@ -2688,7 +2696,18 @@ void FunctionValidator::visitArrayGet(ArrayGet* curr) { if (curr->type == Type::unreachable) { return; } - const auto& element = curr->ref->type.getHeapType().getArray().element; + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.get target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } + const auto& element = heapType.getArray().element; // If the type is not packed, it must be marked internally as unsigned, by // convention. if (element.type != Type::i32 || element.packedType == Field::not_packed) { @@ -2706,6 +2725,17 @@ void FunctionValidator::visitArraySet(ArraySet* curr) { if (curr->type == Type::unreachable) { return; } + // TODO: array rather than data once we've implemented that. + if (!shouldBeSubType(curr->ref->type, + Type(HeapType::data, Nullable), + curr, + "array.set target should be an array reference")) { + return; + } + auto heapType = curr->ref->type.getHeapType(); + if (heapType == HeapType::none) { + return; + } const auto& element = curr->ref->type.getHeapType().getArray().element; shouldBeSubType(curr->value->type, element.type, @@ -2736,9 +2766,23 @@ void FunctionValidator::visitArrayCopy(ArrayCopy* curr) { if (curr->type == Type::unreachable) { return; } - const auto& srcElement = curr->srcRef->type.getHeapType().getArray().element; - const auto& destElement = - curr->destRef->type.getHeapType().getArray().element; + if (!shouldBeSubType(curr->srcRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy source should be an array reference") || + !shouldBeSubType(curr->destRef->type, + Type(HeapType::data, Nullable), + curr, + "array.copy destination should be an array reference")) { + return; + } + auto srcHeapType = curr->srcRef->type.getHeapType(); + auto destHeapType = curr->destRef->type.getHeapType(); + if (srcHeapType == HeapType::none || destHeapType == HeapType::none) { + return; + } + const auto& srcElement = srcHeapType.getArray().element; + const auto& destElement = destHeapType.getArray().element; shouldBeSubType(srcElement.type, destElement.type, curr, diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 724fc12e2..27690f43e 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -796,7 +796,10 @@ void MemoryGrow::finalize() { } } -void RefNull::finalize(HeapType heapType) { type = Type(heapType, Nullable); } +void RefNull::finalize(HeapType heapType) { + assert(heapType.isBottom()); + type = Type(heapType, Nullable); +} void RefNull::finalize(Type type_) { type = type_; } @@ -1033,7 +1036,7 @@ void StructNew::finalize() { void StructGet::finalize() { if (ref->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getStruct().fields[index].type; } } @@ -1066,7 +1069,7 @@ void ArrayInit::finalize() { void ArrayGet::finalize() { if (ref->type == Type::unreachable || index->type == Type::unreachable) { type = Type::unreachable; - } else { + } else if (!ref->type.isNull()) { type = ref->type.getHeapType().getArray().element.type; } } |