diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/child-typer.h | 1118 | ||||
-rw-r--r-- | src/wasm-ir-builder.h | 63 | ||||
-rw-r--r-- | src/wasm/wasm-ir-builder.cpp | 957 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 13 |
4 files changed, 1645 insertions, 506 deletions
diff --git a/src/ir/child-typer.h b/src/ir/child-typer.h new file mode 100644 index 000000000..94ef11d77 --- /dev/null +++ b/src/ir/child-typer.h @@ -0,0 +1,1118 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef wasm_ir_child_typer_h +#define wasm_ir_child_typer_h + +#include "wasm-traversal.h" +#include "wasm.h" + +namespace wasm { + +// CRTP visitor for determining constraints on the types of expression children. +// For each child of the visited expression, calls a callback with a pointer to +// the child and information on how the child is constrained. The possible +// callbacks are: +// +// noteSubtype(Expression** childp, Type type) - The child must be a subtype +// of `type`, which may be a tuple type. For children that must not produce +// values, this may be `Type::none`. This accounts for most type constraints. +// +// noteAnyType(Expression** childp) - The child may have any non-tuple type. +// Used for the children of polymorphic instructions like `drop` and `select`. +// +// noteAnyReference(Expression** childp) - The child may have any reference +// type. Used for the children of polymorphic reference instructions like +// `ref.is_null`. +// +// noteAnyTupleType(Expression** childp, size_t arity) - The child may have +// any tuple type with the given arity. Used for the children of polymorphic +// tuple instructions like `tuple.drop` and `tuple.extract`. +// +// Subclasses must additionally implement a callback for getting the type of a +// branch target. This callback will only be used when the label type is not +// passed directly as an argument to the branch visitor method (see below). +// +// Type getLabelType(Name label) +// +// Children with type `unreachable` satisfy all constraints. +// +// Constraints are determined using information that would be present in the +// binary, e.g. type annotation immediates. Many of the visitor methods take +// optional additional parameter for passing this information directly, and if +// those parameters are not used, it is an error to use this utility in cases +// where that information cannot be recovered from the IR. +// +// For example, it is an error to visit a `StructSet` expression whose `ref` +// field is unreachable or null without directly passing the heap type to the +// visitor because it is not possible for the utility to determine what type the +// value child should be in that case. +// +// Conversely, this utility does not use any information that would not be +// present in the binary, and in particular it generally does not introspect on +// the types of children. For example, it does not report the constraint that +// two non-reference children of `select` must have the same type because that +// would require inspecting the types of those children. +template<typename Subtype> struct ChildTyper : OverriddenVisitor<Subtype> { + Module& wasm; + Function* func; + + ChildTyper(Module& wasm, Function* func) : wasm(wasm), func(func) {} + + Subtype& self() { return *static_cast<Subtype*>(this); } + + void note(Expression** childp, Type type) { + self().noteSubtype(childp, type); + } + + void notePointer(Expression** ptrp, Name mem) { + note(ptrp, wasm.getMemory(mem)->indexType); + } + + void noteAny(Expression** childp) { self().noteAnyType(childp); } + + void noteAnyReference(Expression** childp) { + self().noteAnyReferenceType(childp); + } + + void noteAnyTuple(Expression** childp, size_t arity) { + self().noteAnyTupleType(childp, arity); + } + + Type getLabelType(Name label) { return self().getLabelType(label); } + + void visitNop(Nop* curr) {} + + void visitBlock(Block* curr) { + size_t n = curr->list.size(); + if (n == 0) { + return; + } + for (size_t i = 0; i < n - 1; ++i) { + note(&curr->list[i], Type::none); + } + note(&curr->list.back(), curr->type); + } + + void visitIf(If* curr) { + note(&curr->condition, Type::i32); + note(&curr->ifTrue, curr->type); + if (curr->ifFalse) { + note(&curr->ifFalse, curr->type); + } + } + + void visitLoop(Loop* curr) { note(&curr->body, curr->type); } + + void visitBreak(Break* curr, std::optional<Type> labelType = std::nullopt) { + if (!labelType) { + labelType = getLabelType(curr->name); + } + if (*labelType != Type::none) { + note(&curr->value, *labelType); + } + if (curr->condition) { + note(&curr->condition, Type::i32); + } + } + + void visitSwitch(Switch* curr, std::optional<Type> labelType = std::nullopt) { + if (!labelType) { + Type glb = getLabelType(curr->default_); + for (auto label : curr->targets) { + glb = Type::getGreatestLowerBound(glb, getLabelType(label)); + } + labelType = glb; + } + if (*labelType != Type::none) { + note(&curr->value, *labelType); + } + note(&curr->condition, Type::i32); + } + + template<typename T> void handleCall(T* curr, Type params) { + assert(params.size() == curr->operands.size()); + for (size_t i = 0; i < params.size(); ++i) { + note(&curr->operands[i], params[i]); + } + } + + void visitCall(Call* curr) { + auto params = wasm.getFunction(curr->target)->getParams(); + handleCall(curr, params); + } + + void visitCallIndirect(CallIndirect* curr) { + auto params = curr->heapType.getSignature().params; + handleCall(curr, params); + note(&curr->target, Type::i32); + } + + void visitLocalGet(LocalGet* curr) {} + + void visitLocalSet(LocalSet* curr) { + assert(func); + note(&curr->value, func->getLocalType(curr->index)); + } + + void visitGlobalGet(GlobalGet* curr) {} + + void visitGlobalSet(GlobalSet* curr) { + note(&curr->value, wasm.getGlobal(curr->name)->type); + } + + void visitLoad(Load* curr) { notePointer(&curr->ptr, curr->memory); } + + void visitStore(Store* curr) { + notePointer(&curr->ptr, curr->memory); + note(&curr->value, curr->valueType); + } + + void visitAtomicRMW(AtomicRMW* curr) { + assert(curr->type == Type::i32 || curr->type == Type::i64); + notePointer(&curr->ptr, curr->memory); + note(&curr->value, curr->type); + } + + void visitAtomicCmpxchg(AtomicCmpxchg* curr, + std::optional<Type> type = std::nullopt) { + assert(!type || *type == Type::i32 || *type == Type::i64); + notePointer(&curr->ptr, curr->memory); + if (!type) { + if (curr->expected->type == Type::i64 || + curr->replacement->type == Type::i64) { + type = Type::i64; + } else { + type = Type::i32; + } + } + note(&curr->expected, *type); + note(&curr->replacement, *type); + } + + void visitAtomicWait(AtomicWait* curr) { + notePointer(&curr->ptr, curr->memory); + note(&curr->expected, curr->expectedType); + note(&curr->timeout, Type::i64); + } + + void visitAtomicNotify(AtomicNotify* curr) { + notePointer(&curr->ptr, curr->memory); + note(&curr->notifyCount, Type::i32); + } + + void visitAtomicFence(AtomicFence* curr) {} + + void visitSIMDExtract(SIMDExtract* curr) { note(&curr->vec, Type::v128); } + + void visitSIMDReplace(SIMDReplace* curr) { + note(&curr->vec, Type::v128); + switch (curr->op) { + case ReplaceLaneVecI8x16: + case ReplaceLaneVecI16x8: + case ReplaceLaneVecI32x4: + note(&curr->value, Type::i32); + break; + case ReplaceLaneVecI64x2: + note(&curr->value, Type::i64); + break; + case ReplaceLaneVecF32x4: + note(&curr->value, Type::f32); + break; + case ReplaceLaneVecF64x2: + note(&curr->value, Type::f64); + break; + } + } + + void visitSIMDShuffle(SIMDShuffle* curr) { + note(&curr->left, Type::v128); + note(&curr->right, Type::v128); + } + + void visitSIMDTernary(SIMDTernary* curr) { + note(&curr->a, Type::v128); + note(&curr->b, Type::v128); + note(&curr->c, Type::v128); + } + + void visitSIMDShift(SIMDShift* curr) { + note(&curr->vec, Type::v128); + note(&curr->shift, Type::i32); + } + + void visitSIMDLoad(SIMDLoad* curr) { notePointer(&curr->ptr, curr->memory); } + + void visitSIMDLoadStoreLane(SIMDLoadStoreLane* curr) { + notePointer(&curr->ptr, curr->memory); + note(&curr->vec, Type::v128); + } + + void visitMemoryInit(MemoryInit* curr) { + notePointer(&curr->dest, curr->memory); + note(&curr->offset, Type::i32); + note(&curr->size, Type::i32); + } + + void visitDataDrop(DataDrop* curr) {} + + void visitMemoryCopy(MemoryCopy* curr) { + assert(wasm.getMemory(curr->destMemory)->indexType == + wasm.getMemory(curr->sourceMemory)->indexType); + notePointer(&curr->dest, curr->destMemory); + notePointer(&curr->source, curr->sourceMemory); + notePointer(&curr->size, curr->destMemory); + } + + void visitMemoryFill(MemoryFill* curr) { + notePointer(&curr->dest, curr->memory); + note(&curr->value, Type::i32); + notePointer(&curr->size, curr->memory); + } + + void visitConst(Const* curr) {} + + void visitUnary(Unary* curr) { + switch (curr->op) { + case ClzInt32: + case CtzInt32: + case PopcntInt32: + case EqZInt32: + case ExtendSInt32: + case ExtendUInt32: + case ExtendS8Int32: + case ExtendS16Int32: + case ConvertUInt32ToFloat32: + case ConvertUInt32ToFloat64: + case ConvertSInt32ToFloat32: + case ConvertSInt32ToFloat64: + case ReinterpretInt32: + case SplatVecI8x16: + case SplatVecI16x8: + case SplatVecI32x4: + note(&curr->value, Type::i32); + break; + case ClzInt64: + case CtzInt64: + case PopcntInt64: + case EqZInt64: + case ExtendS8Int64: + case ExtendS16Int64: + case ExtendS32Int64: + case WrapInt64: + case ConvertUInt64ToFloat32: + case ConvertUInt64ToFloat64: + case ConvertSInt64ToFloat32: + case ConvertSInt64ToFloat64: + case ReinterpretInt64: + case SplatVecI64x2: + note(&curr->value, Type::i64); + break; + case NegFloat32: + case AbsFloat32: + case CeilFloat32: + case FloorFloat32: + case TruncFloat32: + case NearestFloat32: + case SqrtFloat32: + case TruncSFloat32ToInt32: + case TruncSFloat32ToInt64: + case TruncUFloat32ToInt32: + case TruncUFloat32ToInt64: + case TruncSatSFloat32ToInt32: + case TruncSatSFloat32ToInt64: + case TruncSatUFloat32ToInt32: + case TruncSatUFloat32ToInt64: + case ReinterpretFloat32: + case PromoteFloat32: + case SplatVecF32x4: + note(&curr->value, Type::f32); + break; + case NegFloat64: + case AbsFloat64: + case CeilFloat64: + case FloorFloat64: + case TruncFloat64: + case NearestFloat64: + case SqrtFloat64: + case TruncSFloat64ToInt32: + case TruncSFloat64ToInt64: + case TruncUFloat64ToInt32: + case TruncUFloat64ToInt64: + case TruncSatSFloat64ToInt32: + case TruncSatSFloat64ToInt64: + case TruncSatUFloat64ToInt32: + case TruncSatUFloat64ToInt64: + case ReinterpretFloat64: + case DemoteFloat64: + case SplatVecF64x2: + note(&curr->value, Type::f64); + break; + case NotVec128: + case PopcntVecI8x16: + case AbsVecI8x16: + case AbsVecI16x8: + case AbsVecI32x4: + case AbsVecI64x2: + case NegVecI8x16: + case NegVecI16x8: + case NegVecI32x4: + case NegVecI64x2: + case AbsVecF32x4: + case NegVecF32x4: + case SqrtVecF32x4: + case CeilVecF32x4: + case FloorVecF32x4: + case TruncVecF32x4: + case NearestVecF32x4: + case AbsVecF64x2: + case NegVecF64x2: + case SqrtVecF64x2: + case CeilVecF64x2: + case FloorVecF64x2: + case TruncVecF64x2: + case NearestVecF64x2: + case ExtAddPairwiseSVecI8x16ToI16x8: + case ExtAddPairwiseUVecI8x16ToI16x8: + case ExtAddPairwiseSVecI16x8ToI32x4: + case ExtAddPairwiseUVecI16x8ToI32x4: + case TruncSatSVecF32x4ToVecI32x4: + case TruncSatUVecF32x4ToVecI32x4: + case ConvertSVecI32x4ToVecF32x4: + case ConvertUVecI32x4ToVecF32x4: + case ExtendLowSVecI8x16ToVecI16x8: + case ExtendHighSVecI8x16ToVecI16x8: + case ExtendLowUVecI8x16ToVecI16x8: + case ExtendHighUVecI8x16ToVecI16x8: + case ExtendLowSVecI16x8ToVecI32x4: + case ExtendHighSVecI16x8ToVecI32x4: + case ExtendLowUVecI16x8ToVecI32x4: + case ExtendHighUVecI16x8ToVecI32x4: + case ExtendLowSVecI32x4ToVecI64x2: + case ExtendHighSVecI32x4ToVecI64x2: + case ExtendLowUVecI32x4ToVecI64x2: + case ExtendHighUVecI32x4ToVecI64x2: + case ConvertLowSVecI32x4ToVecF64x2: + case ConvertLowUVecI32x4ToVecF64x2: + case TruncSatZeroSVecF64x2ToVecI32x4: + case TruncSatZeroUVecF64x2ToVecI32x4: + case DemoteZeroVecF64x2ToVecF32x4: + case PromoteLowVecF32x4ToVecF64x2: + case RelaxedTruncSVecF32x4ToVecI32x4: + case RelaxedTruncUVecF32x4ToVecI32x4: + case RelaxedTruncZeroSVecF64x2ToVecI32x4: + case RelaxedTruncZeroUVecF64x2ToVecI32x4: + case AnyTrueVec128: + case AllTrueVecI8x16: + case AllTrueVecI16x8: + case AllTrueVecI32x4: + case AllTrueVecI64x2: + case BitmaskVecI8x16: + case BitmaskVecI16x8: + case BitmaskVecI32x4: + case BitmaskVecI64x2: + note(&curr->value, Type::v128); + break; + case InvalidUnary: + WASM_UNREACHABLE("invalid unary op"); + } + } + + void visitBinary(Binary* curr) { + switch (curr->op) { + case AddInt32: + case SubInt32: + case MulInt32: + case DivSInt32: + case DivUInt32: + case RemSInt32: + case RemUInt32: + case AndInt32: + case OrInt32: + case XorInt32: + case ShlInt32: + case ShrUInt32: + case ShrSInt32: + case RotLInt32: + case RotRInt32: + case EqInt32: + case NeInt32: + case LtSInt32: + case LtUInt32: + case LeSInt32: + case LeUInt32: + case GtSInt32: + case GtUInt32: + case GeSInt32: + case GeUInt32: + note(&curr->left, Type::i32); + note(&curr->right, Type::i32); + break; + case AddInt64: + case SubInt64: + case MulInt64: + case DivSInt64: + case DivUInt64: + case RemSInt64: + case RemUInt64: + case AndInt64: + case OrInt64: + case XorInt64: + case ShlInt64: + case ShrUInt64: + case ShrSInt64: + case RotLInt64: + case RotRInt64: + case EqInt64: + case NeInt64: + case LtSInt64: + case LtUInt64: + case LeSInt64: + case LeUInt64: + case GtSInt64: + case GtUInt64: + case GeSInt64: + case GeUInt64: + note(&curr->left, Type::i64); + note(&curr->right, Type::i64); + break; + case AddFloat32: + case SubFloat32: + case MulFloat32: + case DivFloat32: + case CopySignFloat32: + case MinFloat32: + case MaxFloat32: + case EqFloat32: + case NeFloat32: + case LtFloat32: + case LeFloat32: + case GtFloat32: + case GeFloat32: + note(&curr->left, Type::f32); + note(&curr->right, Type::f32); + break; + case AddFloat64: + case SubFloat64: + case MulFloat64: + case DivFloat64: + case CopySignFloat64: + case MinFloat64: + case MaxFloat64: + case EqFloat64: + case NeFloat64: + case LtFloat64: + case LeFloat64: + case GtFloat64: + case GeFloat64: + note(&curr->left, Type::f64); + note(&curr->right, Type::f64); + break; + case EqVecI8x16: + case NeVecI8x16: + case LtSVecI8x16: + case LtUVecI8x16: + case LeSVecI8x16: + case LeUVecI8x16: + case GtSVecI8x16: + case GtUVecI8x16: + case GeSVecI8x16: + case GeUVecI8x16: + case EqVecI16x8: + case NeVecI16x8: + case LtSVecI16x8: + case LtUVecI16x8: + case LeSVecI16x8: + case LeUVecI16x8: + case GtSVecI16x8: + case GtUVecI16x8: + case GeSVecI16x8: + case GeUVecI16x8: + case EqVecI32x4: + case NeVecI32x4: + case LtSVecI32x4: + case LtUVecI32x4: + case LeSVecI32x4: + case LeUVecI32x4: + case GtSVecI32x4: + case GtUVecI32x4: + case GeSVecI32x4: + case GeUVecI32x4: + case EqVecI64x2: + case NeVecI64x2: + case LtSVecI64x2: + case LeSVecI64x2: + case GtSVecI64x2: + case GeSVecI64x2: + case EqVecF32x4: + case NeVecF32x4: + case LtVecF32x4: + case LeVecF32x4: + case GtVecF32x4: + case GeVecF32x4: + case EqVecF64x2: + case NeVecF64x2: + case LtVecF64x2: + case LeVecF64x2: + case GtVecF64x2: + case GeVecF64x2: + case AndVec128: + case OrVec128: + case XorVec128: + case AndNotVec128: + case AddVecI8x16: + case AddSatSVecI8x16: + case AddSatUVecI8x16: + case SubVecI8x16: + case SubSatSVecI8x16: + case SubSatUVecI8x16: + case MinSVecI8x16: + case MinUVecI8x16: + case MaxSVecI8x16: + case MaxUVecI8x16: + case AvgrUVecI8x16: + case Q15MulrSatSVecI16x8: + case ExtMulLowSVecI16x8: + case ExtMulHighSVecI16x8: + case ExtMulLowUVecI16x8: + case ExtMulHighUVecI16x8: + case AddVecI16x8: + case AddSatSVecI16x8: + case AddSatUVecI16x8: + case SubVecI16x8: + case SubSatSVecI16x8: + case SubSatUVecI16x8: + case MulVecI16x8: + case MinSVecI16x8: + case MinUVecI16x8: + case MaxSVecI16x8: + case MaxUVecI16x8: + case AvgrUVecI16x8: + case AddVecI32x4: + case SubVecI32x4: + case MulVecI32x4: + case MinSVecI32x4: + case MinUVecI32x4: + case MaxSVecI32x4: + case MaxUVecI32x4: + case DotSVecI16x8ToVecI32x4: + case ExtMulLowSVecI32x4: + case ExtMulHighSVecI32x4: + case ExtMulLowUVecI32x4: + case ExtMulHighUVecI32x4: + case AddVecI64x2: + case SubVecI64x2: + case MulVecI64x2: + case ExtMulLowSVecI64x2: + case ExtMulHighSVecI64x2: + case ExtMulLowUVecI64x2: + case ExtMulHighUVecI64x2: + case AddVecF32x4: + case SubVecF32x4: + case MulVecF32x4: + case DivVecF32x4: + case MinVecF32x4: + case MaxVecF32x4: + case PMinVecF32x4: + case PMaxVecF32x4: + case RelaxedMinVecF32x4: + case RelaxedMaxVecF32x4: + case AddVecF64x2: + case SubVecF64x2: + case MulVecF64x2: + case DivVecF64x2: + case MinVecF64x2: + case MaxVecF64x2: + case PMinVecF64x2: + case PMaxVecF64x2: + case RelaxedMinVecF64x2: + case RelaxedMaxVecF64x2: + case NarrowSVecI16x8ToVecI8x16: + case NarrowUVecI16x8ToVecI8x16: + case NarrowSVecI32x4ToVecI16x8: + case NarrowUVecI32x4ToVecI16x8: + case SwizzleVecI8x16: + case RelaxedSwizzleVecI8x16: + case RelaxedQ15MulrSVecI16x8: + case DotI8x16I7x16SToVecI16x8: + note(&curr->left, Type::v128); + note(&curr->right, Type::v128); + break; + case InvalidBinary: + WASM_UNREACHABLE("invalid binary op"); + } + } + + void visitSelect(Select* curr, std::optional<Type> type = std::nullopt) { + if (type) { + note(&curr->ifTrue, *type); + note(&curr->ifFalse, *type); + } else { + noteAny(&curr->ifTrue); + noteAny(&curr->ifFalse); + } + note(&curr->condition, Type::i32); + } + + void visitDrop(Drop* curr, std::optional<Index> arity = std::nullopt) { + if (!arity) { + arity = curr->value->type.size(); + } + if (*arity > 1) { + noteAnyTuple(&curr->value, *arity); + } else { + noteAny(&curr->value); + } + } + + void visitReturn(Return* curr) { + assert(func); + auto type = func->getResults(); + if (type != Type::none) { + note(&curr->value, type); + } + } + + void visitMemorySize(MemorySize* curr) {} + + void visitMemoryGrow(MemoryGrow* curr) { + notePointer(&curr->delta, curr->memory); + } + + void visitUnreachable(Unreachable* curr) {} + + void visitPop(Pop* curr) {} + + void visitRefNull(RefNull* curr) {} + + void visitRefIsNull(RefIsNull* curr) { noteAnyReference(&curr->value); } + + void visitRefFunc(RefFunc* curr) {} + + void visitRefEq(RefEq* curr) { + Type eqref(HeapType::eq, Nullable); + note(&curr->left, eqref); + note(&curr->right, eqref); + } + + void visitTableGet(TableGet* curr) { note(&curr->index, Type::i32); } + + void visitTableSet(TableSet* curr) { + note(&curr->index, Type::i32); + note(&curr->value, wasm.getTable(curr->table)->type); + } + + void visitTableSize(TableSize* curr) {} + + void visitTableGrow(TableGrow* curr) { + note(&curr->value, wasm.getTable(curr->table)->type); + note(&curr->delta, Type::i32); + } + + void visitTableFill(TableFill* curr) { + auto type = wasm.getTable(curr->table)->type; + note(&curr->dest, Type::i32); + note(&curr->value, type); + note(&curr->size, Type::i32); + } + + void visitTableCopy(TableCopy* curr) { + note(&curr->dest, Type::i32); + note(&curr->source, Type::i32); + note(&curr->size, Type::i32); + } + + void visitTry(Try* curr) { + note(&curr->body, curr->type); + for (auto& expr : curr->catchBodies) { + note(&expr, curr->type); + } + } + + void visitTryTable(TryTable* curr) { note(&curr->body, curr->type); } + + void visitThrow(Throw* curr) { + auto type = wasm.getTag(curr->tag)->sig.params; + assert(curr->operands.size() == type.size()); + for (size_t i = 0; i < type.size(); ++i) { + note(&curr->operands[i], type[i]); + } + } + + void visitRethrow(Rethrow* curr) {} + + void visitThrowRef(ThrowRef* curr) { + note(&curr->exnref, Type(HeapType::exn, Nullable)); + } + + void visitTupleMake(TupleMake* curr) { + for (auto& expr : curr->operands) { + noteAny(&expr); + } + } + + void visitTupleExtract(TupleExtract* curr, + std::optional<size_t> arity = std::nullopt) { + if (!arity) { + assert(curr->tuple->type.isTuple()); + arity = curr->tuple->type.size(); + } + noteAnyTuple(&curr->tuple, *arity); + } + + void visitRefI31(RefI31* curr) { note(&curr->value, Type::i32); } + + void visitI31Get(I31Get* curr) { + note(&curr->i31, Type(HeapType::i31, Nullable)); + } + + void visitCallRef(CallRef* curr, std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->target->type.getHeapType().getSignature(); + } + auto params = ht->getSignature().params; + assert(curr->operands.size() == params.size()); + for (size_t i = 0; i < params.size(); ++i) { + note(&curr->operands[i], params[i]); + } + note(&curr->target, Type(*ht, Nullable)); + } + + void visitRefTest(RefTest* curr) { + auto top = curr->castType.getHeapType().getTop(); + note(&curr->ref, Type(top, Nullable)); + } + + void visitRefCast(RefCast* curr) { + auto top = curr->type.getHeapType().getTop(); + note(&curr->ref, Type(top, Nullable)); + } + + void visitBrOn(BrOn* curr) { + switch (curr->op) { + case BrOnNull: + case BrOnNonNull: + noteAnyReference(&curr->ref); + return; + case BrOnCast: + case BrOnCastFail: { + auto top = curr->castType.getHeapType().getTop(); + note(&curr->ref, Type(top, Nullable)); + return; + } + } + WASM_UNREACHABLE("unexpected op"); + } + + void visitStructNew(StructNew* curr) { + if (curr->isWithDefault()) { + return; + } + const auto& fields = curr->type.getHeapType().getStruct().fields; + assert(fields.size() == curr->operands.size()); + for (size_t i = 0; i < fields.size(); ++i) { + note(&curr->operands[i], fields[i].type); + } + } + + void visitStructGet(StructGet* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + note(&curr->ref, Type(*ht, Nullable)); + } + + void visitStructSet(StructSet* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + const auto& fields = ht->getStruct().fields; + assert(curr->index < fields.size()); + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->value, fields[curr->index].type); + } + + void visitArrayNew(ArrayNew* curr) { + if (!curr->isWithDefault()) { + note(&curr->init, curr->type.getHeapType().getArray().element.type); + } + note(&curr->size, Type::i32); + } + + void visitArrayNewData(ArrayNewData* curr) { + note(&curr->offset, Type::i32); + note(&curr->size, Type::i32); + } + + void visitArrayNewElem(ArrayNewElem* curr) { + note(&curr->offset, Type::i32); + note(&curr->size, Type::i32); + } + + void visitArrayNewFixed(ArrayNewFixed* curr) { + auto type = curr->type.getHeapType().getArray().element.type; + for (auto& expr : curr->values) { + note(&expr, type); + } + } + + void visitArrayGet(ArrayGet* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->index, Type::i32); + } + + void visitArraySet(ArraySet* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + auto type = ht->getArray().element.type; + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->index, Type::i32); + note(&curr->value, type); + } + + void visitArrayLen(ArrayLen* curr) { + note(&curr->ref, Type(HeapType::array, Nullable)); + } + + void visitArrayCopy(ArrayCopy* curr, + std::optional<HeapType> dest = std::nullopt, + std::optional<HeapType> src = std::nullopt) { + if (!dest) { + dest = curr->destRef->type.getHeapType(); + } + if (!src) { + src = curr->srcRef->type.getHeapType(); + } + note(&curr->destRef, Type(*dest, Nullable)); + note(&curr->destIndex, Type::i32); + note(&curr->srcRef, Type(*src, Nullable)); + note(&curr->srcIndex, Type::i32); + note(&curr->length, Type::i32); + } + + void visitArrayFill(ArrayFill* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + auto type = ht->getArray().element.type; + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->index, Type::i32); + note(&curr->value, type); + note(&curr->size, Type::i32); + } + + void visitArrayInitData(ArrayInitData* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->index, Type::i32); + note(&curr->offset, Type::i32); + note(&curr->size, Type::i32); + } + + void visitArrayInitElem(ArrayInitElem* curr, + std::optional<HeapType> ht = std::nullopt) { + if (!ht) { + ht = curr->ref->type.getHeapType(); + } + note(&curr->ref, Type(*ht, Nullable)); + note(&curr->index, Type::i32); + note(&curr->offset, Type::i32); + note(&curr->size, Type::i32); + } + + void visitRefAs(RefAs* curr) { + switch (curr->op) { + case RefAsNonNull: + noteAnyReference(&curr->value); + return; + case ExternInternalize: + note(&curr->value, Type(HeapType::ext, Nullable)); + return; + case ExternExternalize: + note(&curr->value, Type(HeapType::any, Nullable)); + return; + } + WASM_UNREACHABLE("unexpected op"); + } + + void visitStringNew(StringNew* curr, + std::optional<HeapType> ht = std::nullopt) { + switch (curr->op) { + case StringNewUTF8: + case StringNewWTF8: + case StringNewLossyUTF8: + case StringNewWTF16: + // TODO: This should be notePointer, but we don't have a memory. + note(&curr->ptr, Type::i32); + note(&curr->length, Type::i32); + return; + case StringNewUTF8Array: + case StringNewWTF8Array: + case StringNewLossyUTF8Array: + case StringNewWTF16Array: + if (!ht) { + ht = curr->ptr->type.getHeapType(); + } + note(&curr->ptr, Type(*ht, Nullable)); + note(&curr->start, Type::i32); + note(&curr->end, Type::i32); + return; + case StringNewFromCodePoint: + note(&curr->ptr, Type::i32); + return; + } + WASM_UNREACHABLE("unexpected op"); + } + + void visitStringConst(StringConst* curr) {} + + void visitStringMeasure(StringMeasure* curr) { + if (curr->op == StringMeasureWTF16View) { + note(&curr->ref, Type(HeapType::stringview_wtf16, Nullable)); + } else { + note(&curr->ref, Type(HeapType::string, Nullable)); + } + } + + void visitStringEncode(StringEncode* curr, + std::optional<HeapType> ht = std::nullopt) { + note(&curr->ref, Type(HeapType::string, Nullable)); + switch (curr->op) { + case StringEncodeUTF8: + case StringEncodeLossyUTF8: + case StringEncodeWTF8: + case StringEncodeWTF16: + // TODO: This should be notePointer, but we don't have a memory. + note(&curr->ptr, Type::i32); + return; + case StringEncodeUTF8Array: + case StringEncodeLossyUTF8Array: + case StringEncodeWTF8Array: + case StringEncodeWTF16Array: + if (!ht) { + ht = curr->ptr->type.getHeapType(); + } + note(&curr->ptr, Type(*ht, Nullable)); + note(&curr->start, Type::i32); + return; + } + WASM_UNREACHABLE("unexpected op"); + } + + void visitStringConcat(StringConcat* curr) { + auto stringref = Type(HeapType::string, Nullable); + note(&curr->left, stringref); + note(&curr->right, stringref); + } + + void visitStringEq(StringEq* curr) { + auto stringref = Type(HeapType::string, Nullable); + note(&curr->left, stringref); + note(&curr->right, stringref); + } + + void visitStringAs(StringAs* curr) { + note(&curr->ref, Type(HeapType::string, Nullable)); + } + + void visitStringWTF8Advance(StringWTF8Advance* curr) { + note(&curr->ref, Type(HeapType::stringview_wtf8, Nullable)); + note(&curr->pos, Type::i32); + note(&curr->bytes, Type::i32); + } + + void visitStringWTF16Get(StringWTF16Get* curr) { + note(&curr->ref, Type(HeapType::stringview_wtf16, Nullable)); + note(&curr->pos, Type::i32); + } + + void visitStringIterNext(StringIterNext* curr) { + note(&curr->ref, Type(HeapType::stringview_iter, Nullable)); + } + + void visitStringIterMove(StringIterMove* curr) { + note(&curr->ref, Type(HeapType::stringview_iter, Nullable)); + note(&curr->num, Type::i32); + } + + void visitStringSliceWTF(StringSliceWTF* curr) { + switch (curr->op) { + case StringSliceWTF8: + note(&curr->ref, Type(HeapType::stringview_wtf8, Nullable)); + break; + case StringSliceWTF16: + note(&curr->ref, Type(HeapType::stringview_wtf16, Nullable)); + break; + } + note(&curr->start, Type::i32); + note(&curr->end, Type::i32); + } + + void visitStringSliceIter(StringSliceIter* curr) { + note(&curr->ref, Type(HeapType::stringview_iter, Nullable)); + note(&curr->num, Type::i32); + } + + void visitContBind(ContBind* curr) { + auto paramsBefore = + curr->contTypeBefore.getContinuation().type.getSignature().params; + auto paramsAfter = + curr->contTypeAfter.getContinuation().type.getSignature().params; + assert(paramsBefore.size() >= paramsAfter.size()); + auto n = paramsBefore.size() - paramsAfter.size(); + assert(curr->operands.size() == n); + for (size_t i = 0; i < n; ++i) { + note(&curr->operands[i], paramsBefore[i]); + } + note(&curr->cont, Type(curr->contTypeBefore, Nullable)); + } + + void visitContNew(ContNew* curr) { + note(&curr->func, Type(curr->contType.getContinuation().type, Nullable)); + } + + void visitResume(Resume* curr) { + auto params = curr->contType.getContinuation().type.getSignature().params; + assert(params.size() == curr->operands.size()); + for (size_t i = 0; i < params.size(); ++i) { + note(&curr->operands[i], params[i]); + } + note(&curr->cont, Type(curr->contType, Nullable)); + } + + void visitSuspend(Suspend* curr) { + auto params = wasm.getTag(curr->tag)->sig.params; + assert(params.size() == curr->operands.size()); + for (size_t i = 0; i < params.size(); ++i) { + note(&curr->operands[i], params[i]); + } + } +}; + +} // namespace wasm + +#endif // wasm_ir_child_typer_h diff --git a/src/wasm-ir-builder.h b/src/wasm-ir-builder.h index 7c4c4a36d..3b6588e86 100644 --- a/src/wasm-ir-builder.h +++ b/src/wasm-ir-builder.h @@ -76,6 +76,15 @@ public: Name label = {}); [[nodiscard]] Result<> visitEnd(); + // Used to visit break nodes when traversing a single block without its + // context. The type indicates how many values the break carries to its + // destination. + [[nodiscard]] Result<> visitBreakWithType(Break*, Type); + // Used to visit switch nodes when traversing a single block without its + // context. The type indicates how many values the switch carries to its + // destination. + [[nodiscard]] Result<> visitSwitchWithType(Switch*, Type); + // Binaryen IR uses names to refer to branch targets, but in general there may // be branches to constructs that do not yet have names, so in IRBuilder we // use indices to refer to branch targets instead, just as the binary format @@ -220,49 +229,10 @@ public: // Private functions that must be public for technical reasons. [[nodiscard]] Result<> visitExpression(Expression*); - [[nodiscard]] Result<> - visitDrop(Drop*, std::optional<uint32_t> arity = std::nullopt); - [[nodiscard]] Result<> visitIf(If*); - [[nodiscard]] Result<> visitReturn(Return*); - [[nodiscard]] Result<> visitStructNew(StructNew*); - [[nodiscard]] Result<> visitArrayNew(ArrayNew*); - [[nodiscard]] Result<> visitArrayNewFixed(ArrayNewFixed*); - // Used to visit break exprs when traversing the module in the fully nested - // format. Break label destinations are assumed to have already been visited, - // with a corresponding push onto the scope stack. As a result, an error will - // return if a corresponding scope is not found for the break. - [[nodiscard]] Result<> visitBreak(Break*, - std::optional<Index> label = std::nullopt); - // Used to visit break nodes when traversing a single block without its - // context. The type indicates how many values the break carries to its - // destination. - [[nodiscard]] Result<> visitBreakWithType(Break*, Type); - [[nodiscard]] Result<> - // Used to visit switch exprs when traversing the module in the fully nested - // format. Switch label destinations are assumed to have already been visited, - // with a corresponding push onto the scope stack. As a result, an error will - // return if a corresponding scope is not found for the switch. - visitSwitch(Switch*, std::optional<Index> defaultLabel = std::nullopt); - // Used to visit switch nodes when traversing a single block without its - // context. The type indicates how many values the switch carries to its - // destination. - [[nodiscard]] Result<> visitSwitchWithType(Switch*, Type); - [[nodiscard]] Result<> visitCall(Call*); - [[nodiscard]] Result<> visitCallIndirect(CallIndirect*); - [[nodiscard]] Result<> visitCallRef(CallRef*); - [[nodiscard]] Result<> visitLocalSet(LocalSet*); - [[nodiscard]] Result<> visitGlobalSet(GlobalSet*); - [[nodiscard]] Result<> visitThrow(Throw*); - [[nodiscard]] Result<> visitStringNew(StringNew*); - [[nodiscard]] Result<> visitStringEncode(StringEncode*); - [[nodiscard]] Result<> visitContBind(ContBind*); - [[nodiscard]] Result<> visitResume(Resume*); - [[nodiscard]] Result<> visitSuspend(Suspend*); - [[nodiscard]] Result<> visitTupleMake(TupleMake*); - [[nodiscard]] Result<> - visitTupleExtract(TupleExtract*, - std::optional<uint32_t> arity = std::nullopt); - [[nodiscard]] Result<> visitPop(Pop*); + + // Do not push pops onto the stack since we generate our own pops as necessary + // when visiting the beginnings of try blocks. + [[nodiscard]] Result<> visitPop(Pop*) { return Ok{}; } private: Module& wasm; @@ -270,6 +240,8 @@ private: Builder builder; std::optional<Function::DebugLocation> debugLoc; + struct ChildPopper; + void applyDebugLoc(Expression* expr); // The context for a single block scope, including the instructions parsed @@ -534,7 +506,6 @@ private: [[nodiscard]] Result<Name> getLabelName(Index label); [[nodiscard]] Result<Name> getDelegateLabelName(Index label); [[nodiscard]] Result<Index> addScratchLocal(Type); - [[nodiscard]] Result<Expression*> pop(size_t size = 1); struct HoistedVal { // The index in the stack of the original value-producing expression. @@ -556,8 +527,8 @@ private: [[nodiscard]] Result<> packageHoistedValue(const HoistedVal&, size_t sizeHint = 1); - [[nodiscard]] Result<Expression*> - getBranchValue(Expression* curr, Name labelName, std::optional<Index> label); + [[nodiscard]] Result<Type> getLabelType(Index label); + [[nodiscard]] Result<Type> getLabelType(Name labelName); void dump(); }; diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index 7b88d345f..bee858435 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -16,6 +16,7 @@ #include <cassert> +#include "ir/child-typer.h" #include "ir/names.h" #include "ir/properties.h" #include "ir/utils.h" @@ -140,13 +141,6 @@ Result<> IRBuilder::packageHoistedValue(const HoistedVal& hoisted, void IRBuilder::push(Expression* expr) { auto& scope = getScope(); if (expr->type == Type::unreachable) { - // We want to avoid popping back past this most recent unreachable - // instruction. Drop all prior instructions so they won't be consumed by - // later instructions but will still be emitted for their side effects, if - // any. - for (auto& expr : scope.exprStack) { - expr = builder.dropIfConcretelyTyped(expr); - } scope.unreachable = true; } scope.exprStack.push_back(expr); @@ -157,44 +151,6 @@ void IRBuilder::push(Expression* expr) { DBG(dump()); } -Result<Expression*> IRBuilder::pop(size_t size) { - assert(size >= 1); - auto& scope = getScope(); - - // Find the suffix of expressions that do not produce values. - auto hoisted = hoistLastValue(); - CHECK_ERR(hoisted); - if (!hoisted) { - // There are no expressions that produce values. - if (scope.unreachable) { - return builder.makeUnreachable(); - } - return Err{"popping from empty stack"}; - } - - CHECK_ERR(packageHoistedValue(*hoisted, size)); - - auto* ret = scope.exprStack.back(); - // If the top value has the correct size, we can pop it and be done. - // Unreachable values satisfy any size. - if (ret->type.size() == size || ret->type == Type::unreachable) { - scope.exprStack.pop_back(); - return ret; - } - - // The last value-producing expression did not produce exactly the right - // number of values, so we need to construct a tuple piecewise instead. - assert(size > 1); - std::vector<Expression*> elems; - elems.resize(size); - for (int i = size - 1; i >= 0; --i) { - auto elem = pop(); - CHECK_ERR(elem); - elems[i] = *elem; - } - return builder.makeTupleMake(elems); -} - Result<Expression*> IRBuilder::build() { if (scopeStack.empty()) { return builder.makeNop(); @@ -292,417 +248,424 @@ void IRBuilder::dump() { #endif // IR_BUILDER_DEBUG } -Result<> IRBuilder::visit(Expression* curr) { - // Call either `visitExpression` or an expression-specific override. - auto val = UnifiedExpressionVisitor<IRBuilder, Result<>>::visit(curr); - CHECK_ERR(val); - if (auto* block = curr->dynCast<Block>()) { - block->finalize(block->type); - } else { - // TODO: Call more efficient versions of finalize() that take the known type - // for other kinds of nodes as well, as done above. - ReFinalizeNode{}.visit(curr); - } - push(curr); - return Ok{}; -} +struct IRBuilder::ChildPopper + : UnifiedExpressionVisitor<ChildPopper, Result<>> { + struct Subtype { + Type bound; + }; -// Handle the common case of instructions with a constant number of children -// uniformly. -Result<> IRBuilder::visitExpression(Expression* curr) { - if (Properties::isControlFlowStructure(curr)) { - // Control flow structures (besides `if`, handled separately) do not consume - // stack values. - return Ok{}; - } + struct AnyType {}; -#define DELEGATE_ID curr->_id -#define DELEGATE_START(id) [[maybe_unused]] auto* expr = curr->cast<id>(); -#define DELEGATE_GET_FIELD(id, field) expr->field -#define DELEGATE_FIELD_CHILD(id, field) \ - auto field = pop(); \ - CHECK_ERR(field); \ - expr->field = *field; -#define DELEGATE_FIELD_SCOPE_NAME_DEF(id, field) \ - if (labelDepths.count(expr->field)) { \ - return Err{"repeated label"}; \ - } -#define DELEGATE_END(id) + struct AnyReference {}; -#define DELEGATE_FIELD_OPTIONAL_CHILD(id, field) \ - WASM_UNREACHABLE("should have called visit" #id " because " #id \ - " has optional child " #field); -#define DELEGATE_FIELD_CHILD_VECTOR(id, field) \ - WASM_UNREACHABLE("should have called visit" #id " because " #id \ - " has child vector " #field); + struct AnyTuple { + size_t arity; + }; -#define DELEGATE_FIELD_INT(id, field) -#define DELEGATE_FIELD_LITERAL(id, field) -#define DELEGATE_FIELD_NAME(id, field) -#define DELEGATE_FIELD_SCOPE_NAME_USE(id, field) + struct Constraint : std::variant<Subtype, AnyType, AnyReference, AnyTuple> { + std::optional<Type> getSubtype() const { + if (auto* subtype = std::get_if<Subtype>(this)) { + return subtype->bound; + } + return std::nullopt; + } + bool isAnyType() const { return std::get_if<AnyType>(this); } + bool isAnyReference() const { return std::get_if<AnyReference>(this); } + std::optional<size_t> getAnyTuple() const { + if (auto* tuple = std::get_if<AnyTuple>(this)) { + return tuple->arity; + } + return std::nullopt; + } + size_t size() const { + if (auto type = getSubtype()) { + return type->size(); + } + if (auto arity = getAnyTuple()) { + return *arity; + } + return 1; + } + Constraint operator[](size_t i) const { + if (auto type = getSubtype()) { + return {Subtype{(*type)[i]}}; + } + if (getAnyTuple()) { + return {AnyType{}}; + } + return *this; + } + }; -#define DELEGATE_FIELD_TYPE(id, field) -#define DELEGATE_FIELD_HEAPTYPE(id, field) -#define DELEGATE_FIELD_ADDRESS(id, field) + struct Child { + Expression** childp; + Constraint constraint; + }; -#include "wasm-delegations-fields.def" + struct ConstraintCollector : ChildTyper<ConstraintCollector> { + IRBuilder& builder; + std::vector<Child>& children; - return Ok{}; -} + ConstraintCollector(IRBuilder& builder, std::vector<Child>& children) + : ChildTyper(builder.wasm, builder.func), builder(builder), + children(children) {} -Result<> IRBuilder::visitDrop(Drop* curr, std::optional<uint32_t> arity) { - // Multivalue drops must remain multivalue drops. - if (!arity) { - arity = curr->value->type.size(); - } - if (*arity >= 2) { - auto val = pop(*arity); - CHECK_ERR(val); - curr->value = *val; + void noteSubtype(Expression** childp, Type type) { + children.push_back({childp, {Subtype{type}}}); + } + + void noteAnyType(Expression** childp) { + children.push_back({childp, {AnyType{}}}); + } + + void noteAnyReferenceType(Expression** childp) { + children.push_back({childp, {AnyReference{}}}); + } + + void noteAnyTupleType(Expression** childp, size_t arity) { + children.push_back({childp, {AnyTuple{arity}}}); + } + + Type getLabelType(Name label) { + WASM_UNREACHABLE("labels should be explicitly provided"); + }; + + void visitIf(If* curr) { + // Skip the control flow children because we only want to pop the + // condition. + children.push_back({&curr->condition, {Subtype{Type::i32}}}); + } + }; + + IRBuilder& builder; + + ChildPopper(IRBuilder& builder) : builder(builder) {} + +private: + [[nodiscard]] Result<> popConstrainedChildren(std::vector<Child>& children) { + auto& scope = builder.getScope(); + + // Two-part indices into the stack of available expressions and the vector + // of requirements, allowing them to move independently with the granularity + // of a single tuple element. + size_t stackIndex = scope.exprStack.size(); + size_t stackTupleIndex = 0; + size_t childIndex = children.size(); + size_t childTupleIndex = 0; + + // The index of the shallowest unreachable instruction on the stack. + std::optional<size_t> unreachableIndex; + + // Whether popping the children past the unreachable would produce a type + // mismatch or try to pop from an empty stack. + bool needUnreachableFallback = false; + + if (!scope.unreachable) { + // We only need to check requirements if there is an unreachable. + // Otherwise the validator will catch any problems. + goto pop; + } + + // Check whether the values on the stack will be able to meet the given + // requirements. + while (true) { + // Advance to the next requirement. + if (childTupleIndex > 0) { + --childTupleIndex; + } else { + if (childIndex == 0) { + // We have examined all the requirements. + break; + } + --childIndex; + childTupleIndex = children[childIndex].constraint.size() - 1; + } + + // Advance to the next available value on the stack. + while (true) { + if (stackTupleIndex > 0) { + --stackTupleIndex; + } else { + if (stackIndex == 0) { + // No more available values. This is fine iff we are reaching past + // an unreachable. Any error will be caught later when we are + // popping. + goto pop; + } + --stackIndex; + stackTupleIndex = scope.exprStack[stackIndex]->type.size() - 1; + } + + // Skip expressions that don't produce values. + if (scope.exprStack[stackIndex]->type == Type::none) { + stackTupleIndex = 0; + continue; + } + break; + } + + // We have an available type and a constraint. Only check constraints if + // we are past an unreachable, since otherwise we can leave problems to be + // caught by the validator later. + auto type = scope.exprStack[stackIndex]->type[stackTupleIndex]; + if (unreachableIndex) { + auto constraint = children[childIndex].constraint[childTupleIndex]; + if (constraint.isAnyType()) { + // Always succeeds. + } else if (constraint.isAnyReference()) { + if (!type.isRef() && type != Type::unreachable) { + needUnreachableFallback = true; + break; + } + } else if (auto bound = constraint.getSubtype()) { + if (!Type::isSubType(type, *bound)) { + needUnreachableFallback = true; + break; + } + } else { + WASM_UNREACHABLE("unexpected constraint"); + } + } + + // No problems for children after this unreachable. + if (type == Type::unreachable) { + assert(!needUnreachableFallback); + unreachableIndex = stackIndex; + } + } + + pop: + // We have checked all the constraints, so we are ready to pop children. + for (int i = children.size() - 1; i >= 0; --i) { + if (needUnreachableFallback && + scope.exprStack.size() == *unreachableIndex + 1) { + // The expressions remaining on the stack may be executed, but they do + // not satisfy the requirements to be children of the current parent. + // Explicitly drop them so they will still be executed for their side + // effects and so the remaining children will be filled with + // unreachables. + assert(scope.exprStack.back()->type == Type::unreachable); + for (auto& expr : scope.exprStack) { + expr = Builder(builder.wasm).dropIfConcretelyTyped(expr); + } + } + + auto val = pop(children[i].constraint.size()); + CHECK_ERR(val); + *children[i].childp = *val; + } return Ok{}; } - return visitExpression(curr); -} -Result<> IRBuilder::visitIf(If* curr) { - // Only the condition is popped from the stack. The ifTrue and ifFalse are - // self-contained so we do not modify them. - auto cond = pop(); - CHECK_ERR(cond); - curr->condition = *cond; - return Ok{}; -} + Result<Expression*> pop(size_t size) { + assert(size >= 1); + auto& scope = builder.getScope(); -Result<> IRBuilder::visitReturn(Return* curr) { - if (!func) { - return Err{"cannot return outside of a function"}; - } - size_t n = func->getResults().size(); - if (n == 0) { - curr->value = nullptr; - } else { - auto val = pop(n); - CHECK_ERR(val); - curr->value = *val; - } - return Ok{}; -} + // Find the suffix of expressions that do not produce values. + auto hoisted = builder.hoistLastValue(); + CHECK_ERR(hoisted); + if (!hoisted) { + // There are no expressions that produce values. + if (scope.unreachable) { + return builder.builder.makeUnreachable(); + } + return Err{"popping from empty stack"}; + } + + CHECK_ERR(builder.packageHoistedValue(*hoisted, size)); + + auto* ret = scope.exprStack.back(); + // If the top value has the correct size, we can pop it and be done. + // Unreachable values satisfy any size. + if (ret->type.size() == size || ret->type == Type::unreachable) { + scope.exprStack.pop_back(); + return ret; + } -Result<> IRBuilder::visitStructNew(StructNew* curr) { - for (size_t i = 0, n = curr->operands.size(); i < n; ++i) { - auto val = pop(); - CHECK_ERR(val); - curr->operands[n - 1 - i] = *val; + // The last value-producing expression did not produce exactly the right + // number of values, so we need to construct a tuple piecewise instead. + assert(size > 1); + std::vector<Expression*> elems; + elems.resize(size); + for (int i = size - 1; i >= 0; --i) { + auto elem = pop(1); + CHECK_ERR(elem); + elems[i] = *elem; + } + return builder.builder.makeTupleMake(elems); } - return Ok{}; -} -Result<> IRBuilder::visitArrayNew(ArrayNew* curr) { - auto size = pop(); - CHECK_ERR(size); - curr->size = *size; - if (!curr->isWithDefault()) { - auto init = pop(); - CHECK_ERR(init); - curr->init = *init; +public: + Result<> visitExpression(Expression* expr) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visit(expr); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<> IRBuilder::visitArrayNewFixed(ArrayNewFixed* curr) { - for (size_t i = 0, size = curr->values.size(); i < size; ++i) { - auto val = pop(); - CHECK_ERR(val); - curr->values[size - i - 1] = *val; + Result<> visitAtomicCmpxchg(AtomicCmpxchg* curr, + std::optional<Type> type = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitAtomicCmpxchg(curr, type); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<Expression*> IRBuilder::getBranchValue(Expression* curr, - Name labelName, - std::optional<Index> label) { - // As new branch instructions are added, one of the existing branch visit* - // functions is likely to be copied, along with its call to getBranchValue(). - // This assert serves as a reminder to also add an implementation of - // visit*WithType() for new branch instructions. - assert(curr->is<Break>() || curr->is<Switch>()); - if (!label) { - auto index = getLabelIndex(labelName); - CHECK_ERR(index); - label = *index; + Result<> visitStructGet(StructGet* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitStructGet(curr, ht); + return popConstrainedChildren(children); } - auto scope = getScope(*label); - CHECK_ERR(scope); - // Loops would receive their input type rather than their output type, if we - // supported that. - size_t numValues = (*scope)->getLoop() ? 0 : (*scope)->getResultType().size(); - return numValues == 0 ? nullptr : pop(numValues); -} -Result<> IRBuilder::visitBreak(Break* curr, std::optional<Index> label) { - if (curr->condition) { - auto cond = pop(); - CHECK_ERR(cond); - curr->condition = *cond; + Result<> visitStructSet(StructSet* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitStructSet(curr, ht); + return popConstrainedChildren(children); } - auto value = getBranchValue(curr, curr->name, label); - CHECK_ERR(value); - curr->value = *value; - return Ok{}; -} -Result<> IRBuilder::visitBreakWithType(Break* curr, Type type) { - if (curr->condition) { - auto cond = pop(); - CHECK_ERR(cond); - curr->condition = *cond; + Result<> visitArrayGet(ArrayGet* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArrayGet(curr, ht); + return popConstrainedChildren(children); } - if (type == Type::none) { - curr->value = nullptr; - } else { - auto value = pop(type.size()); - CHECK_ERR(value) - curr->value = *value; + + Result<> visitArraySet(ArraySet* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArraySet(curr, ht); + return popConstrainedChildren(children); } - curr->finalize(); - push(curr); - return Ok{}; -} -Result<> IRBuilder::visitSwitch(Switch* curr, - std::optional<Index> defaultLabel) { - auto cond = pop(); - CHECK_ERR(cond); - curr->condition = *cond; - auto value = getBranchValue(curr, curr->default_, defaultLabel); - CHECK_ERR(value); - curr->value = *value; - return Ok{}; -} + Result<> visitArrayCopy(ArrayCopy* curr, + std::optional<HeapType> dest = std::nullopt, + std::optional<HeapType> src = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArrayCopy(curr, dest, src); + return popConstrainedChildren(children); + } -Result<> IRBuilder::visitSwitchWithType(Switch* curr, Type type) { - auto cond = pop(); - CHECK_ERR(cond); - curr->condition = *cond; - if (type == Type::none) { - curr->value = nullptr; - } else { - auto value = pop(type.size()); - CHECK_ERR(value) - curr->value = *value; + Result<> visitArrayFill(ArrayFill* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArrayFill(curr, ht); + return popConstrainedChildren(children); } - curr->finalize(); - push(curr); - return Ok{}; -} -Result<> IRBuilder::visitCall(Call* curr) { - auto numArgs = wasm.getFunction(curr->target)->getNumParams(); - curr->operands.resize(numArgs); - for (size_t i = 0; i < numArgs; ++i) { - auto arg = pop(); - CHECK_ERR(arg); - curr->operands[numArgs - 1 - i] = *arg; + Result<> visitArrayInitData(ArrayInitData* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArrayInitData(curr, ht); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<> IRBuilder::visitCallIndirect(CallIndirect* curr) { - auto target = pop(); - CHECK_ERR(target); - curr->target = *target; - auto numArgs = curr->heapType.getSignature().params.size(); - curr->operands.resize(numArgs); - for (size_t i = 0; i < numArgs; ++i) { - auto arg = pop(); - CHECK_ERR(arg); - curr->operands[numArgs - 1 - i] = *arg; + Result<> visitArrayInitElem(ArrayInitElem* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitArrayInitElem(curr, ht); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<> IRBuilder::visitCallRef(CallRef* curr) { - auto target = pop(); - CHECK_ERR(target); - curr->target = *target; - for (size_t i = 0, numArgs = curr->operands.size(); i < numArgs; ++i) { - auto arg = pop(); - CHECK_ERR(arg); - curr->operands[numArgs - 1 - i] = *arg; + Result<> visitStringNew(StringNew* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitStringNew(curr, ht); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<> IRBuilder::visitLocalSet(LocalSet* curr) { - auto type = func->getLocalType(curr->index); - auto val = pop(type.size()); - CHECK_ERR(val); - curr->value = *val; - return Ok{}; -} + Result<> visitStringEncode(StringEncode* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitStringEncode(curr, ht); + return popConstrainedChildren(children); + } -Result<> IRBuilder::visitGlobalSet(GlobalSet* curr) { - auto type = wasm.getGlobal(curr->name)->type; - auto val = pop(type.size()); - CHECK_ERR(val); - curr->value = *val; - return Ok{}; -} -Result<> IRBuilder::visitThrow(Throw* curr) { - auto numArgs = wasm.getTag(curr->tag)->sig.params.size(); - curr->operands.resize(numArgs); - for (size_t i = 0; i < numArgs; ++i) { - auto arg = pop(); - CHECK_ERR(arg); - curr->operands[numArgs - 1 - i] = *arg; + Result<> visitCallRef(CallRef* curr, + std::optional<HeapType> ht = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitCallRef(curr, ht); + return popConstrainedChildren(children); } - return Ok{}; -} -Result<> IRBuilder::visitStringNew(StringNew* curr) { - switch (curr->op) { - case StringNewUTF8: - case StringNewWTF8: - case StringNewLossyUTF8: - case StringNewWTF16: { - auto len = pop(); - CHECK_ERR(len); - curr->length = *len; - break; - } - case StringNewUTF8Array: - case StringNewWTF8Array: - case StringNewLossyUTF8Array: - case StringNewWTF16Array: { - auto end = pop(); - CHECK_ERR(end); - curr->end = *end; - auto start = pop(); - CHECK_ERR(start); - curr->start = *start; - break; - } - case StringNewFromCodePoint: - break; + Result<> visitBreak(Break* curr, + std::optional<Type> labelType = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitBreak(curr, labelType); + return popConstrainedChildren(children); } - auto ptr = pop(); - CHECK_ERR(ptr); - curr->ptr = *ptr; - return Ok{}; -} -Result<> IRBuilder::visitStringEncode(StringEncode* curr) { - switch (curr->op) { - case StringEncodeUTF8Array: - case StringEncodeLossyUTF8Array: - case StringEncodeWTF8Array: - case StringEncodeWTF16Array: { - auto start = pop(); - CHECK_ERR(start); - curr->start = *start; - } - [[fallthrough]]; - case StringEncodeUTF8: - case StringEncodeLossyUTF8: - case StringEncodeWTF8: - case StringEncodeWTF16: { - auto ptr = pop(); - CHECK_ERR(ptr); - curr->ptr = *ptr; - auto ref = pop(); - CHECK_ERR(ref); - curr->ref = *ref; - return Ok{}; - } + Result<> visitSwitch(Switch* curr, + std::optional<Type> labelType = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitSwitch(curr, labelType); + return popConstrainedChildren(children); } - WASM_UNREACHABLE("unexpected op"); -} -Result<> IRBuilder::visitContBind(ContBind* curr) { - auto cont = pop(); - CHECK_ERR(cont); - curr->cont = *cont; + Result<> visitDrop(Drop* curr, std::optional<Index> arity = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitDrop(curr, arity); + return popConstrainedChildren(children); + } - size_t paramsBefore = - curr->contTypeBefore.getContinuation().type.getSignature().params.size(); - size_t paramsAfter = - curr->contTypeAfter.getContinuation().type.getSignature().params.size(); - if (paramsBefore < paramsAfter) { - return Err{"incompatible continuation types in cont.bind: source type " + - curr->contTypeBefore.toString() + - " has fewer parameters than destination " + - curr->contTypeAfter.toString()}; + Result<> visitTupleExtract(TupleExtract* curr, + std::optional<Index> arity = std::nullopt) { + std::vector<Child> children; + ConstraintCollector{builder, children}.visitTupleExtract(curr, arity); + return popConstrainedChildren(children); } - size_t numArgs = paramsBefore - paramsAfter; +}; - curr->operands.resize(numArgs); - for (size_t i = 0; i < numArgs; ++i) { - auto val = pop(); - CHECK_ERR(val); - curr->operands[numArgs - i - 1] = *val; +Result<> IRBuilder::visit(Expression* curr) { + // Call either `visitExpression` or an expression-specific override. + auto val = UnifiedExpressionVisitor<IRBuilder, Result<>>::visit(curr); + CHECK_ERR(val); + if (auto* block = curr->dynCast<Block>()) { + block->finalize(block->type); + } else { + // TODO: Call more efficient versions of finalize() that take the known type + // for other kinds of nodes as well, as done above. + ReFinalizeNode{}.visit(curr); } + push(curr); return Ok{}; } -Result<> IRBuilder::visitResume(Resume* curr) { - auto cont = pop(); - CHECK_ERR(cont); - curr->cont = *cont; - - auto sig = curr->contType.getContinuation().type.getSignature(); - auto size = sig.params.size(); - curr->operands.resize(size); - for (size_t i = 0; i < size; ++i) { - auto val = pop(); - CHECK_ERR(val); - curr->operands[size - i - 1] = *val; +// Handle the common case of instructions with a constant number of children +// uniformly. +Result<> IRBuilder::visitExpression(Expression* curr) { + if (Properties::isControlFlowStructure(curr) && !curr->is<If>()) { + // Control flow structures (besides `if`, handled separately) do not consume + // stack values. + return Ok{}; } - return Ok{}; + return ChildPopper{*this}.visit(curr); } -Result<> IRBuilder::visitSuspend(Suspend* curr) { - auto tag = wasm.getTag(curr->tag); - auto sig = tag->sig; - auto size = sig.params.size(); - curr->operands.resize(size); - for (size_t i = 0; i < size; ++i) { - auto val = pop(); - CHECK_ERR(val); - curr->operands[size - i - 1] = *val; - } - return Ok{}; +Result<Type> IRBuilder::getLabelType(Index label) { + auto scope = getScope(label); + CHECK_ERR(scope); + // Loops would receive their input type rather than their output type, if we + // supported that. + return (*scope)->getLoop() ? Type::none : (*scope)->getResultType(); } -Result<> IRBuilder::visitTupleMake(TupleMake* curr) { - assert(curr->operands.size() >= 2); - for (size_t i = 0, size = curr->operands.size(); i < size; ++i) { - auto elem = pop(); - CHECK_ERR(elem); - curr->operands[size - 1 - i] = *elem; - } - return Ok{}; +Result<Type> IRBuilder::getLabelType(Name labelName) { + auto label = getLabelIndex(labelName); + CHECK_ERR(label); + return getLabelType(*label); } -Result<> IRBuilder::visitTupleExtract(TupleExtract* curr, - std::optional<uint32_t> arity) { - if (!arity) { - if (curr->tuple->type == Type::unreachable) { - // Fallback to an arbitrary valid arity. - arity = 2; - } else { - arity = curr->tuple->type.size(); - } - } - assert(*arity >= 2); - auto tuple = pop(*arity); - CHECK_ERR(tuple); - curr->tuple = *tuple; +Result<> IRBuilder::visitBreakWithType(Break* curr, Type type) { + CHECK_ERR(ChildPopper{*this}.visitBreak(curr, type)); + curr->finalize(); + push(curr); return Ok{}; } -Result<> IRBuilder::visitPop(Pop*) { - // Do not actually push this pop onto the stack since we generate our own pops - // as necessary when visiting the beginnings of try blocks. +Result<> IRBuilder::visitSwitchWithType(Switch* curr, Type type) { + CHECK_ERR(ChildPopper{*this}.visitSwitch(curr, type)); + curr->finalize(); + push(curr); return Ok{}; } @@ -727,9 +690,7 @@ Result<> IRBuilder::visitBlockStart(Block* curr) { Result<> IRBuilder::visitIfStart(If* iff, Name label) { applyDebugLoc(iff); - auto cond = pop(); - CHECK_ERR(cond); - iff->condition = *cond; + CHECK_ERR(visitIf(iff)); pushScope(ScopeCtx::makeIf(iff, label)); return Ok{}; } @@ -769,52 +730,36 @@ Result<Expression*> IRBuilder::finishScope(Block* block) { auto& scope = scopeStack.back(); auto type = scope.getResultType(); - if (type.isTuple()) { - if (scope.unreachable) { - // We may not have enough concrete values on the stack to construct the - // full tuple, and if we tried to fill out the beginning of a tuple.make - // with additional popped `unreachable`s, that could cause a trap to - // happen before important side effects. Instead, just drop everything on - // the stack and finish with a single unreachable. - // - // TODO: Validate that the available expressions are a correct suffix of - // the expected type, since this will no longer be caught by normal - // validation? - for (auto& expr : scope.exprStack) { - expr = builder.dropIfConcretelyTyped(expr); - } - if (scope.exprStack.back()->type != Type::unreachable) { - scope.exprStack.push_back(builder.makeUnreachable()); - } - } else { - auto hoisted = hoistLastValue(); - CHECK_ERR(hoisted); - if (!hoisted) { - return Err{"popping from empty stack"}; - } - auto hoistedType = scope.exprStack.back()->type; - if (hoistedType.size() != type.size()) { - // We cannot propagate the hoisted value directly because it does not - // have the correct number of elements. Break it up if necessary and - // construct our returned tuple from parts. - CHECK_ERR(packageHoistedValue(*hoisted)); - std::vector<Expression*> elems(type.size()); - for (size_t i = 0; i < elems.size(); ++i) { - auto elem = pop(); - CHECK_ERR(elem); - elems[elems.size() - 1 - i] = *elem; - } - scope.exprStack.push_back(builder.makeTupleMake(std::move(elems))); + + if (scope.unreachable) { + // Drop everything before the last unreachable. + bool sawUnreachable = false; + for (int i = scope.exprStack.size() - 1; i >= 0; --i) { + if (sawUnreachable) { + scope.exprStack[i] = builder.dropIfConcretelyTyped(scope.exprStack[i]); + } else if (scope.exprStack[i]->type == Type::unreachable) { + sawUnreachable = true; } } - } else if (type.isConcrete()) { - // If the value is buried in none-typed expressions, we have to bring it to - // the top. + } + + if (type.isConcrete()) { auto hoisted = hoistLastValue(); CHECK_ERR(hoisted); if (!hoisted) { return Err{"popping from empty stack"}; } + + if (type.isTuple()) { + auto hoistedType = scope.exprStack.back()->type; + if (hoistedType != Type::unreachable && + hoistedType.size() != type.size()) { + // We cannot propagate the hoisted value directly because it does not + // have the correct number of elements. Repackage it. + CHECK_ERR(packageHoistedValue(*hoisted, hoistedType.size())); + CHECK_ERR(makeTupleMake(type.size())); + } + } } Expression* ret = nullptr; @@ -1124,44 +1069,58 @@ Result<> IRBuilder::makeLoop(Name label, Type type) { Result<> IRBuilder::makeBreak(Index label, bool isConditional) { auto name = getLabelName(label); CHECK_ERR(name); + auto labelType = getLabelType(label); + CHECK_ERR(labelType); + Break curr; curr.name = *name; // Use a dummy condition value if we need to pop a condition. curr.condition = isConditional ? &curr : nullptr; - CHECK_ERR(visitBreak(&curr, label)); + CHECK_ERR(ChildPopper{*this}.visitBreak(&curr, *labelType)); push(builder.makeBreak(curr.name, curr.value, curr.condition)); return Ok{}; } Result<> IRBuilder::makeSwitch(const std::vector<Index>& labels, Index defaultLabel) { + auto defaultType = getLabelType(defaultLabel); + CHECK_ERR(defaultType); + std::vector<Name> names; names.reserve(labels.size()); + Type glbLabelType = *defaultType; for (auto label : labels) { auto name = getLabelName(label); CHECK_ERR(name); names.push_back(*name); + auto type = getLabelType(label); + CHECK_ERR(type); + glbLabelType = Type::getGreatestLowerBound(glbLabelType, *type); } + auto defaultName = getLabelName(defaultLabel); CHECK_ERR(defaultName); + Switch curr(wasm.allocator); - CHECK_ERR(visitSwitch(&curr, defaultLabel)); + CHECK_ERR(ChildPopper{*this}.visitSwitch(&curr, glbLabelType)); push(builder.makeSwitch(names, *defaultName, curr.condition, curr.value)); return Ok{}; } Result<> IRBuilder::makeCall(Name func, bool isReturn) { + auto sig = wasm.getFunction(func)->getSig(); Call curr(wasm.allocator); curr.target = func; + curr.operands.resize(sig.params.size()); CHECK_ERR(visitCall(&curr)); - auto type = wasm.getFunction(func)->getResults(); - push(builder.makeCall(curr.target, curr.operands, type, isReturn)); + push(builder.makeCall(curr.target, curr.operands, sig.results, isReturn)); return Ok{}; } Result<> IRBuilder::makeCallIndirect(Name table, HeapType type, bool isReturn) { CallIndirect curr(wasm.allocator); curr.heapType = type; + curr.operands.resize(type.getSignature().params.size()); CHECK_ERR(visitCallIndirect(&curr)); push(builder.makeCallIndirect( table, curr.target, curr.operands, type, isReturn)); @@ -1209,6 +1168,7 @@ Result<> IRBuilder::makeLoad(unsigned bytes, Type type, Name mem) { Load curr; + curr.memory = mem; CHECK_ERR(visitLoad(&curr)); push(builder.makeLoad(bytes, signed_, offset, align, curr.ptr, type, mem)); return Ok{}; @@ -1217,6 +1177,8 @@ Result<> IRBuilder::makeLoad(unsigned bytes, Result<> IRBuilder::makeStore( unsigned bytes, Address offset, unsigned align, Type type, Name mem) { Store curr; + curr.memory = mem; + curr.valueType = type; CHECK_ERR(visitStore(&curr)); push( builder.makeStore(bytes, offset, align, curr.ptr, curr.value, type, mem)); @@ -1226,6 +1188,7 @@ Result<> IRBuilder::makeStore( Result<> IRBuilder::makeAtomicLoad(unsigned bytes, Address offset, Type type, Name mem) { Load curr; + curr.memory = mem; CHECK_ERR(visitLoad(&curr)); push(builder.makeAtomicLoad(bytes, offset, curr.ptr, type, mem)); return Ok{}; @@ -1236,6 +1199,8 @@ Result<> IRBuilder::makeAtomicStore(unsigned bytes, Type type, Name mem) { Store curr; + curr.memory = mem; + curr.valueType = type; CHECK_ERR(visitStore(&curr)); push(builder.makeAtomicStore(bytes, offset, curr.ptr, curr.value, type, mem)); return Ok{}; @@ -1244,6 +1209,8 @@ Result<> IRBuilder::makeAtomicStore(unsigned bytes, Result<> IRBuilder::makeAtomicRMW( AtomicRMWOp op, unsigned bytes, Address offset, Type type, Name mem) { AtomicRMW curr; + curr.memory = mem; + curr.type = type; CHECK_ERR(visitAtomicRMW(&curr)); push( builder.makeAtomicRMW(op, bytes, offset, curr.ptr, curr.value, type, mem)); @@ -1255,7 +1222,8 @@ Result<> IRBuilder::makeAtomicCmpxchg(unsigned bytes, Type type, Name mem) { AtomicCmpxchg curr; - CHECK_ERR(visitAtomicCmpxchg(&curr)); + curr.memory = mem; + CHECK_ERR(ChildPopper{*this}.visitAtomicCmpxchg(&curr, type)); push(builder.makeAtomicCmpxchg( bytes, offset, curr.ptr, curr.expected, curr.replacement, type, mem)); return Ok{}; @@ -1263,6 +1231,8 @@ Result<> IRBuilder::makeAtomicCmpxchg(unsigned bytes, Result<> IRBuilder::makeAtomicWait(Type type, Address offset, Name mem) { AtomicWait curr; + curr.memory = mem; + curr.expectedType = type; CHECK_ERR(visitAtomicWait(&curr)); push(builder.makeAtomicWait( curr.ptr, curr.expected, curr.timeout, type, offset, mem)); @@ -1271,6 +1241,7 @@ Result<> IRBuilder::makeAtomicWait(Type type, Address offset, Name mem) { Result<> IRBuilder::makeAtomicNotify(Address offset, Name mem) { AtomicNotify curr; + curr.memory = mem; CHECK_ERR(visitAtomicNotify(&curr)); push(builder.makeAtomicNotify(curr.ptr, curr.notifyCount, offset, mem)); return Ok{}; @@ -1290,6 +1261,7 @@ Result<> IRBuilder::makeSIMDExtract(SIMDExtractOp op, uint8_t lane) { Result<> IRBuilder::makeSIMDReplace(SIMDReplaceOp op, uint8_t lane) { SIMDReplace curr; + curr.op = op; CHECK_ERR(visitSIMDReplace(&curr)); push(builder.makeSIMDReplace(op, curr.vec, lane, curr.value)); return Ok{}; @@ -1321,6 +1293,7 @@ Result<> IRBuilder::makeSIMDLoad(SIMDLoadOp op, unsigned align, Name mem) { SIMDLoad curr; + curr.memory = mem; CHECK_ERR(visitSIMDLoad(&curr)); push(builder.makeSIMDLoad(op, offset, align, curr.ptr, mem)); return Ok{}; @@ -1332,6 +1305,7 @@ Result<> IRBuilder::makeSIMDLoadStoreLane(SIMDLoadStoreLaneOp op, uint8_t lane, Name mem) { SIMDLoadStoreLane curr; + curr.memory = mem; CHECK_ERR(visitSIMDLoadStoreLane(&curr)); push(builder.makeSIMDLoadStoreLane( op, offset, align, lane, curr.ptr, curr.vec, mem)); @@ -1340,6 +1314,7 @@ Result<> IRBuilder::makeSIMDLoadStoreLane(SIMDLoadStoreLaneOp op, Result<> IRBuilder::makeMemoryInit(Name data, Name mem) { MemoryInit curr; + curr.memory = mem; CHECK_ERR(visitMemoryInit(&curr)); push(builder.makeMemoryInit(data, curr.dest, curr.offset, curr.size, mem)); return Ok{}; @@ -1352,6 +1327,8 @@ Result<> IRBuilder::makeDataDrop(Name data) { Result<> IRBuilder::makeMemoryCopy(Name destMem, Name srcMem) { MemoryCopy curr; + curr.destMemory = destMem; + curr.sourceMemory = srcMem; CHECK_ERR(visitMemoryCopy(&curr)); push( builder.makeMemoryCopy(curr.dest, curr.source, curr.size, destMem, srcMem)); @@ -1360,6 +1337,7 @@ Result<> IRBuilder::makeMemoryCopy(Name destMem, Name srcMem) { Result<> IRBuilder::makeMemoryFill(Name mem) { MemoryFill curr; + curr.memory = mem; CHECK_ERR(visitMemoryFill(&curr)); push(builder.makeMemoryFill(curr.dest, curr.value, curr.size, mem)); return Ok{}; @@ -1372,6 +1350,7 @@ Result<> IRBuilder::makeConst(Literal val) { Result<> IRBuilder::makeUnary(UnaryOp op) { Unary curr; + curr.op = op; CHECK_ERR(visitUnary(&curr)); push(builder.makeUnary(op, curr.value)); return Ok{}; @@ -1379,6 +1358,7 @@ Result<> IRBuilder::makeUnary(UnaryOp op) { Result<> IRBuilder::makeBinary(BinaryOp op) { Binary curr; + curr.op = op; CHECK_ERR(visitBinary(&curr)); push(builder.makeBinary(op, curr.left, curr.right)); return Ok{}; @@ -1399,7 +1379,7 @@ Result<> IRBuilder::makeSelect(std::optional<Type> type) { Result<> IRBuilder::makeDrop() { Drop curr; - CHECK_ERR(visitDrop(&curr, 1)); + CHECK_ERR(ChildPopper{*this}.visitDrop(&curr, 1)); push(builder.makeDrop(curr.value)); return Ok{}; } @@ -1418,6 +1398,7 @@ Result<> IRBuilder::makeMemorySize(Name mem) { Result<> IRBuilder::makeMemoryGrow(Name mem) { MemoryGrow curr; + curr.memory = mem; CHECK_ERR(visitMemoryGrow(&curr)); push(builder.makeMemoryGrow(curr.delta, mem)); return Ok{}; @@ -1480,6 +1461,7 @@ Result<> IRBuilder::makeTableGet(Name table) { Result<> IRBuilder::makeTableSet(Name table) { TableSet curr; + curr.table = table; CHECK_ERR(visitTableSet(&curr)); push(builder.makeTableSet(table, curr.index, curr.value)); return Ok{}; @@ -1492,6 +1474,7 @@ Result<> IRBuilder::makeTableSize(Name table) { Result<> IRBuilder::makeTableGrow(Name table) { TableGrow curr; + curr.table = table; CHECK_ERR(visitTableGrow(&curr)); push(builder.makeTableGrow(table, curr.value, curr.delta)); return Ok{}; @@ -1499,6 +1482,7 @@ Result<> IRBuilder::makeTableGrow(Name table) { Result<> IRBuilder::makeTableFill(Name table) { TableFill curr; + curr.table = table; CHECK_ERR(visitTableFill(&curr)); push(builder.makeTableFill(table, curr.dest, curr.value, curr.size)); return Ok{}; @@ -1539,6 +1523,7 @@ Result<> IRBuilder::makeTryTable(Name label, Result<> IRBuilder::makeThrow(Name tag) { Throw curr(wasm.allocator); curr.tag = tag; + curr.operands.resize(wasm.getTag(tag)->sig.params.size()); CHECK_ERR(visitThrow(&curr)); push(builder.makeThrow(tag, curr.operands)); return Ok{}; @@ -1578,7 +1563,7 @@ Result<> IRBuilder::makeTupleExtract(uint32_t arity, uint32_t index) { return Err{"tuple arity must be at least 2"}; } TupleExtract curr; - CHECK_ERR(visitTupleExtract(&curr, arity)); + CHECK_ERR(ChildPopper{*this}.visitTupleExtract(&curr, arity)); push(builder.makeTupleExtract(curr.tuple, index)); return Ok{}; } @@ -1588,7 +1573,7 @@ Result<> IRBuilder::makeTupleDrop(uint32_t arity) { return Err{"tuple arity must be at least 2"}; } Drop curr; - CHECK_ERR(visitDrop(&curr, arity)); + CHECK_ERR(ChildPopper{*this}.visitDrop(&curr, arity)); push(builder.makeDrop(curr.value)); return Ok{}; } @@ -1614,7 +1599,7 @@ Result<> IRBuilder::makeCallRef(HeapType type, bool isReturn) { } auto sig = type.getSignature(); curr.operands.resize(type.getSignature().params.size()); - CHECK_ERR(visitCallRef(&curr)); + CHECK_ERR(ChildPopper{*this}.visitCallRef(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.target)); push(builder.makeCallRef(curr.target, curr.operands, sig.results, isReturn)); return Ok{}; @@ -1622,6 +1607,7 @@ Result<> IRBuilder::makeCallRef(HeapType type, bool isReturn) { Result<> IRBuilder::makeRefTest(Type type) { RefTest curr; + curr.castType = type; CHECK_ERR(visitRefTest(&curr)); push(builder.makeRefTest(curr.ref, type)); return Ok{}; @@ -1629,6 +1615,7 @@ Result<> IRBuilder::makeRefTest(Type type) { Result<> IRBuilder::makeRefCast(Type type) { RefCast curr; + curr.type = type; CHECK_ERR(visitRefCast(&curr)); push(builder.makeRefCast(curr.ref, type)); return Ok{}; @@ -1636,6 +1623,8 @@ Result<> IRBuilder::makeRefCast(Type type) { Result<> IRBuilder::makeBrOn(Index label, BrOnOp op, Type in, Type out) { BrOn curr; + curr.op = op; + curr.castType = out; CHECK_ERR(visitBrOn(&curr)); if (out != Type::none) { if (!Type::isSubType(out, in)) { @@ -1653,6 +1642,7 @@ Result<> IRBuilder::makeBrOn(Index label, BrOnOp op, Type in, Type out) { Result<> IRBuilder::makeStructNew(HeapType type) { StructNew curr(wasm.allocator); + curr.type = Type(type, NonNullable); // Differentiate from struct.new_default with a non-empty expression list. curr.operands.resize(type.getStruct().fields.size()); CHECK_ERR(visitStructNew(&curr)); @@ -1668,7 +1658,7 @@ Result<> IRBuilder::makeStructNewDefault(HeapType type) { Result<> IRBuilder::makeStructGet(HeapType type, Index field, bool signed_) { const auto& fields = type.getStruct().fields; StructGet curr; - CHECK_ERR(visitStructGet(&curr)); + CHECK_ERR(ChildPopper{*this}.visitStructGet(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeStructGet(field, curr.ref, fields[field].type, signed_)); return Ok{}; @@ -1676,7 +1666,8 @@ Result<> IRBuilder::makeStructGet(HeapType type, Index field, bool signed_) { Result<> IRBuilder::makeStructSet(HeapType type, Index field) { StructSet curr; - CHECK_ERR(visitStructSet(&curr)); + curr.index = field; + CHECK_ERR(ChildPopper{*this}.visitStructSet(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeStructSet(field, curr.ref, curr.value)); return Ok{}; @@ -1684,6 +1675,7 @@ Result<> IRBuilder::makeStructSet(HeapType type, Index field) { Result<> IRBuilder::makeArrayNew(HeapType type) { ArrayNew curr; + curr.type = Type(type, NonNullable); // Differentiate from array.new_default with dummy initializer. curr.init = (Expression*)0x01; CHECK_ERR(visitArrayNew(&curr)); @@ -1693,6 +1685,7 @@ Result<> IRBuilder::makeArrayNew(HeapType type) { Result<> IRBuilder::makeArrayNewDefault(HeapType type) { ArrayNew curr; + curr.init = nullptr; CHECK_ERR(visitArrayNew(&curr)); push(builder.makeArrayNew(type, curr.size)); return Ok{}; @@ -1714,6 +1707,7 @@ Result<> IRBuilder::makeArrayNewElem(HeapType type, Name elem) { Result<> IRBuilder::makeArrayNewFixed(HeapType type, uint32_t arity) { ArrayNewFixed curr(wasm.allocator); + curr.type = Type(type, NonNullable); curr.values.resize(arity); CHECK_ERR(visitArrayNewFixed(&curr)); push(builder.makeArrayNewFixed(type, curr.values)); @@ -1722,7 +1716,7 @@ Result<> IRBuilder::makeArrayNewFixed(HeapType type, uint32_t arity) { Result<> IRBuilder::makeArrayGet(HeapType type, bool signed_) { ArrayGet curr; - CHECK_ERR(visitArrayGet(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArrayGet(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeArrayGet( curr.ref, curr.index, type.getArray().element.type, signed_)); @@ -1731,7 +1725,7 @@ Result<> IRBuilder::makeArrayGet(HeapType type, bool signed_) { Result<> IRBuilder::makeArraySet(HeapType type) { ArraySet curr; - CHECK_ERR(visitArraySet(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArraySet(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeArraySet(curr.ref, curr.index, curr.value)); return Ok{}; @@ -1746,7 +1740,7 @@ Result<> IRBuilder::makeArrayLen() { Result<> IRBuilder::makeArrayCopy(HeapType destType, HeapType srcType) { ArrayCopy curr; - CHECK_ERR(visitArrayCopy(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArrayCopy(&curr, destType, srcType)); CHECK_ERR(validateTypeAnnotation(destType, curr.destRef)); CHECK_ERR(validateTypeAnnotation(srcType, curr.srcRef)); push(builder.makeArrayCopy( @@ -1756,7 +1750,7 @@ Result<> IRBuilder::makeArrayCopy(HeapType destType, HeapType srcType) { Result<> IRBuilder::makeArrayFill(HeapType type) { ArrayFill curr; - CHECK_ERR(visitArrayFill(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArrayFill(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeArrayFill(curr.ref, curr.index, curr.value, curr.size)); return Ok{}; @@ -1764,7 +1758,7 @@ Result<> IRBuilder::makeArrayFill(HeapType type) { Result<> IRBuilder::makeArrayInitData(HeapType type, Name data) { ArrayInitData curr; - CHECK_ERR(visitArrayInitData(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArrayInitData(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeArrayInitData( data, curr.ref, curr.index, curr.offset, curr.size)); @@ -1773,7 +1767,7 @@ Result<> IRBuilder::makeArrayInitData(HeapType type, Name data) { Result<> IRBuilder::makeArrayInitElem(HeapType type, Name elem) { ArrayInitElem curr; - CHECK_ERR(visitArrayInitElem(&curr)); + CHECK_ERR(ChildPopper{*this}.visitArrayInitElem(&curr, type)); CHECK_ERR(validateTypeAnnotation(type, curr.ref)); push(builder.makeArrayInitElem( elem, curr.ref, curr.index, curr.offset, curr.size)); @@ -1782,6 +1776,7 @@ Result<> IRBuilder::makeArrayInitElem(HeapType type, Name elem) { Result<> IRBuilder::makeRefAs(RefAsOp op) { RefAs curr; + curr.op = op; CHECK_ERR(visitRefAs(&curr)); push(builder.makeRefAs(op, curr.value)); return Ok{}; @@ -1790,22 +1785,28 @@ Result<> IRBuilder::makeRefAs(RefAsOp op) { Result<> IRBuilder::makeStringNew(StringNewOp op, bool try_, Name mem) { StringNew curr; curr.op = op; - CHECK_ERR(visitStringNew(&curr)); // TODO: Store the memory in the IR. switch (op) { case StringNewUTF8: case StringNewWTF8: case StringNewLossyUTF8: case StringNewWTF16: + CHECK_ERR(visitStringNew(&curr)); push(builder.makeStringNew(op, curr.ptr, curr.length, try_)); return Ok{}; case StringNewUTF8Array: case StringNewWTF8Array: case StringNewLossyUTF8Array: case StringNewWTF16Array: + // There's no type annotation on these instructions due to a bug in the + // stringref proposal, so we just fudge it and pass `array` instead of a + // defined heap type. This will allow us to pop a child with an invalid + // array type, but that's just too bad. + CHECK_ERR(ChildPopper{*this}.visitStringNew(&curr, HeapType::array)); push(builder.makeStringNew(op, curr.ptr, curr.start, curr.end, try_)); return Ok{}; case StringNewFromCodePoint: + CHECK_ERR(visitStringNew(&curr)); push(builder.makeStringNew(op, curr.ptr, nullptr, try_)); return Ok{}; } @@ -1819,6 +1820,7 @@ Result<> IRBuilder::makeStringConst(Name string) { Result<> IRBuilder::makeStringMeasure(StringMeasureOp op) { StringMeasure curr; + curr.op = op; CHECK_ERR(visitStringMeasure(&curr)); push(builder.makeStringMeasure(op, curr.ref)); return Ok{}; @@ -1827,10 +1829,30 @@ Result<> IRBuilder::makeStringMeasure(StringMeasureOp op) { Result<> IRBuilder::makeStringEncode(StringEncodeOp op, Name mem) { StringEncode curr; curr.op = op; - CHECK_ERR(visitStringEncode(&curr)); // TODO: Store the memory in the IR. - push(builder.makeStringEncode(op, curr.ref, curr.ptr, curr.start)); - return Ok{}; + switch (op) { + case StringEncodeUTF8: + case StringEncodeLossyUTF8: + case StringEncodeWTF8: + case StringEncodeWTF16: { + CHECK_ERR(visitStringEncode(&curr)); + push(builder.makeStringEncode(op, curr.ref, curr.ptr, curr.start)); + return Ok{}; + } + case StringEncodeUTF8Array: + case StringEncodeLossyUTF8Array: + case StringEncodeWTF8Array: + case StringEncodeWTF16Array: { + // There's no type annotation on these instructions due to a bug in the + // stringref proposal, so we just fudge it and pass `array` instead of a + // defined heap type. This will allow us to pop a child with an invalid + // array type, but that's just too bad. + CHECK_ERR(ChildPopper{*this}.visitStringEncode(&curr, HeapType::array)); + push(builder.makeStringEncode(op, curr.ref, curr.ptr, curr.start)); + return Ok{}; + } + } + WASM_UNREACHABLE("unexpected op"); } Result<> IRBuilder::makeStringConcat() { @@ -1884,6 +1906,7 @@ Result<> IRBuilder::makeStringIterMove(StringIterMoveOp op) { Result<> IRBuilder::makeStringSliceWTF(StringSliceWTFOp op) { StringSliceWTF curr; + curr.op = op; CHECK_ERR(visitStringSliceWTF(&curr)); push(builder.makeStringSliceWTF(op, curr.ref, curr.start, curr.end)); return Ok{}; @@ -1904,6 +1927,17 @@ Result<> IRBuilder::makeContBind(HeapType contTypeBefore, ContBind curr(wasm.allocator); curr.contTypeBefore = contTypeBefore; curr.contTypeAfter = contTypeAfter; + size_t paramsBefore = + contTypeBefore.getContinuation().type.getSignature().params.size(); + size_t paramsAfter = + contTypeAfter.getContinuation().type.getSignature().params.size(); + if (paramsBefore < paramsAfter) { + return Err{"incompatible continuation types in cont.bind: source type " + + contTypeBefore.toString() + + " has fewer parameters than destination " + + contTypeAfter.toString()}; + } + curr.operands.resize(paramsBefore - paramsAfter); CHECK_ERR(visitContBind(&curr)); std::vector<Expression*> operands(curr.operands.begin(), curr.operands.end()); @@ -1917,6 +1951,7 @@ Result<> IRBuilder::makeContNew(HeapType ct) { return Err{"expected continuation type"}; } ContNew curr; + curr.contType = ct; CHECK_ERR(visitContNew(&curr)); push(builder.makeContNew(ct, curr.func)); @@ -1931,6 +1966,7 @@ Result<> IRBuilder::makeResume(HeapType ct, } Resume curr(wasm.allocator); curr.contType = ct; + curr.operands.resize(ct.getContinuation().type.getSignature().params.size()); CHECK_ERR(visitResume(&curr)); std::vector<Name> labelNames; @@ -1948,6 +1984,7 @@ Result<> IRBuilder::makeResume(HeapType ct, Result<> IRBuilder::makeSuspend(Name tag) { Suspend curr(wasm.allocator); curr.tag = tag; + curr.operands.resize(wasm.getTag(tag)->sig.params.size()); CHECK_ERR(visitSuspend(&curr)); std::vector<Expression*> operands(curr.operands.begin(), curr.operands.end()); diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 16ab426ad..155d8556e 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -1081,6 +1081,19 @@ Type Type::getGreatestLowerBound(Type a, Type b) { if (a == b) { return a; } + if (a.isTuple() && b.isTuple() && a.size() == b.size()) { + std::vector<Type> elems; + size_t size = a.size(); + elems.reserve(size); + for (size_t i = 0; i < size; ++i) { + auto glb = Type::getGreatestLowerBound(a[i], b[i]); + if (glb == Type::unreachable) { + return Type::unreachable; + } + elems.push_back(glb); + } + return Tuple(elems); + } if (!a.isRef() || !b.isRef()) { return Type::unreachable; } |