From bcc76146fed433cbc8ba01a9f568d979c145110b Mon Sep 17 00:00:00 2001 From: Heejin Ahn Date: Mon, 30 Dec 2019 17:55:20 -0800 Subject: Add support for reference types proposal (#2451) This adds support for the reference type proposal. This includes support for all reference types (`anyref`, `funcref`(=`anyfunc`), and `nullref`) and four new instructions: `ref.null`, `ref.is_null`, `ref.func`, and new typed `select`. This also adds subtype relationship support between reference types. This does not include table instructions yet. This also does not include wasm2js support. Fixes #2444 and fixes #2447. --- src/asmjs/asm_v_wasm.cpp | 13 +- src/binaryen-c.cpp | 101 ++++++++-- src/binaryen-c.h | 19 +- src/gen-s-parser.inc | 64 +++++-- src/ir/ExpressionAnalyzer.cpp | 3 + src/ir/ExpressionManipulator.cpp | 15 +- src/ir/ReFinalize.cpp | 27 ++- src/ir/abstract.h | 12 +- src/ir/block-utils.h | 3 +- src/ir/effects.h | 3 + src/ir/flat.h | 12 +- src/ir/global-utils.h | 6 + src/ir/literal-utils.h | 4 + src/ir/manipulation.h | 16 +- src/ir/properties.h | 4 + src/ir/utils.h | 6 + src/js/binaryen.js-post.js | 53 +++++- src/literal.h | 36 +++- src/parsing.h | 6 +- src/passes/ConstHoisting.cpp | 9 +- src/passes/DeadCodeElimination.cpp | 6 + src/passes/Flatten.cpp | 37 +++- src/passes/FuncCastEmulation.cpp | 16 +- src/passes/Inlining.cpp | 17 +- src/passes/InstrumentLocals.cpp | 32 ++++ src/passes/LegalizeJSInterface.cpp | 33 +++- src/passes/LocalCSE.cpp | 7 +- src/passes/MergeLocals.cpp | 15 +- src/passes/OptimizeInstructions.cpp | 6 +- src/passes/Precompute.cpp | 19 +- src/passes/Print.cpp | 40 +++- src/passes/RemoveUnusedModuleElements.cpp | 6 + src/passes/SimplifyGlobals.cpp | 19 +- src/passes/SimplifyLocals.cpp | 6 +- src/passes/opt-utils.h | 13 +- src/shell-interface.h | 8 +- src/support/name.h | 2 +- src/support/small_vector.h | 10 +- src/tools/execution-results.h | 24 ++- src/tools/fuzzing.h | 253 +++++++++++++++++------- src/tools/spec-wrapper.h | 8 +- src/tools/wasm-reduce.cpp | 14 ++ src/tools/wasm-shell.cpp | 4 +- src/wasm-binary.h | 28 ++- src/wasm-builder.h | 73 ++++++- src/wasm-interpreter.h | 47 +++-- src/wasm-s-parser.h | 3 + src/wasm-stack.h | 30 +++ src/wasm-traversal.h | 57 +++++- src/wasm-type.h | 24 ++- src/wasm.h | 32 ++++ src/wasm/literal.cpp | 45 ++++- src/wasm/wasm-binary.cpp | 83 ++++++-- src/wasm/wasm-s-parser.cpp | 43 ++++- src/wasm/wasm-stack.cpp | 76 ++++++-- src/wasm/wasm-type.cpp | 53 +++++- src/wasm/wasm-validator.cpp | 306 +++++++++++++++++------------- src/wasm/wasm.cpp | 90 ++++----- src/wasm2js.h | 12 ++ 59 files changed, 1532 insertions(+), 477 deletions(-) (limited to 'src') diff --git a/src/asmjs/asm_v_wasm.cpp b/src/asmjs/asm_v_wasm.cpp index 3720ca079..5959db43e 100644 --- a/src/asmjs/asm_v_wasm.cpp +++ b/src/asmjs/asm_v_wasm.cpp @@ -53,10 +53,11 @@ AsmType wasmToAsmType(Type type) { return ASM_INT64; case v128: assert(false && "v128 not implemented yet"); + case funcref: case anyref: - assert(false && "anyref is not supported by asm2wasm"); + case nullref: case exnref: - assert(false && "exnref is not supported by asm2wasm"); + assert(false && "reference types are not supported by asm2wasm"); case none: return ASM_NONE; case unreachable: @@ -77,10 +78,14 @@ char getSig(Type type) { return 'd'; case v128: return 'V'; + case funcref: + return 'F'; case anyref: - return 'a'; + return 'A'; + case nullref: + return 'N'; case exnref: - return 'e'; + return 'E'; case none: return 'v'; case unreachable: diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 82cbc4c1f..826193e06 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -64,13 +64,16 @@ BinaryenLiteral toBinaryenLiteral(Literal x) { case Type::f64: ret.i64 = x.reinterpreti64(); break; - case Type::v128: { + case Type::v128: memcpy(&ret.v128, x.getv128Ptr(), 16); break; - } - - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + ret.func = x.getFunc().c_str(); + break; + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -90,8 +93,12 @@ Literal fromBinaryenLiteral(BinaryenLiteral x) { return Literal(x.i64).castToF64(); case Type::v128: return Literal(x.v128); - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + return Literal::makeFuncref(x.func); + case Type::nullref: + return Literal::makeNullref(); + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -209,8 +216,14 @@ void printArg(std::ostream& setup, std::ostream& out, BinaryenLiteral arg) { out << "BinaryenLiteralVec128(" << array << ")"; break; } - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + out << "BinaryenLiteralFuncref(" << arg.func << ")"; + break; + case Type::nullref: + out << "BinaryenLiteralNullref()"; + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -265,7 +278,9 @@ BinaryenType BinaryenTypeInt64(void) { return i64; } BinaryenType BinaryenTypeFloat32(void) { return f32; } BinaryenType BinaryenTypeFloat64(void) { return f64; } BinaryenType BinaryenTypeVec128(void) { return v128; } +BinaryenType BinaryenTypeFuncref(void) { return funcref; } BinaryenType BinaryenTypeAnyref(void) { return anyref; } +BinaryenType BinaryenTypeNullref(void) { return nullref; } BinaryenType BinaryenTypeExnref(void) { return exnref; } BinaryenType BinaryenTypeUnreachable(void) { return unreachable; } BinaryenType BinaryenTypeAuto(void) { return uint32_t(-1); } @@ -397,6 +412,15 @@ BinaryenExpressionId BinaryenMemoryCopyId(void) { BinaryenExpressionId BinaryenMemoryFillId(void) { return Expression::Id::MemoryFillId; } +BinaryenExpressionId BinaryenRefNullId(void) { + return Expression::Id::RefNullId; +} +BinaryenExpressionId BinaryenRefIsNullId(void) { + return Expression::Id::RefIsNullId; +} +BinaryenExpressionId BinaryenRefFuncId(void) { + return Expression::Id::RefFuncId; +} BinaryenExpressionId BinaryenTryId(void) { return Expression::Id::TryId; } BinaryenExpressionId BinaryenThrowId(void) { return Expression::Id::ThrowId; } BinaryenExpressionId BinaryenRethrowId(void) { @@ -1330,17 +1354,22 @@ BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, - BinaryenExpressionRef ifFalse) { + BinaryenExpressionRef ifFalse, + BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc(); + ret->condition = condition; + ret->ifTrue = ifTrue; + ret->ifFalse = ifFalse; + ret->finalize(type); + return ret; + } Return* makeReturn(Expression* value = nullptr) { auto* ret = allocator.alloc(); ret->value = value; @@ -502,6 +527,23 @@ public: ret->finalize(); return ret; } + RefNull* makeRefNull() { + auto* ret = allocator.alloc(); + ret->finalize(); + return ret; + } + RefIsNull* makeRefIsNull(Expression* value) { + auto* ret = allocator.alloc(); + ret->value = value; + ret->finalize(); + return ret; + } + RefFunc* makeRefFunc(Name func) { + auto* ret = allocator.alloc(); + ret->func = func; + ret->finalize(); + return ret; + } Try* makeTry(Expression* body, Expression* catchBody) { auto* ret = allocator.alloc(); ret->body = body; @@ -569,6 +611,21 @@ public: return ret; } + Expression* makeConstExpression(Literal value) { + switch (value.type) { + case Type::nullref: + return makeRefNull(); + case Type::funcref: + if (value.getFunc()[0] != 0) { + return makeRefFunc(value.getFunc()); + } + return makeRefNull(); + default: + assert(value.type.isNumber()); + return makeConst(value); + } + } + // Additional utility functions for building on top of nodes // Convenient to have these on Builder, as it has allocation built in @@ -663,6 +720,13 @@ public: return block; } + Block* makeSequence(Expression* left, Expression* right, Type type) { + auto* block = makeBlock(left); + block->list.push_back(right); + block->finalize(type); + return block; + } + // Grab a slice out of a block, replacing it with nops, and returning // either another block with the contents (if more than 1) or a single // expression @@ -728,16 +792,15 @@ public: value = Literal(bytes.data()); break; } + case funcref: case anyref: - // TODO Implement and return nullref - assert(false && "anyref not implemented yet"); + case nullref: case exnref: - // TODO Implement and return nullref - assert(false && "exnref not implemented yet"); + return ExpressionManipulator::refNull(curr); case none: return ExpressionManipulator::nop(curr); case unreachable: - return ExpressionManipulator::convert(curr); + return ExpressionManipulator::unreachable(curr); } return makeConst(value); } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 571f0d1a5..f37a6edd6 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -143,13 +143,13 @@ public: if (!ret.breaking() && (curr->type.isConcrete() || ret.value.type.isConcrete())) { #if 1 // def WASM_INTERPRETER_DEBUG - if (ret.value.type != curr->type) { + if (!Type::isSubType(ret.value.type, curr->type)) { std::cerr << "expected " << curr->type << ", seeing " << ret.value.type << " from\n" << curr << '\n'; } #endif - assert(ret.value.type == curr->type); + assert(Type::isSubType(ret.value.type, curr->type)); } depth--; return ret; @@ -1095,7 +1095,7 @@ public: return Literal(uint64_t(val)); } } - Flow visitAtomicFence(AtomicFence*) { + Flow visitAtomicFence(AtomicFence* curr) { // Wasm currently supports only sequentially consistent atomics, in which // case atomic_fence can be lowered to nothing. NOTE_ENTER("AtomicFence"); @@ -1123,6 +1123,26 @@ public: Flow visitSIMDLoadExtend(SIMDLoad*) { WASM_UNREACHABLE("unimp"); } Flow visitPush(Push*) { WASM_UNREACHABLE("unimp"); } Flow visitPop(Pop*) { WASM_UNREACHABLE("unimp"); } + Flow visitRefNull(RefNull* curr) { + NOTE_ENTER("RefNull"); + return Literal::makeNullref(); + } + Flow visitRefIsNull(RefIsNull* curr) { + NOTE_ENTER("RefIsNull"); + Flow flow = visit(curr->value); + if (flow.breaking()) { + return flow; + } + Literal value = flow.value; + NOTE_EVAL1(value); + return Literal(value.type == nullref); + } + Flow visitRefFunc(RefFunc* curr) { + NOTE_ENTER("RefFunc"); + NOTE_NAME(curr->func); + return Literal::makeFuncref(curr->func); + } + // TODO Implement EH instructions Flow visitTry(Try*) { WASM_UNREACHABLE("unimp"); } Flow visitThrow(Throw*) { WASM_UNREACHABLE("unimp"); } Flow visitRethrow(Rethrow*) { WASM_UNREACHABLE("unimp"); } @@ -1217,8 +1237,10 @@ public: return Literal(load64u(addr)).castToF64(); case v128: return Literal(load128(addr).data()); - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1272,8 +1294,10 @@ public: case v128: store128(addr, value.getv128()); break; - case anyref: // anyref cannot be stored from memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1464,7 +1488,7 @@ private: for (size_t i = 0; i < function->getNumLocals(); i++) { if (i < arguments.size()) { assert(i < params.size()); - if (params[i] != arguments[i].type) { + if (!Type::isSubType(arguments[i].type, params[i])) { std::cerr << "Function `" << function->name << "` expects type " << params[i] << " for parameter " << i << ", got " << arguments[i].type << "." << std::endl; @@ -1473,7 +1497,7 @@ private: locals[i] = arguments[i]; } else { assert(function->isVar(i)); - locals[i].type = function->getLocalType(i); + locals[i] = Literal::makeZero(function->getLocalType(i)); } } } @@ -1580,7 +1604,8 @@ private: } NOTE_EVAL1(index); NOTE_EVAL1(flow.value); - assert(curr->isTee() ? flow.value.type == curr->type : true); + assert(curr->isTee() ? Type::isSubType(flow.value.type, curr->type) + : true); scope.locals[index] = flow.value; return curr->isTee() ? flow : Flow(); } @@ -2067,7 +2092,7 @@ public: // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); Literal ret = flow.value; - if (function->sig.results != ret.type) { + if (!Type::isSubType(ret.type, function->sig.results)) { std::cerr << "calling " << function->name << " resulted in " << ret << " but the function type is " << function->sig.results << '\n'; diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index d7324d756..8cdcb88f4 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -225,6 +225,9 @@ private: Expression* makeBreak(Element& s); Expression* makeBreakTable(Element& s); Expression* makeReturn(Element& s); + Expression* makeRefNull(Element& s); + Expression* makeRefIsNull(Element& s); + Expression* makeRefFunc(Element& s); Expression* makeTry(Element& s); Expression* makeCatch(Element& s, Type type); Expression* makeThrow(Element& s); diff --git a/src/wasm-stack.h b/src/wasm-stack.h index fbd28b0d5..91c0c5383 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -128,6 +128,9 @@ public: void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -207,6 +210,9 @@ public: void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -698,6 +704,30 @@ void BinaryenIRWriter::visitHost(Host* curr) { emit(curr); } +template +void BinaryenIRWriter::visitRefNull(RefNull* curr) { + emit(curr); +} + +template +void BinaryenIRWriter::visitRefIsNull(RefIsNull* curr) { + visit(curr->value); + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + +template +void BinaryenIRWriter::visitRefFunc(RefFunc* curr) { + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + template void BinaryenIRWriter::visitTry(Try* curr) { emit(curr); visitPossibleBlockContents(curr->body); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 9c6e78360..c9290cbab 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -72,6 +72,9 @@ template struct Visitor { ReturnType visitDrop(Drop* curr) { return ReturnType(); } ReturnType visitReturn(Return* curr) { return ReturnType(); } ReturnType visitHost(Host* curr) { return ReturnType(); } + ReturnType visitRefNull(RefNull* curr) { return ReturnType(); } + ReturnType visitRefIsNull(RefIsNull* curr) { return ReturnType(); } + ReturnType visitRefFunc(RefFunc* curr) { return ReturnType(); } ReturnType visitTry(Try* curr) { return ReturnType(); } ReturnType visitThrow(Throw* curr) { return ReturnType(); } ReturnType visitRethrow(Rethrow* curr) { return ReturnType(); } @@ -167,6 +170,12 @@ template struct Visitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -241,6 +250,9 @@ struct OverriddenVisitor { UNIMPLEMENTED(Drop); UNIMPLEMENTED(Return); UNIMPLEMENTED(Host); + UNIMPLEMENTED(RefNull); + UNIMPLEMENTED(RefIsNull); + UNIMPLEMENTED(RefFunc); UNIMPLEMENTED(Try); UNIMPLEMENTED(Throw); UNIMPLEMENTED(Rethrow); @@ -337,6 +349,12 @@ struct OverriddenVisitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -476,6 +494,15 @@ struct UnifiedExpressionVisitor : public Visitor { ReturnType visitHost(Host* curr) { return static_cast(this)->visitExpression(curr); } + ReturnType visitRefNull(RefNull* curr) { + return static_cast(this)->visitExpression(curr); + } + ReturnType visitRefIsNull(RefIsNull* curr) { + return static_cast(this)->visitExpression(curr); + } + ReturnType visitRefFunc(RefFunc* curr) { + return static_cast(this)->visitExpression(curr); + } ReturnType visitTry(Try* curr) { return static_cast(this)->visitExpression(curr); } @@ -778,6 +805,15 @@ struct Walker : public VisitorType { static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast()); } + static void doVisitRefNull(SubType* self, Expression** currp) { + self->visitRefNull((*currp)->cast()); + } + static void doVisitRefIsNull(SubType* self, Expression** currp) { + self->visitRefIsNull((*currp)->cast()); + } + static void doVisitRefFunc(SubType* self, Expression** currp) { + self->visitRefFunc((*currp)->cast()); + } static void doVisitTry(SubType* self, Expression** currp) { self->visitTry((*currp)->cast()); } @@ -1036,6 +1072,19 @@ struct PostWalker : public Walker { } break; } + case Expression::Id::RefNullId: { + self->pushTask(SubType::doVisitRefNull, currp); + break; + } + case Expression::Id::RefIsNullId: { + self->pushTask(SubType::doVisitRefIsNull, currp); + self->pushTask(SubType::scan, &curr->cast()->value); + break; + } + case Expression::Id::RefFuncId: { + self->pushTask(SubType::doVisitRefFunc, currp); + break; + } case Expression::Id::TryId: { self->pushTask(SubType::doVisitTry, currp); self->pushTask(SubType::scan, &curr->cast()->catchBody); @@ -1099,7 +1148,7 @@ struct ControlFlowWalker : public PostWalker { Expression* findBreakTarget(Name name) { assert(!controlFlowStack.empty()); Index i = controlFlowStack.size() - 1; - while (1) { + while (true) { auto* curr = controlFlowStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) { @@ -1111,7 +1160,7 @@ struct ControlFlowWalker : public PostWalker { } } else { // an if, ignorable - assert(curr->template is()); + assert(curr->template is() || curr->template is()); } if (i == 0) { return nullptr; @@ -1169,7 +1218,7 @@ struct ExpressionStackWalker : public PostWalker { Expression* findBreakTarget(Name name) { assert(!expressionStack.empty()); Index i = expressionStack.size() - 1; - while (1) { + while (true) { auto* curr = expressionStack[i]; if (Block* block = curr->template dynCast()) { if (name == block->name) { @@ -1179,8 +1228,6 @@ struct ExpressionStackWalker : public PostWalker { if (name == loop->name) { return curr; } - } else { - WASM_UNREACHABLE("unexpected expression type"); } if (i == 0) { return nullptr; diff --git a/src/wasm-type.h b/src/wasm-type.h index 53ef39ef8..668ac3e4d 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -36,7 +36,9 @@ public: f32, f64, v128, + funcref, anyref, + nullref, exnref, _last_value_type, }; @@ -64,7 +66,8 @@ public: bool isInteger() const { return id == i32 || id == i64; } bool isFloat() const { return id == f32 || id == f64; } bool isVector() const { return id == v128; }; - bool isRef() const { return id == anyref || id == exnref; } + bool isNumber() const { return id >= i32 && id <= v128; } + bool isRef() const { return id >= funcref && id <= exnref; } // (In)equality must be defined for both Type and ValueType because it is // otherwise ambiguous whether to convert both this and other to int or @@ -94,6 +97,23 @@ public: // type. static Type get(unsigned byteSize, bool float_); + // Returns true if left is a subtype of right. Subtype includes itself. + static bool isSubType(Type left, Type right); + + // Computes the least upper bound from the type lattice. + // If one of the type is unreachable, the other type becomes the result. If + // the common supertype does not exist, returns none, a poison value. + static Type getLeastUpperBound(Type a, Type b); + + // Computes the least upper bound for all types in the given list. + template static Type mergeTypes(const T& types) { + Type type = Type::unreachable; + for (auto other : types) { + type = Type::getLeastUpperBound(type, other); + } + return type; + } + std::string toString() const; }; @@ -134,7 +154,9 @@ constexpr Type i64 = Type::i64; constexpr Type f32 = Type::f32; constexpr Type f64 = Type::f64; constexpr Type v128 = Type::v128; +constexpr Type funcref = Type::funcref; constexpr Type anyref = Type::anyref; +constexpr Type nullref = Type::nullref; constexpr Type exnref = Type::exnref; constexpr Type unreachable = Type::unreachable; diff --git a/src/wasm.h b/src/wasm.h index 48adf103b..c4dbd2f3f 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -531,6 +531,9 @@ public: MemoryFillId, PushId, PopId, + RefNullId, + RefIsNullId, + RefFuncId, TryId, ThrowId, RethrowId, @@ -569,6 +572,8 @@ public: const char* getExpressionName(Expression* curr); +Literal getLiteralFromConstExpression(Expression* curr); + typedef ArenaVector ExpressionList; template class SpecificExpression : public Expression { @@ -1008,6 +1013,7 @@ public: Expression* condition; void finalize(); + void finalize(Type type_); }; class Drop : public SpecificExpression { @@ -1070,6 +1076,32 @@ public: Pop(MixedArena& allocator) {} }; +class RefNull : public SpecificExpression { +public: + RefNull() = default; + RefNull(MixedArena& allocator) {} + + void finalize(); +}; + +class RefIsNull : public SpecificExpression { +public: + RefIsNull(MixedArena& allocator) {} + + Expression* value; + + void finalize(); +}; + +class RefFunc : public SpecificExpression { +public: + RefFunc(MixedArena& allocator) {} + + Name func; + + void finalize(); +}; + class Try : public SpecificExpression { public: Try(MixedArena& allocator) {} diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 82a150257..4f66b36e3 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -137,8 +137,11 @@ void Literal::getBits(uint8_t (&buf)[16]) const { case Type::v128: memcpy(buf, &v128, sizeof(v128)); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -146,10 +149,20 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { + if (type.isRef() && other.type.isRef()) { + if (type == Type::nullref && other.type == Type::nullref) { + return true; + } + if (type == Type::funcref && other.type == Type::funcref && + func == other.func) { + return true; + } + return false; + } if (type != other.type) { return false; } - if (type == none) { + if (type == Type::none) { return true; } uint8_t bits[16], other_bits[16]; @@ -273,8 +286,14 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "i32x4 "; literal.printVec128(o, literal.getv128()); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + o << "funcref(" << literal.getFunc() << ")"; + break; + case Type::nullref: + o << "nullref"; + break; + case Type::anyref: + case Type::exnref: case Type::unreachable: WASM_UNREACHABLE("invalid type"); } @@ -477,7 +496,9 @@ Literal Literal::eqz() const { case Type::f64: return eq(Literal(double(0))); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -497,7 +518,9 @@ Literal Literal::neg() const { case Type::f64: return Literal(int64_t(i64 ^ 0x8000000000000000ULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -517,7 +540,9 @@ Literal Literal::abs() const { case Type::f64: return Literal(int64_t(i64 & 0x7fffffffffffffffULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -620,7 +645,9 @@ Literal Literal::add(const Literal& other) const { case Type::f64: return Literal(getf64() + other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -640,7 +667,9 @@ Literal Literal::sub(const Literal& other) const { case Type::f64: return Literal(getf64() - other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -731,7 +760,9 @@ Literal Literal::mul(const Literal& other) const { case Type::f64: return Literal(getf64() * other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -967,7 +998,9 @@ Literal Literal::eq(const Literal& other) const { case Type::f64: return Literal(getf64() == other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -987,7 +1020,9 @@ Literal Literal::ne(const Literal& other) const { case Type::f64: return Literal(getf64() != other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 82eb51d7e..ba5a8d3dd 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -262,7 +262,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one table\n"); writeImportHeader(&wasm->table); o << U32LEB(int32_t(ExternalKind::Table)); - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -463,7 +463,7 @@ void WasmBinaryWriter::writeFunctionTableDeclaration() { BYN_TRACE("== writeFunctionTableDeclaration\n"); auto start = startSection(BinaryConsts::Section::Table); o << U32LEB(1); // Declare 1 table. - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -1059,8 +1059,12 @@ Type WasmBinaryBuilder::getType() { return f64; case BinaryConsts::EncodedType::v128: return v128; + case BinaryConsts::EncodedType::funcref: + return funcref; case BinaryConsts::EncodedType::anyref: return anyref; + case BinaryConsts::EncodedType::nullref: + return nullref; case BinaryConsts::EncodedType::exnref: return exnref; default: @@ -1258,8 +1262,8 @@ void WasmBinaryBuilder::readImports() { wasm.table.name = Name(std::string("timport$") + std::to_string(i)); auto elementType = getS32LEB(); WASM_UNUSED(elementType); - if (elementType != BinaryConsts::EncodedType::AnyFunc) { - throwError("Imported table type is not AnyFunc"); + if (elementType != BinaryConsts::EncodedType::funcref) { + throwError("Imported table type is not funcref"); } wasm.table.exists = true; bool is_shared; @@ -1802,11 +1806,17 @@ void WasmBinaryBuilder::processFunctions() { wasm.addExport(curr); } - for (auto& iter : functionCalls) { + for (auto& iter : functionRefs) { size_t index = iter.first; - auto& calls = iter.second; - for (auto* call : calls) { - call->target = getFunctionName(index); + auto& refs = iter.second; + for (auto* ref : refs) { + if (auto* call = ref->dynCast()) { + call->target = getFunctionName(index); + } else if (auto* refFunc = ref->dynCast()) { + refFunc->func = getFunctionName(index); + } else { + WASM_UNREACHABLE("Invalid type in function references"); + } } } @@ -1869,8 +1879,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() { } wasm.table.exists = true; auto elemType = getS32LEB(); - if (elemType != BinaryConsts::EncodedType::AnyFunc) { - throwError("ElementType must be AnyFunc in MVP"); + if (elemType != BinaryConsts::EncodedType::funcref) { + throwError("ElementType must be funcref in MVP"); } bool is_shared; getResizableLimits( @@ -2117,7 +2127,8 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitGlobalSet((curr = allocator.alloc())->cast()); break; case BinaryConsts::Select: - visitSelect((curr = allocator.alloc()); + case BinaryConsts::SelectWithType: + visitSelect((curr = allocator.alloc(), code); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc())->cast()); @@ -2137,6 +2148,15 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { case BinaryConsts::Catch: curr = nullptr; break; + case BinaryConsts::RefNull: + visitRefNull((curr = allocator.alloc())->cast()); + break; + case BinaryConsts::RefIsNull: + visitRefIsNull((curr = allocator.alloc())->cast()); + break; + case BinaryConsts::RefFunc: + visitRefFunc((curr = allocator.alloc())->cast()); + break; case BinaryConsts::Try: visitTry((curr = allocator.alloc())->cast()); break; @@ -2510,7 +2530,7 @@ void WasmBinaryBuilder::visitCall(Call* curr) { curr->operands[num - i - 1] = popNonVoidExpression(); } curr->type = sig.results; - functionCalls[index].push_back(curr); // we don't know function names yet + functionRefs[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -4326,12 +4346,24 @@ bool WasmBinaryBuilder::maybeVisitSIMDLoad(Expression*& out, uint32_t code) { return true; } -void WasmBinaryBuilder::visitSelect(Select* curr) { - BYN_TRACE("zz node: Select\n"); +void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) { + BYN_TRACE("zz node: Select, code " << int32_t(code) << std::endl); + if (code == BinaryConsts::SelectWithType) { + size_t numTypes = getU32LEB(); + std::vector types; + for (size_t i = 0; i < numTypes; i++) { + types.push_back(getType()); + } + curr->type = Type(types); + } curr->condition = popNonVoidExpression(); curr->ifFalse = popNonVoidExpression(); curr->ifTrue = popNonVoidExpression(); - curr->finalize(); + if (code == BinaryConsts::SelectWithType) { + curr->finalize(curr->type); + } else { + curr->finalize(); + } } void WasmBinaryBuilder::visitReturn(Return* curr) { @@ -4383,6 +4415,27 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { curr->finalize(); } +void WasmBinaryBuilder::visitRefNull(RefNull* curr) { + BYN_TRACE("zz node: RefNull\n"); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefIsNull(RefIsNull* curr) { + BYN_TRACE("zz node: RefIsNull\n"); + curr->value = popNonVoidExpression(); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { + BYN_TRACE("zz node: RefFunc\n"); + Index index = getU32LEB(); + if (index >= functionImports.size() + functionSignatures.size()) { + throwError("ref.func: invalid call index"); + } + functionRefs[index].push_back(curr); // we don't know function names yet + curr->finalize(); +} + void WasmBinaryBuilder::visitTry(Try* curr) { BYN_TRACE("zz node: Try\n"); // For simplicity of implementation, like if scopes, we create a hidden block diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 20aff2091..3b12c4346 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -850,16 +850,22 @@ Type SExpressionWasmBuilder::stringToType(const char* str, return v128; } } + if (strncmp(str, "funcref", 7) == 0 && (prefix || str[7] == 0)) { + return funcref; + } if (strncmp(str, "anyref", 6) == 0 && (prefix || str[6] == 0)) { return anyref; } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return nullref; + } if (strncmp(str, "exnref", 6) == 0 && (prefix || str[6] == 0)) { return exnref; } if (allowError) { return none; } - throw ParseException("invalid wasm type"); + throw ParseException(std::string("invalid wasm type: ") + str); } Type SExpressionWasmBuilder::stringToLaneType(const char* str) { @@ -936,10 +942,16 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) { Expression* SExpressionWasmBuilder::makeSelect(Element& s) { auto ret = allocator.alloc