diff options
author | Heejin Ahn <aheejin@gmail.com> | 2019-12-30 17:55:20 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-30 17:55:20 -0800 |
commit | bcc76146fed433cbc8ba01a9f568d979c145110b (patch) | |
tree | ab70ad24afc257b73513c3e62f3aab9938d05944 /src | |
parent | a30f1df5696ccb3490e2eaa3a9ed5e7e487c7b0e (diff) | |
download | binaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.tar.gz binaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.tar.bz2 binaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.zip |
Add support for reference types proposal (#2451)
This adds support for the reference type proposal. This includes support
for all reference types (`anyref`, `funcref`(=`anyfunc`), and `nullref`)
and four new instructions: `ref.null`, `ref.is_null`, `ref.func`, and
new typed `select`. This also adds subtype relationship support between
reference types.
This does not include table instructions yet. This also does not include
wasm2js support.
Fixes #2444 and fixes #2447.
Diffstat (limited to 'src')
59 files changed, 1532 insertions, 477 deletions
diff --git a/src/asmjs/asm_v_wasm.cpp b/src/asmjs/asm_v_wasm.cpp index 3720ca079..5959db43e 100644 --- a/src/asmjs/asm_v_wasm.cpp +++ b/src/asmjs/asm_v_wasm.cpp @@ -53,10 +53,11 @@ AsmType wasmToAsmType(Type type) { return ASM_INT64; case v128: assert(false && "v128 not implemented yet"); + case funcref: case anyref: - assert(false && "anyref is not supported by asm2wasm"); + case nullref: case exnref: - assert(false && "exnref is not supported by asm2wasm"); + assert(false && "reference types are not supported by asm2wasm"); case none: return ASM_NONE; case unreachable: @@ -77,10 +78,14 @@ char getSig(Type type) { return 'd'; case v128: return 'V'; + case funcref: + return 'F'; case anyref: - return 'a'; + return 'A'; + case nullref: + return 'N'; case exnref: - return 'e'; + return 'E'; case none: return 'v'; case unreachable: diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 82cbc4c1f..826193e06 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -64,13 +64,16 @@ BinaryenLiteral toBinaryenLiteral(Literal x) { case Type::f64: ret.i64 = x.reinterpreti64(); break; - case Type::v128: { + case Type::v128: memcpy(&ret.v128, x.getv128Ptr(), 16); break; - } - - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + ret.func = x.getFunc().c_str(); + break; + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -90,8 +93,12 @@ Literal fromBinaryenLiteral(BinaryenLiteral x) { return Literal(x.i64).castToF64(); case Type::v128: return Literal(x.v128); - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + return Literal::makeFuncref(x.func); + case Type::nullref: + return Literal::makeNullref(); + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -209,8 +216,14 @@ void printArg(std::ostream& setup, std::ostream& out, BinaryenLiteral arg) { out << "BinaryenLiteralVec128(" << array << ")"; break; } - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals + case Type::funcref: + out << "BinaryenLiteralFuncref(" << arg.func << ")"; + break; + case Type::nullref: + out << "BinaryenLiteralNullref()"; + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("unexpected type"); @@ -265,7 +278,9 @@ BinaryenType BinaryenTypeInt64(void) { return i64; } BinaryenType BinaryenTypeFloat32(void) { return f32; } BinaryenType BinaryenTypeFloat64(void) { return f64; } BinaryenType BinaryenTypeVec128(void) { return v128; } +BinaryenType BinaryenTypeFuncref(void) { return funcref; } BinaryenType BinaryenTypeAnyref(void) { return anyref; } +BinaryenType BinaryenTypeNullref(void) { return nullref; } BinaryenType BinaryenTypeExnref(void) { return exnref; } BinaryenType BinaryenTypeUnreachable(void) { return unreachable; } BinaryenType BinaryenTypeAuto(void) { return uint32_t(-1); } @@ -397,6 +412,15 @@ BinaryenExpressionId BinaryenMemoryCopyId(void) { BinaryenExpressionId BinaryenMemoryFillId(void) { return Expression::Id::MemoryFillId; } +BinaryenExpressionId BinaryenRefNullId(void) { + return Expression::Id::RefNullId; +} +BinaryenExpressionId BinaryenRefIsNullId(void) { + return Expression::Id::RefIsNullId; +} +BinaryenExpressionId BinaryenRefFuncId(void) { + return Expression::Id::RefFuncId; +} BinaryenExpressionId BinaryenTryId(void) { return Expression::Id::TryId; } BinaryenExpressionId BinaryenThrowId(void) { return Expression::Id::ThrowId; } BinaryenExpressionId BinaryenRethrowId(void) { @@ -1330,17 +1354,22 @@ BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module, BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, - BinaryenExpressionRef ifFalse) { + BinaryenExpressionRef ifFalse, + BinaryenType type) { auto* ret = ((Module*)module)->allocator.alloc<Select>(); if (tracing) { - traceExpression(ret, "BinaryenSelect", condition, ifTrue, ifFalse); + traceExpression(ret, "BinaryenSelect", condition, ifTrue, ifFalse, type); } ret->condition = (Expression*)condition; ret->ifTrue = (Expression*)ifTrue; ret->ifFalse = (Expression*)ifFalse; - ret->finalize(); + if (type != BinaryenTypeAuto()) { + ret->finalize(Type(type)); + } else { + ret->finalize(); + } return static_cast<Expression*>(ret); } BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, @@ -1695,6 +1724,32 @@ BinaryenExpressionRef BinaryenPop(BinaryenModuleRef module, BinaryenType type) { return static_cast<Expression*>(ret); } +BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module) { + auto* ret = Builder(*(Module*)module).makeRefNull(); + if (tracing) { + traceExpression(ret, "BinaryenRefNull"); + } + return static_cast<Expression*>(ret); +} + +BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module, + BinaryenExpressionRef value) { + auto* ret = Builder(*(Module*)module).makeRefIsNull((Expression*)value); + if (tracing) { + traceExpression(ret, "BinaryenRefIsNull", value); + } + return static_cast<Expression*>(ret); +} + +BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module, + const char* func) { + auto* ret = Builder(*(Module*)module).makeRefFunc(func); + if (tracing) { + traceExpression(ret, "BinaryenRefFunc", StringLit(func)); + } + return static_cast<Expression*>(ret); +} + BinaryenExpressionRef BinaryenTry(BinaryenModuleRef module, BinaryenExpressionRef body, BinaryenExpressionRef catchBody) { @@ -2964,6 +3019,28 @@ BinaryenExpressionRef BinaryenPushGetValue(BinaryenExpressionRef expr) { assert(expression->is<Push>()); return static_cast<Push*>(expression)->value; } +// RefIsNull +BinaryenExpressionRef BinaryenRefIsNullGetValue(BinaryenExpressionRef expr) { + if (tracing) { + std::cout << " BinaryenRefIsNullGetValue(expressions[" << expressions[expr] + << "]);\n"; + } + + auto* expression = (Expression*)expr; + assert(expression->is<RefIsNull>()); + return static_cast<RefIsNull*>(expression)->value; +} +// RefFunc +const char* BinaryenRefFuncGetFunc(BinaryenExpressionRef expr) { + if (tracing) { + std::cout << " BinaryenRefFuncGetFunc(expressions[" << expressions[expr] + << "]);\n"; + } + + auto* expression = (Expression*)expr; + assert(expression->is<RefFunc>()); + return static_cast<RefFunc*>(expression)->func.c_str(); +} // Try BinaryenExpressionRef BinaryenTryGetBody(BinaryenExpressionRef expr) { if (tracing) { diff --git a/src/binaryen-c.h b/src/binaryen-c.h index f169ff367..36d82ecb0 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -98,7 +98,9 @@ BINARYEN_API BinaryenType BinaryenTypeInt64(void); BINARYEN_API BinaryenType BinaryenTypeFloat32(void); BINARYEN_API BinaryenType BinaryenTypeFloat64(void); BINARYEN_API BinaryenType BinaryenTypeVec128(void); +BINARYEN_API BinaryenType BinaryenTypeFuncref(void); BINARYEN_API BinaryenType BinaryenTypeAnyref(void); +BINARYEN_API BinaryenType BinaryenTypeNullref(void); BINARYEN_API BinaryenType BinaryenTypeExnref(void); BINARYEN_API BinaryenType BinaryenTypeUnreachable(void); // Not a real type. Used as the last parameter to BinaryenBlock to let @@ -158,6 +160,9 @@ BINARYEN_API BinaryenExpressionId BinaryenMemoryInitId(void); BINARYEN_API BinaryenExpressionId BinaryenDataDropId(void); BINARYEN_API BinaryenExpressionId BinaryenMemoryCopyId(void); BINARYEN_API BinaryenExpressionId BinaryenMemoryFillId(void); +BINARYEN_API BinaryenExpressionId BinaryenRefNullId(void); +BINARYEN_API BinaryenExpressionId BinaryenRefIsNullId(void); +BINARYEN_API BinaryenExpressionId BinaryenRefFuncId(void); BINARYEN_API BinaryenExpressionId BinaryenTryId(void); BINARYEN_API BinaryenExpressionId BinaryenThrowId(void); BINARYEN_API BinaryenExpressionId BinaryenRethrowId(void); @@ -222,6 +227,7 @@ struct BinaryenLiteral { float f32; double f64; uint8_t v128[16]; + const char* func; }; }; @@ -692,7 +698,8 @@ BINARYEN_API BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module, BinaryenExpressionRef condition, BinaryenExpressionRef ifTrue, - BinaryenExpressionRef ifFalse); + BinaryenExpressionRef ifFalse, + BinaryenType type); BINARYEN_API BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module, BinaryenExpressionRef value); // Return: value can be NULL @@ -797,6 +804,11 @@ BinaryenMemoryFill(BinaryenModuleRef module, BinaryenExpressionRef dest, BinaryenExpressionRef value, BinaryenExpressionRef size); +BINARYEN_API BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module); +BINARYEN_API BinaryenExpressionRef +BinaryenRefIsNull(BinaryenModuleRef module, BinaryenExpressionRef value); +BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module, + const char* func); BINARYEN_API BinaryenExpressionRef BinaryenTry(BinaryenModuleRef module, BinaryenExpressionRef body, BinaryenExpressionRef catchBody); @@ -1036,6 +1048,11 @@ BINARYEN_API BinaryenExpressionRef BinaryenMemoryFillGetSize(BinaryenExpressionRef expr); BINARYEN_API BinaryenExpressionRef +BinaryenRefIsNullGetValue(BinaryenExpressionRef expr); + +BINARYEN_API const char* BinaryenRefFuncGetFunc(BinaryenExpressionRef expr); + +BINARYEN_API BinaryenExpressionRef BinaryenTryGetBody(BinaryenExpressionRef expr); BINARYEN_API BinaryenExpressionRef BinaryenTryGetCatchBody(BinaryenExpressionRef expr); diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc index d60ee1794..eb5626b3f 100644 --- a/src/gen-s-parser.inc +++ b/src/gen-s-parser.inc @@ -653,6 +653,9 @@ switch (op[0]) { default: goto parse_error; } } + case 'u': + if (strcmp(op, "funcref.pop") == 0) { return makePop(funcref); } + goto parse_error; default: goto parse_error; } } @@ -2480,30 +2483,57 @@ switch (op[0]) { default: goto parse_error; } } - case 'n': - if (strcmp(op, "nop") == 0) { return makeNop(); } - goto parse_error; + case 'n': { + switch (op[1]) { + case 'o': + if (strcmp(op, "nop") == 0) { return makeNop(); } + goto parse_error; + case 'u': + if (strcmp(op, "nullref.pop") == 0) { return makePop(nullref); } + goto parse_error; + default: goto parse_error; + } + } case 'p': if (strcmp(op, "push") == 0) { return makePush(s); } goto parse_error; case 'r': { - switch (op[3]) { - case 'h': - if (strcmp(op, "rethrow") == 0) { return makeRethrow(s); } - goto parse_error; - case 'u': { - switch (op[6]) { - case '\0': - if (strcmp(op, "return") == 0) { return makeReturn(s); } + switch (op[2]) { + case 'f': { + switch (op[4]) { + case 'f': + if (strcmp(op, "ref.func") == 0) { return makeRefFunc(s); } goto parse_error; - case '_': { - switch (op[11]) { + case 'i': + if (strcmp(op, "ref.is_null") == 0) { return makeRefIsNull(s); } + goto parse_error; + case 'n': + if (strcmp(op, "ref.null") == 0) { return makeRefNull(s); } + goto parse_error; + default: goto parse_error; + } + } + case 't': { + switch (op[3]) { + case 'h': + if (strcmp(op, "rethrow") == 0) { return makeRethrow(s); } + goto parse_error; + case 'u': { + switch (op[6]) { case '\0': - if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); } - goto parse_error; - case '_': - if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); } + if (strcmp(op, "return") == 0) { return makeReturn(s); } goto parse_error; + case '_': { + switch (op[11]) { + case '\0': + if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); } + goto parse_error; + case '_': + if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); } + goto parse_error; + default: goto parse_error; + } + } default: goto parse_error; } } diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index 7355d1856..4b9869ddd 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -218,6 +218,9 @@ template<typename T> void visitImmediates(Expression* curr, T& visitor) { visitor.visitInt(curr->op); visitor.visitNonScopeName(curr->nameOperand); } + void visitRefNull(RefNull* curr) {} + void visitRefIsNull(RefIsNull* curr) {} + void visitRefFunc(RefFunc* curr) { visitor.visitNonScopeName(curr->func); } void visitTry(Try* curr) {} void visitThrow(Throw* curr) { visitor.visitNonScopeName(curr->event); } void visitRethrow(Rethrow* curr) {} diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp index fbee9f9c1..acea09bad 100644 --- a/src/ir/ExpressionManipulator.cpp +++ b/src/ir/ExpressionManipulator.cpp @@ -58,7 +58,7 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { curr->type); } Expression* visitLoop(Loop* curr) { - return builder.makeLoop(curr->name, copy(curr->body)); + return builder.makeLoop(curr->name, copy(curr->body), curr->type); } Expression* visitBreak(Break* curr) { return builder.makeBreak( @@ -208,8 +208,10 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right)); } Expression* visitSelect(Select* curr) { - return builder.makeSelect( - copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse)); + return builder.makeSelect(copy(curr->condition), + copy(curr->ifTrue), + copy(curr->ifFalse), + curr->type); } Expression* visitDrop(Drop* curr) { return builder.makeDrop(copy(curr->value)); @@ -226,6 +228,13 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { builder.makeHost(curr->op, curr->nameOperand, std::move(operands)); return ret; } + Expression* visitRefNull(RefNull* curr) { return builder.makeRefNull(); } + Expression* visitRefIsNull(RefIsNull* curr) { + return builder.makeRefIsNull(copy(curr->value)); + } + Expression* visitRefFunc(RefFunc* curr) { + return builder.makeRefFunc(curr->func); + } Expression* visitTry(Try* curr) { return builder.makeTry( copy(curr->body), copy(curr->catchBody), curr->type); diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index be0a8604b..9243869a1 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp @@ -44,23 +44,13 @@ void ReFinalize::visitBlock(Block* curr) { curr->type = none; return; } - // do this quickly, without any validation - // last element determines type + // Get the least upper bound type of the last element and all branch return + // values curr->type = curr->list.back()->type; - // if concrete, it doesn't matter if we have an unreachable child, and we - // don't need to look at breaks - if (curr->type.isConcrete()) { - return; - } - // otherwise, we have no final fallthrough element to determine the type, - // could be determined by breaks if (curr->name.is()) { auto iter = breakValues.find(curr->name); if (iter != breakValues.end()) { - // there is a break to here - auto type = iter->second; - assert(type != unreachable); // we would have removed such branches - curr->type = type; + curr->type = Type::getLeastUpperBound(curr->type, iter->second); return; } } @@ -130,6 +120,9 @@ void ReFinalize::visitSelect(Select* curr) { curr->finalize(); } void ReFinalize::visitDrop(Drop* curr) { curr->finalize(); } void ReFinalize::visitReturn(Return* curr) { curr->finalize(); } void ReFinalize::visitHost(Host* curr) { curr->finalize(); } +void ReFinalize::visitRefNull(RefNull* curr) { curr->finalize(); } +void ReFinalize::visitRefIsNull(RefIsNull* curr) { curr->finalize(); } +void ReFinalize::visitRefFunc(RefFunc* curr) { curr->finalize(); } void ReFinalize::visitTry(Try* curr) { curr->finalize(); } void ReFinalize::visitThrow(Throw* curr) { curr->finalize(); } void ReFinalize::visitRethrow(Rethrow* curr) { curr->finalize(); } @@ -159,8 +152,12 @@ void ReFinalize::visitEvent(Event* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitModule(Module* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::updateBreakValueType(Name name, Type type) { - if (type != unreachable || breakValues.count(name) == 0) { - breakValues[name] = type; + if (type != Type::unreachable) { + if (breakValues.count(name) == 0) { + breakValues[name] = type; + } else { + breakValues[name] = Type::getLeastUpperBound(breakValues[name], type); + } } } diff --git a/src/ir/abstract.h b/src/ir/abstract.h index 384f8b555..76215d07f 100644 --- a/src/ir/abstract.h +++ b/src/ir/abstract.h @@ -80,8 +80,10 @@ inline UnaryOp getUnary(Type type, Op op) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: // there's no unary instructions for anyref - case exnref: // there's no unary instructions for exnref + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: { return InvalidUnary; @@ -211,8 +213,10 @@ inline BinaryOp getBinary(Type type, Op op) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: // there's no binary instructions for anyref - case exnref: // there's no binary instructions for exnref + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: { return InvalidBinary; diff --git a/src/ir/block-utils.h b/src/ir/block-utils.h index ca8b7179b..153dd45b3 100644 --- a/src/ir/block-utils.h +++ b/src/ir/block-utils.h @@ -43,7 +43,8 @@ simplifyToContents(Block* block, T* parent, bool allowTypeChange = false) { // no side effects, and singleton is not returning a value, so we can // throw away the block and its contents, basically return Builder(*parent->getModule()).replaceWithIdenticalType(block); - } else if (block->type == singleton->type || allowTypeChange) { + } else if (Type::isSubType(singleton->type, block->type) || + allowTypeChange) { return singleton; } else { // (side effects +) type change, must be block with declared value but diff --git a/src/ir/effects.h b/src/ir/effects.h index e93c63017..6eb2da91d 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -387,6 +387,9 @@ struct EffectAnalyzer // Atomics are also sequentially consistent with memory.grow. isAtomic = true; } + void visitRefNull(RefNull* curr) {} + void visitRefIsNull(RefIsNull* curr) {} + void visitRefFunc(RefFunc* curr) {} void visitTry(Try* curr) {} // We safely model throws as branches void visitThrow(Throw* curr) { branches = true; } diff --git a/src/ir/flat.h b/src/ir/flat.h index dd72e339d..01a94a759 100644 --- a/src/ir/flat.h +++ b/src/ir/flat.h @@ -56,6 +56,7 @@ #define wasm_ir_flat_h #include "ir/iteration.h" +#include "ir/properties.h" #include "pass.h" #include "wasm-traversal.h" @@ -64,7 +65,8 @@ namespace wasm { namespace Flat { inline bool isControlFlowStructure(Expression* curr) { - return curr->is<Block>() || curr->is<If>() || curr->is<Loop>(); + return curr->is<Block>() || curr->is<If>() || curr->is<Loop>() || + curr->is<Try>(); } inline void verifyFlatness(Function* func) { @@ -79,10 +81,10 @@ inline void verifyFlatness(Function* func) { verify(!curr->type.isConcrete(), "tees are not allowed, only sets"); } else { for (auto* child : ChildIterator(curr)) { - verify(child->is<Const>() || child->is<LocalGet>() || - child->is<Unreachable>(), - "instructions must only have const, local.get, or unreachable " - "as children"); + verify(Properties::isConstantExpression(child) || + child->is<LocalGet>() || child->is<Unreachable>(), + "instructions must only have constant expressions, local.get, " + "or unreachable as children"); } } } diff --git a/src/ir/global-utils.h b/src/ir/global-utils.h index 93e5c8a67..e096aec8c 100644 --- a/src/ir/global-utils.h +++ b/src/ir/global-utils.h @@ -52,6 +52,12 @@ getGlobalInitializedToImport(Module& wasm, Name module, Name base) { }); return ret; } + +inline bool canInitializeGlobal(const Expression* curr) { + return curr->is<Const>() || curr->is<RefNull>() || curr->is<RefFunc>() || + curr->is<GlobalGet>(); +} + } // namespace GlobalUtils } // namespace wasm diff --git a/src/ir/literal-utils.h b/src/ir/literal-utils.h index 63a2b3b44..4bc79eee9 100644 --- a/src/ir/literal-utils.h +++ b/src/ir/literal-utils.h @@ -39,6 +39,10 @@ inline Expression* makeZero(Type type, Module& wasm) { return builder.makeUnary(SplatVecI32x4, builder.makeConst(Literal(int32_t(0)))); } + if (type.isRef()) { + Builder builder(wasm); + return builder.makeRefNull(); + } return makeFromInt32(0, type, wasm); } diff --git a/src/ir/manipulation.h b/src/ir/manipulation.h index ec137d372..49ed7e11e 100644 --- a/src/ir/manipulation.h +++ b/src/ir/manipulation.h @@ -33,14 +33,24 @@ inline OutputType* convert(InputType* input) { return output; } -// Convenience method for nop, which is a common conversion +// Convenience methods for certain instructions, which are common conversions template<typename InputType> inline Nop* nop(InputType* target) { - return convert<InputType, Nop>(target); + auto* ret = convert<InputType, Nop>(target); + ret->finalize(); + return ret; +} + +template<typename InputType> inline RefNull* refNull(InputType* target) { + auto* ret = convert<InputType, RefNull>(target); + ret->finalize(); + return ret; } template<typename InputType> inline Unreachable* unreachable(InputType* target) { - return convert<InputType, Unreachable>(target); + auto* ret = convert<InputType, Unreachable>(target); + ret->finalize(); + return ret; } // Convert a node that allocates diff --git a/src/ir/properties.h b/src/ir/properties.h index bb88af6c5..f4c9686b6 100644 --- a/src/ir/properties.h +++ b/src/ir/properties.h @@ -187,6 +187,10 @@ inline Expression* getFallthrough(Expression* curr) { return curr; } +inline bool isConstantExpression(const Expression* curr) { + return curr->is<Const>() || curr->is<RefNull>() || curr->is<RefFunc>(); +} + } // namespace Properties } // namespace wasm diff --git a/src/ir/utils.h b/src/ir/utils.h index cad7bc885..9bd3c9e0b 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -146,6 +146,9 @@ struct ReFinalize void visitDrop(Drop* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -210,6 +213,9 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitDrop(Drop* curr) { curr->finalize(); } void visitReturn(Return* curr) { curr->finalize(); } void visitHost(Host* curr) { curr->finalize(); } + void visitRefNull(RefNull* curr) { curr->finalize(); } + void visitRefIsNull(RefIsNull* curr) { curr->finalize(); } + void visitRefFunc(RefFunc* curr) { curr->finalize(); } void visitTry(Try* curr) { curr->finalize(); } void visitThrow(Throw* curr) { curr->finalize(); } void visitRethrow(Rethrow* curr) { curr->finalize(); } diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 8cddc61c7..2993573d1 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -38,7 +38,9 @@ function initializeConstants() { ['f32', 'Float32'], ['f64', 'Float64'], ['v128', 'Vec128'], + ['funcref', 'Funcref'], ['anyref', 'Anyref'], + ['nullref', 'Nullref'], ['exnref', 'Exnref'], ['unreachable', 'Unreachable'], ['auto', 'Auto'] @@ -86,6 +88,9 @@ function initializeConstants() { 'DataDrop', 'MemoryCopy', 'MemoryFill', + 'RefNull', + 'RefIsNull', + 'RefFunc', 'Try', 'Throw', 'Rethrow', @@ -1952,20 +1957,47 @@ function wrapModule(module, self) { }, }; + self['funcref'] = { + 'pop': function() { + return Module['_BinaryenPop'](module, Module['funcref']); + } + }; + self['anyref'] = { 'pop': function() { return Module['_BinaryenPop'](module, Module['anyref']); } }; + self['nullref'] = { + 'pop': function() { + return Module['_BinaryenPop'](module, Module['nullref']); + } + }; + self['exnref'] = { 'pop': function() { return Module['_BinaryenPop'](module, Module['exnref']); } }; - self['select'] = function(condition, ifTrue, ifFalse) { - return Module['_BinaryenSelect'](module, condition, ifTrue, ifFalse); + self['ref'] = { + 'null': function() { + return Module['_BinaryenRefNull'](module); + }, + 'is_null': function(value) { + return Module['_BinaryenRefIsNull'](module, value); + }, + 'func': function(func) { + return preserveStack(function() { + return Module['_BinaryenRefFunc'](module, strToStack(func)); + }); + } + }; + + self['select'] = function(condition, ifTrue, ifFalse, type) { + return Module['_BinaryenSelect']( + module, condition, ifTrue, ifFalse, typeof type !== 'undefined' ? type : Module['auto']); }; self['drop'] = function(value) { return Module['_BinaryenDrop'](module, value); @@ -2651,6 +2683,23 @@ Module['getExpressionInfo'] = function(expr) { 'value': Module['_BinaryenMemoryFillGetValue'](expr), 'size': Module['_BinaryenMemoryFillGetSize'](expr) }; + case Module['RefNullId']: + return { + 'id': id, + 'type': type + }; + case Module['RefIsNullId']: + return { + 'id': id, + 'type': type, + 'value': Module['_BinaryenRefIsNullGetValue'](expr) + }; + case Module['RefFuncId']: + return { + 'id': id, + 'type': type, + 'func': UTF8ToString(Module['_BinaryenRefFuncGetFunc'](expr)), + }; case Module['TryId']: return { 'id': id, diff --git a/src/literal.h b/src/literal.h index 1d19e6661..ef3e13d44 100644 --- a/src/literal.h +++ b/src/literal.h @@ -22,6 +22,7 @@ #include "compiler-support.h" #include "support/hash.h" +#include "support/name.h" #include "support/utilities.h" #include "wasm-type.h" @@ -34,6 +35,7 @@ class Literal { int32_t i32; int64_t i64; uint8_t v128[16]; + Name func; // function name for funcref }; public: @@ -57,11 +59,12 @@ public: explicit Literal(const std::array<Literal, 8>&); explicit Literal(const std::array<Literal, 4>&); explicit Literal(const std::array<Literal, 2>&); + explicit Literal(Name func) : func(func), type(Type::funcref) {} - bool isConcrete() { return type != none; } - bool isNull() { return type == none; } + bool isConcrete() { return type != Type::none; } + bool isNone() { return type == Type::none; } - inline static Literal makeFromInt32(int32_t x, Type type) { + static Literal makeFromInt32(int32_t x, Type type) { switch (type) { case Type::i32: return Literal(int32_t(x)); @@ -80,16 +83,26 @@ public: Literal(int32_t(0)), Literal(int32_t(0)), Literal(int32_t(0))}}); - case Type::anyref: // there's no anyref literals - case Type::exnref: // there's no exnref literals - case none: - case unreachable: + case Type::funcref: + case Type::anyref: + case Type::nullref: + case Type::exnref: + case Type::none: + case Type::unreachable: WASM_UNREACHABLE("unexpected type"); } WASM_UNREACHABLE("unexpected type"); } - inline static Literal makeZero(Type type) { return makeFromInt32(0, type); } + static Literal makeZero(Type type) { + if (type.isRef()) { + return makeNullref(); + } + return makeFromInt32(0, type); + } + + static Literal makeNullref() { return Literal(Type(Type::nullref)); } + static Literal makeFuncref(Name func) { return Literal(func.c_str()); } Literal castToF32(); Literal castToF64(); @@ -113,6 +126,7 @@ public: return bit_cast<double>(i64); } std::array<uint8_t, 16> getv128() const; + Name getFunc() const { return func; } // careful! int32_t* geti32Ptr() { @@ -464,8 +478,10 @@ template<> struct less<wasm::Literal> { return a.reinterpreti64() < b.reinterpreti64(); case wasm::Type::v128: return memcmp(a.getv128Ptr(), b.getv128Ptr(), 16) < 0; - case wasm::Type::anyref: // anyref is an opaque value - case wasm::Type::exnref: // exnref is an opaque value + case wasm::Type::funcref: + case wasm::Type::anyref: + case wasm::Type::nullref: + case wasm::Type::exnref: case wasm::Type::none: case wasm::Type::unreachable: return false; diff --git a/src/parsing.h b/src/parsing.h index 7017fdb0f..d64236df3 100644 --- a/src/parsing.h +++ b/src/parsing.h @@ -263,8 +263,10 @@ parseConst(cashew::IString s, Type type, MixedArena& allocator) { break; } case v128: - case anyref: // there's no anyref.const - case exnref: // there's no exnref.const + case funcref: + case anyref: + case nullref: + case exnref: WASM_UNREACHABLE("unexpected const type"); case none: case unreachable: { diff --git a/src/passes/ConstHoisting.cpp b/src/passes/ConstHoisting.cpp index dbb3853d8..4e8cd9910 100644 --- a/src/passes/ConstHoisting.cpp +++ b/src/passes/ConstHoisting.cpp @@ -91,9 +91,12 @@ private: size = value.type.getByteSize(); break; } - case v128: // v128 not implemented yet - case anyref: // anyref cannot have literals - case exnref: { // exnref cannot have literals + // not implemented yet + case v128: + case funcref: + case anyref: + case nullref: + case exnref: { return false; } case none: diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index be6f92ffa..7d5385a83 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -347,6 +347,12 @@ struct DeadCodeElimination DELEGATE(Push); case Expression::Id::PopId: DELEGATE(Pop); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index f5115567b..fda8e3f80 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -21,6 +21,7 @@ #include <ir/branch-utils.h> #include <ir/effects.h> #include <ir/flat.h> +#include <ir/properties.h> #include <ir/utils.h> #include <pass.h> #include <wasm-builder.h> @@ -61,7 +62,9 @@ struct Flatten std::vector<Expression*> ourPreludes; Builder builder(*getModule()); - if (curr->is<Const>() || curr->is<Nop>() || curr->is<Unreachable>()) { + // Nothing to do for constants, nop, and unreachable + if (Properties::isConstantExpression(curr) || curr->is<Nop>() || + curr->is<Unreachable>()) { return; } @@ -194,8 +197,37 @@ struct Flatten auto type = br->value->type; if (type.isConcrete()) { // we are sending a value. use a local instead - Index temp = getTempForBreakTarget(br->name, type); + Type blockType = findBreakTarget(br->name)->type; + Index temp = getTempForBreakTarget(br->name, blockType); ourPreludes.push_back(builder.makeLocalSet(temp, br->value)); + + // br_if leaves a value on the stack if not taken, which later can + // be the last element of the enclosing innermost block and flow + // out. The local we created using 'getTempForBreakTarget' returns + // the return type of the block this branch is targetting, which may + // not be the same with the innermost block's return type. For + // example, + // (block $any (result anyref) + // (block (result nullref) + // (local.tee $0 + // (br_if $any + // (ref.null) + // (i32.const 0) + // ) + // ) + // ) + // ) + // In this case we need two locals to store (ref.null); one with + // anyref type that's for the target block ($label0) and one more + // with nullref type in case for flowing out. Here we create the + // second 'flowing out' local in case two block's types are + // different. + if (type != blockType) { + temp = builder.addVar(getFunction(), type); + ourPreludes.push_back(builder.makeLocalSet( + temp, ExpressionManipulator::copy(br->value, *getModule()))); + } + if (br->condition) { // the value must also flow out ourPreludes.push_back(br); @@ -239,6 +271,7 @@ struct Flatten } } } + // TODO Handle br_on_exn // continue for general handling of everything, control flow or otherwise curr = getCurrent(); // we may have replaced it diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index 729a4a6c3..9d5109a83 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -65,11 +65,11 @@ static Expression* toABI(Expression* value, Module* module) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: { - WASM_UNREACHABLE("anyref cannot be converted to i64"); - } + case funcref: + case anyref: + case nullref: case exnref: { - WASM_UNREACHABLE("exnref cannot be converted to i64"); + WASM_UNREACHABLE("reference types cannot be converted to i64"); } case none: { // the value is none, but we need a value here @@ -108,11 +108,11 @@ static Expression* fromABI(Expression* value, Type type, Module* module) { case v128: { WASM_UNREACHABLE("v128 not implemented yet"); } - case anyref: { - WASM_UNREACHABLE("anyref cannot be converted from i64"); - } + case funcref: + case anyref: + case nullref: case exnref: { - WASM_UNREACHABLE("exnref cannot be converted from i64"); + WASM_UNREACHABLE("reference types cannot be converted from i64"); } case none: { value = builder.makeDrop(value); diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index db1db5971..c43d41e7f 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -46,13 +46,13 @@ namespace wasm { // Useful into on a function, helping us decide if we can inline it struct FunctionInfo { - std::atomic<Index> calls; + std::atomic<Index> refs; Index size; std::atomic<bool> lightweight; bool usedGlobally; // in a table or export FunctionInfo() { - calls = 0; + refs = 0; size = 0; lightweight = true; usedGlobally = false; @@ -75,7 +75,7 @@ struct FunctionInfo { // FIXME: move this check to be first in this function, since we should // return true if oneCallerInlineMaxSize is bigger than // flexibleInlineMaxSize (which it typically should be). - if (calls == 1 && !usedGlobally && + if (refs == 1 && !usedGlobally && size <= options.inlining.oneCallerInlineMaxSize) { return true; } @@ -108,11 +108,16 @@ struct FunctionInfoScanner void visitCall(Call* curr) { // can't add a new element in parallel assert(infos->count(curr->target) > 0); - (*infos)[curr->target].calls++; + (*infos)[curr->target].refs++; // having a call is not lightweight (*infos)[getFunction()->name].lightweight = false; } + void visitRefFunc(RefFunc* curr) { + assert(infos->count(curr->func) > 0); + (*infos)[curr->func].refs++; + } + void visitFunction(Function* curr) { (*infos)[curr->name].size = Measurer::measure(curr->body); } @@ -374,7 +379,7 @@ struct Inlining : public Pass { doInlining(module, func.get(), action); inlinedUses[inlinedName]++; inlinedInto.insert(func.get()); - assert(inlinedUses[inlinedName] <= infos[inlinedName].calls); + assert(inlinedUses[inlinedName] <= infos[inlinedName].refs); } } // anything we inlined into may now have non-unique label names, fix it up @@ -388,7 +393,7 @@ struct Inlining : public Pass { module->removeFunctions([&](Function* func) { auto name = func->name; auto& info = infos[name]; - return inlinedUses.count(name) && inlinedUses[name] == info.calls && + return inlinedUses.count(name) && inlinedUses[name] == info.refs && !info.usedGlobally; }); // return whether we did any work diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp index 407903219..ae35ec2d1 100644 --- a/src/passes/InstrumentLocals.cpp +++ b/src/passes/InstrumentLocals.cpp @@ -56,14 +56,18 @@ Name get_i32("get_i32"); Name get_i64("get_i64"); Name get_f32("get_f32"); Name get_f64("get_f64"); +Name get_funcref("get_funcref"); Name get_anyref("get_anyref"); +Name get_nullref("get_nullref"); Name get_exnref("get_exnref"); Name set_i32("set_i32"); Name set_i64("set_i64"); Name set_f32("set_f32"); Name set_f64("set_f64"); +Name set_funcref("set_funcref"); Name set_anyref("set_anyref"); +Name set_nullref("set_nullref"); Name set_exnref("set_exnref"); struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { @@ -84,9 +88,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { break; case v128: assert(false && "v128 not implemented yet"); + case funcref: + import = get_funcref; + break; case anyref: import = get_anyref; break; + case nullref: + import = get_nullref; + break; case exnref: import = get_exnref; break; @@ -126,9 +136,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { break; case v128: assert(false && "v128 not implemented yet"); + case funcref: + import = set_funcref; + break; case anyref: import = set_anyref; break; + case nullref: + import = set_nullref; + break; case exnref: import = set_exnref; break; @@ -156,10 +172,26 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { addImport(curr, set_f64, {Type::i32, Type::i32, Type::f64}, Type::f64); if (curr->features.hasReferenceTypes()) { + addImport(curr, + get_funcref, + {Type::i32, Type::i32, Type::funcref}, + Type::funcref); + addImport(curr, + set_funcref, + {Type::i32, Type::i32, Type::funcref}, + Type::funcref); addImport( curr, get_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); addImport( curr, set_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); + addImport(curr, + get_nullref, + {Type::i32, Type::i32, Type::nullref}, + Type::nullref); + addImport(curr, + set_nullref, + {Type::i32, Type::i32, Type::nullref}, + Type::nullref); } if (curr->features.hasExceptionHandling()) { addImport( diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 8c7bc4414..df6651b0d 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -107,14 +107,43 @@ struct LegalizeJSInterface : public Pass { } } } + if (!illegalImportsToLegal.empty()) { + // Gather functions used in 'ref.func'. They should not be removed. + std::unordered_map<Name, std::atomic<bool>> usedInRefFunc; + + struct RefFuncScanner : public WalkerPass<PostWalker<RefFuncScanner>> { + Module& wasm; + std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc; + + bool isFunctionParallel() override { return true; } + + Pass* create() override { + return new RefFuncScanner(wasm, usedInRefFunc); + } + + RefFuncScanner( + Module& wasm, + std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc) + : wasm(wasm), usedInRefFunc(usedInRefFunc) { + // Fill in unordered_map, as we operate on it in parallel + for (auto& func : wasm.functions) { + usedInRefFunc[func->name]; + } + } + + void visitRefFunc(RefFunc* curr) { usedInRefFunc[curr->func] = true; } + }; + + RefFuncScanner(*module, usedInRefFunc).run(runner, module); for (auto& pair : illegalImportsToLegal) { - module->removeFunction(pair.first); + if (!usedInRefFunc[pair.first]) { + module->removeFunction(pair.first); + } } // fix up imports: call_import of an illegal must be turned to a call of a // legal - struct FixImports : public WalkerPass<PostWalker<FixImports>> { bool isFunctionParallel() override { return true; } diff --git a/src/passes/LocalCSE.cpp b/src/passes/LocalCSE.cpp index 0816bf6ea..b49c92310 100644 --- a/src/passes/LocalCSE.cpp +++ b/src/passes/LocalCSE.cpp @@ -172,9 +172,12 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> { void handle(Expression* curr) { if (auto* set = curr->dynCast<LocalSet>()) { // Calculate equivalences + auto* func = getFunction(); equivalences.reset(set->index); if (auto* get = set->value->dynCast<LocalGet>()) { - equivalences.add(set->index, get->index); + if (func->getLocalType(set->index) == func->getLocalType(get->index)) { + equivalences.add(set->index, get->index); + } } // consider the value auto* value = set->value; @@ -184,7 +187,7 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> { if (iter != usables.end()) { // already exists in the table, this is good to reuse auto& info = iter->second; - Type localType = getFunction()->getLocalType(info.index); + Type localType = func->getLocalType(info.index); set->value = Builder(*getModule()).makeLocalGet(info.index, localType); anotherPass = true; diff --git a/src/passes/MergeLocals.cpp b/src/passes/MergeLocals.cpp index 0116753f1..2223594b6 100644 --- a/src/passes/MergeLocals.cpp +++ b/src/passes/MergeLocals.cpp @@ -100,7 +100,8 @@ struct MergeLocals return; } // compute all dependencies - LocalGraph preGraph(getFunction()); + auto* func = getFunction(); + LocalGraph preGraph(func); preGraph.computeInfluences(); // optimize each copy std::unordered_map<LocalSet*, LocalSet*> optimizedToCopy, @@ -119,6 +120,11 @@ struct MergeLocals if (preGraph.getSetses[influencedGet].size() == 1) { // this is ok assert(*preGraph.getSetses[influencedGet].begin() == trivial); + // If local types are different (when one is a subtype of the + // other), don't optimize + if (func->getLocalType(copy->index) != influencedGet->type) { + canOptimizeToCopy = false; + } } else { canOptimizeToCopy = false; break; @@ -152,6 +158,11 @@ struct MergeLocals if (preGraph.getSetses[influencedGet].size() == 1) { // this is ok assert(*preGraph.getSetses[influencedGet].begin() == copy); + // If local types are different (when one is a subtype of the + // other), don't optimize + if (func->getLocalType(trivial->index) != influencedGet->type) { + canOptimizeToTrivial = false; + } } else { canOptimizeToTrivial = false; break; @@ -176,7 +187,7 @@ struct MergeLocals // if one does not work, we need to undo all its siblings (don't extend // the live range unless we are definitely removing a conflict, same // logic as before). - LocalGraph postGraph(getFunction()); + LocalGraph postGraph(func); postGraph.computeInfluences(); for (auto& pair : optimizedToCopy) { auto* copy = pair.first; diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 6de1d3d00..edd6ba2b6 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -751,12 +751,12 @@ struct OptimizeInstructions // condition, do that auto needCondition = EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects(); - auto typeIsIdentical = iff->ifTrue->type == iff->type; - if (typeIsIdentical && !needCondition) { + auto isSubType = Type::isSubType(iff->ifTrue->type, iff->type); + if (isSubType && !needCondition) { return iff->ifTrue; } else { Builder builder(*getModule()); - if (typeIsIdentical) { + if (isSubType) { return builder.makeSequence(builder.makeDrop(iff->condition), iff->ifTrue); } else { diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 57a3ab27f..85eb026f9 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -177,7 +177,7 @@ struct Precompute void visitExpression(Expression* curr) { // TODO: if local.get, only replace with a constant if we don't care about // size...? - if (curr->is<Const>() || curr->is<Nop>()) { + if (Properties::isConstantExpression(curr) || curr->is<Nop>()) { return; } // Until engines implement v128.const and we have SIMD-aware optimizations @@ -208,14 +208,16 @@ struct Precompute return; } } - ret->value = Builder(*getModule()).makeConst(flow.value); + ret->value = Builder(*getModule()).makeConstExpression(flow.value); } else { ret->value = nullptr; } } else { Builder builder(*getModule()); - replaceCurrent(builder.makeReturn( - flow.value.type != none ? builder.makeConst(flow.value) : nullptr)); + replaceCurrent( + builder.makeReturn(flow.value.type != Type::none + ? builder.makeConstExpression(flow.value) + : nullptr)); } return; } @@ -234,7 +236,7 @@ struct Precompute return; } } - br->value = Builder(*getModule()).makeConst(flow.value); + br->value = Builder(*getModule()).makeConstExpression(flow.value); } else { br->value = nullptr; } @@ -243,13 +245,14 @@ struct Precompute Builder builder(*getModule()); replaceCurrent(builder.makeBreak( flow.breakTo, - flow.value.type != none ? builder.makeConst(flow.value) : nullptr)); + flow.value.type != none ? builder.makeConstExpression(flow.value) + : nullptr)); } return; } // this was precomputed if (flow.value.type.isConcrete()) { - replaceCurrent(Builder(*getModule()).makeConst(flow.value)); + replaceCurrent(Builder(*getModule()).makeConstExpression(flow.value)); worked = true; } else { ExpressionManipulator::nop(curr); @@ -350,7 +353,7 @@ private: } else { curr = setValues[set]; } - if (curr.isNull()) { + if (curr.isNone()) { // not a constant, give up value = Literal(); break; diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 5efd1fd28..51e78c8a7 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -1333,7 +1333,12 @@ struct PrintExpressionContents } restoreNormalColor(o); } - void visitSelect(Select* curr) { prepareColor(o) << "select"; } + void visitSelect(Select* curr) { + prepareColor(o) << "select"; + if (curr->type.isRef()) { + o << " (result " << curr->type << ')'; + } + } void visitDrop(Drop* curr) { printMedium(o, "drop"); } void visitReturn(Return* curr) { printMedium(o, "return"); } void visitHost(Host* curr) { @@ -1346,6 +1351,12 @@ struct PrintExpressionContents break; } } + void visitRefNull(RefNull* curr) { printMedium(o, "ref.null"); } + void visitRefIsNull(RefIsNull* curr) { printMedium(o, "ref.is_null"); } + void visitRefFunc(RefFunc* curr) { + printMedium(o, "ref.func "); + printName(curr->func, o); + } void visitTry(Try* curr) { printMedium(o, "try"); if (curr->type.isConcrete()) { @@ -1852,6 +1863,23 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { } } } + void visitRefNull(RefNull* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + o << ')'; + } + void visitRefIsNull(RefIsNull* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + incIndent(); + printFullLine(curr->value); + decIndent(); + } + void visitRefFunc(RefFunc* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + o << ')'; + } // try-catch-end is written in the folded wat format as // (try // ... @@ -2434,13 +2462,15 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) { } case StackInst::BlockBegin: case StackInst::IfBegin: - case StackInst::LoopBegin: { + case StackInst::LoopBegin: + case StackInst::TryBegin: { o << getExpressionName(inst->origin); break; } case StackInst::BlockEnd: case StackInst::IfEnd: - case StackInst::LoopEnd: { + case StackInst::LoopEnd: + case StackInst::TryEnd: { o << "end (" << inst->type << ')'; break; } @@ -2448,6 +2478,10 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) { o << "else"; break; } + case StackInst::Catch: { + o << "catch"; + break; + } default: WASM_UNREACHABLE("unexpeted op"); } diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index f5000e3a4..21cbc5e5b 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -116,6 +116,12 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { usesMemory = true; } } + void visitRefFunc(RefFunc* curr) { + if (reachable.count( + ModuleElement(ModuleElementKind::Function, curr->func)) == 0) { + queue.emplace_back(ModuleElementKind::Function, curr->func); + } + } void visitThrow(Throw* curr) { if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) == 0) { diff --git a/src/passes/SimplifyGlobals.cpp b/src/passes/SimplifyGlobals.cpp index 88f27f8be..b18f726ed 100644 --- a/src/passes/SimplifyGlobals.cpp +++ b/src/passes/SimplifyGlobals.cpp @@ -37,6 +37,7 @@ #include <atomic> #include "ir/effects.h" +#include "ir/properties.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" @@ -106,8 +107,9 @@ struct ConstantGlobalApplier void visitExpression(Expression* curr) { if (auto* set = curr->dynCast<GlobalSet>()) { - if (auto* c = set->value->dynCast<Const>()) { - currConstantGlobals[set->name] = c->value; + if (Properties::isConstantExpression(set->value)) { + currConstantGlobals[set->name] = + getLiteralFromConstExpression(set->value); } else { currConstantGlobals.erase(set->name); } @@ -116,7 +118,7 @@ struct ConstantGlobalApplier // Check if the global is known to be constant all the time. if (constantGlobals->count(get->name)) { auto* global = getModule()->getGlobal(get->name); - assert(global->init->is<Const>()); + assert(Properties::isConstantExpression(global->init)); replaceCurrent(ExpressionManipulator::copy(global->init, *getModule())); replaced = true; return; @@ -125,7 +127,7 @@ struct ConstantGlobalApplier auto iter = currConstantGlobals.find(get->name); if (iter != currConstantGlobals.end()) { Builder builder(*getModule()); - replaceCurrent(builder.makeConst(iter->second)); + replaceCurrent(builder.makeConstExpression(iter->second)); replaced = true; } return; @@ -249,13 +251,14 @@ struct SimplifyGlobals : public Pass { std::map<Name, Literal> constantGlobals; for (auto& global : module->globals) { if (!global->imported()) { - if (auto* c = global->init->dynCast<Const>()) { - constantGlobals[global->name] = c->value; + if (Properties::isConstantExpression(global->init)) { + constantGlobals[global->name] = + getLiteralFromConstExpression(global->init); } else if (auto* get = global->init->dynCast<GlobalGet>()) { auto iter = constantGlobals.find(get->name); if (iter != constantGlobals.end()) { Builder builder(*module); - global->init = builder.makeConst(iter->second); + global->init = builder.makeConstExpression(iter->second); } } } @@ -268,7 +271,7 @@ struct SimplifyGlobals : public Pass { NameSet constantGlobals; for (auto& global : module->globals) { if (!global->mutable_ && !global->imported() && - global->init->is<Const>()) { + Properties::isConstantExpression(global->init)) { constantGlobals.insert(global->name); } } diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index a3fa4a34d..a952f8a38 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -546,7 +546,6 @@ struct SimplifyLocals auto* blockLocalSetPointer = sinkables.at(sharedIndex).item; auto* value = (*blockLocalSetPointer)->template cast<LocalSet>()->value; block->list[block->list.size() - 1] = value; - block->type = value->type; ExpressionManipulator::nop(*blockLocalSetPointer); for (size_t j = 0; j < breaks.size(); j++) { // move break local.set's value to the break @@ -577,6 +576,7 @@ struct SimplifyLocals this->replaceCurrent(newLocalSet); sinkables.clear(); anotherCycle = true; + block->finalize(); } // optimize local.sets from both sides of an if into a return value @@ -915,6 +915,7 @@ struct SimplifyLocals void visitLocalSet(LocalSet* curr) { // Remove trivial copies, even through a tee auto* value = curr->value; + Function* func = this->getFunction(); while (auto* subSet = value->dynCast<LocalSet>()) { value = subSet->value; } @@ -929,7 +930,8 @@ struct SimplifyLocals } anotherCycle = true; } - } else { + } else if (func->getLocalType(curr->index) == + func->getLocalType(get->index)) { // There is a new equivalence now. equivalences.reset(curr->index); equivalences.add(curr->index, get->index); diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h index 93fac137f..7912a7d92 100644 --- a/src/passes/opt-utils.h +++ b/src/passes/opt-utils.h @@ -54,19 +54,22 @@ inline void optimizeAfterInlining(std::unordered_set<Function*>& funcs, module->updateMaps(); } -struct CallTargetReplacer : public WalkerPass<PostWalker<CallTargetReplacer>> { +struct FunctionRefReplacer + : public WalkerPass<PostWalker<FunctionRefReplacer>> { bool isFunctionParallel() override { return true; } using MaybeReplace = std::function<void(Name&)>; - CallTargetReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {} + FunctionRefReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {} - CallTargetReplacer* create() override { - return new CallTargetReplacer(maybeReplace); + FunctionRefReplacer* create() override { + return new FunctionRefReplacer(maybeReplace); } void visitCall(Call* curr) { maybeReplace(curr->target); } + void visitRefFunc(RefFunc* curr) { maybeReplace(curr->func); } + private: MaybeReplace maybeReplace; }; @@ -81,7 +84,7 @@ inline void replaceFunctions(PassRunner* runner, } }; // replace direct calls - CallTargetReplacer(maybeReplace).run(runner, &module); + FunctionRefReplacer(maybeReplace).run(runner, &module); // replace in table for (auto& segment : module.table.segments) { for (auto& name : segment.data) { diff --git a/src/shell-interface.h b/src/shell-interface.h index 52533f37c..75f8e81b8 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -114,10 +114,12 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { break; case v128: assert(false && "v128 not implemented yet"); + case funcref: case anyref: - assert(false && "anyref not implemented yet"); + case nullref: case exnref: - assert(false && "exnref not implemented yet"); + globals[import->name] = Literal::makeNullref(); + break; case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -163,7 +165,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { trap("callIndirect: bad # of arguments"); } for (size_t i = 0; i < params.size(); i++) { - if (params[i] != arguments[i].type) { + if (!Type::isSubType(arguments[i].type, params[i])) { trap("callIndirect: bad argument type"); } } diff --git a/src/support/name.h b/src/support/name.h index 2bc50abf0..615740e09 100644 --- a/src/support/name.h +++ b/src/support/name.h @@ -17,7 +17,7 @@ #ifndef wasm_support_name_h #define wasm_support_name_h -#include <cstring> +#include <string> #include "emscripten-optimizer/istring.h" diff --git a/src/support/small_vector.h b/src/support/small_vector.h index 7f00bd4a6..d4ad961a7 100644 --- a/src/support/small_vector.h +++ b/src/support/small_vector.h @@ -38,17 +38,15 @@ template<typename T, size_t N> class SmallVector { std::vector<T> flexible; public: + using value_type = T; + SmallVector() {} T& operator[](size_t i) { - if (i < N) { - return fixed[i]; - } else { - return flexible[i - N]; - } + return const_cast<T&>(static_cast<const SmallVector<T, N>&>(*this)[i]); } - T operator[](size_t i) const { + const T& operator[](size_t i) const { if (i < N) { return fixed[i]; } else { diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index c0c7428cc..7787dba25 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -69,11 +69,17 @@ struct ExecutionResults { auto* func = wasm.getFunction(exp->value); if (func->sig.results != Type::none) { // this has a result - results[exp->name] = run(func, wasm, instance); - // ignore the result if we hit an unreachable and returned no value - if (results[exp->name].type.isConcrete()) { - std::cout << "[fuzz-exec] note result: " << exp->name << " => " - << results[exp->name] << '\n'; + Literal ret = run(func, wasm, instance); + // We cannot compare funcrefs by name because function names can + // change (after duplicate function elimination or roundtripping) + // while the function contents are still the same + if (ret.type != Type::funcref) { + results[exp->name] = ret; + // ignore the result if we hit an unreachable and returned no value + if (results[exp->name].type.isConcrete()) { + std::cout << "[fuzz-exec] note result: " << exp->name << " => " + << results[exp->name] << '\n'; + } } } else { // no result, run it anyhow (it might modify memory etc.) @@ -100,17 +106,17 @@ struct ExecutionResults { auto name = iter.first; if (results.find(name) == results.end()) { std::cout << "[fuzz-exec] missing " << name << '\n'; - abort(); + return false; } std::cout << "[fuzz-exec] comparing " << name << '\n'; if (results[name] != other.results[name]) { std::cout << "not identical!\n"; - abort(); + return false; } } if (loggings != other.loggings) { std::cout << "logging not identical!\n"; - abort(); + return false; } return true; } @@ -138,7 +144,7 @@ struct ExecutionResults { // call the method for (Type param : func->sig.params.expand()) { // zeros in arguments TODO: more? - arguments.push_back(Literal(param)); + arguments.push_back(Literal::makeZero(param)); } return instance.callFunction(func->name, arguments); } catch (const TrapException&) { diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index ce302fac6..ff0888f1d 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -25,8 +25,7 @@ high chance for set at start of loop high chance of a tee in that case => loop var */ -// TODO Complete exnref type support. Its support is partialy implemented -// and the type is currently not generated in fuzzed programs yet. +// TODO Generate exception handling instructions #include "ir/memory-utils.h" #include <ir/find_all.h> @@ -310,6 +309,24 @@ private: double getDouble() { return Literal(get64()).reinterpretf64(); } + SmallVector<Type, 2> getSubTypes(Type type) { + SmallVector<Type, 2> ret; + ret.push_back(type); // includes itself + switch (type) { + case Type::anyref: + ret.push_back(Type::funcref); + ret.push_back(Type::exnref); + // falls through + case Type::funcref: + case Type::exnref: + ret.push_back(Type::nullref); + break; + default: + break; + } + return ret; + } + void setupMemory() { // Add memory itself MemoryUtils::ensureExists(wasm.memory); @@ -404,10 +421,12 @@ private: Index num = upTo(3); for (size_t i = 0; i < num; i++) { // Events should have void return type and at least one param type + Type type = getConcreteType(); std::vector<Type> params; + params.push_back(type); Index numValues = upToSquared(MAX_PARAMS - 1); for (Index i = 0; i < numValues + 1; i++) { - params.push_back(pick(i32, i64, f32, f64)); + params.push_back(getConcreteType()); } auto* event = builder.makeEvent(std::string("event$") + std::to_string(i), WASM_EVENT_ATTRIBUTE_EXCEPTION, @@ -447,7 +466,7 @@ private: } void addImportLoggingSupport() { - for (auto type : getConcreteTypes()) { + for (auto type : getLoggableTypes()) { auto* func = new Function; Name name = std::string("log-") + type.toString(); func->name = name; @@ -501,7 +520,7 @@ private: // function generation state - Function* func; + Function* func = nullptr; std::vector<Expression*> breakableStack; // things we can break to Index labelIndex; @@ -585,10 +604,12 @@ private: // loop limit FindAll<Loop> loops(func->body); for (auto* loop : loops.list) { - loop->body = builder.makeSequence(makeHangLimitCheck(), loop->body); + loop->body = + builder.makeSequence(makeHangLimitCheck(), loop->body, loop->type); } // recursion limit - func->body = builder.makeSequence(makeHangLimitCheck(), func->body); + func->body = + builder.makeSequence(makeHangLimitCheck(), func->body, func->sig.results); } void recombine(Function* func) { @@ -841,7 +862,9 @@ private: case f32: case f64: case v128: + case funcref: case anyref: + case nullref: case exnref: ret = _makeConcrete(type); break; @@ -852,7 +875,8 @@ private: ret = _makeunreachable(); break; } - assert(ret->type == type); // we should create the right type of thing + // we should create the right type of thing + assert(Type::isSubType(ret->type, type)); nesting--; return ret; } @@ -898,9 +922,12 @@ private: &Self::makeSelect, &Self::makeGlobalGet) .add(FeatureSet::SIMD, &Self::makeSIMD); - if (type == i32 || type == i64) { + if (type == Type::i32 || type == Type::i64) { options.add(FeatureSet::Atomics, &Self::makeAtomic); } + if (type == Type::i32) { + options.add(FeatureSet::ReferenceTypes, &Self::makeRefIsNull); + } return (this->*pick(options))(type); } @@ -1064,11 +1091,11 @@ private: // possible branch back list.push_back(builder.makeBreak(ret->name, nullptr, makeCondition())); list.push_back(make(type)); // final element, so we have the right type - ret->body = builder.makeBlock(list); + ret->body = builder.makeBlock(list, type); } breakableStack.pop_back(); hangStack.pop_back(); - ret->finalize(); + ret->finalize(type); return ret; } @@ -1093,15 +1120,15 @@ private: } } - Expression* buildIf(const struct ThreeArgs& args) { - return builder.makeIf(args.a, args.b, args.c); + Expression* buildIf(const struct ThreeArgs& args, Type type) { + return builder.makeIf(args.a, args.b, args.c, type); } Expression* makeIf(Type type) { auto* condition = makeCondition(); hangStack.push_back(nullptr); auto* ret = - buildIf({condition, makeMaybeBlock(type), makeMaybeBlock(type)}); + buildIf({condition, makeMaybeBlock(type), makeMaybeBlock(type)}, type); hangStack.pop_back(); return ret; } @@ -1360,8 +1387,10 @@ private: return builder.makeLoad( 16, false, offset, pick(1, 2, 4, 8, 16), ptr, type); } - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("invalid type"); @@ -1370,8 +1399,8 @@ private: } Expression* makeLoad(Type type) { - // exnref type cannot be stored in memory - if (!allowMemory || type == exnref) { + // reference types cannot be stored in memory + if (!allowMemory || type.isRef()) { return makeTrivial(type); } auto* ret = makeNonAtomicLoad(type); @@ -1393,7 +1422,7 @@ private: Expression* makeNonAtomicStore(Type type) { if (type == unreachable) { // make a normal store, then make it unreachable - auto* ret = makeNonAtomicStore(getConcreteType()); + auto* ret = makeNonAtomicStore(getStorableType()); auto* store = ret->dynCast<Store>(); if (!store) { return ret; @@ -1416,7 +1445,7 @@ private: // the type is none or unreachable. we also need to pick the value // type. if (type == none) { - type = getConcreteType(); + type = getStorableType(); } auto offset = logify(get()); auto ptr = makePointer(); @@ -1462,8 +1491,10 @@ private: return builder.makeStore( 16, offset, pick(1, 2, 4, 8, 16), ptr, value, type); } - case anyref: // anyref cannot be stored in memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("invalid type"); @@ -1472,7 +1503,6 @@ private: } Expression* makeStore(Type type) { - // exnref type cannot be stored in memory if (!allowMemory || type.isRef()) { return makeTrivial(type); } @@ -1558,8 +1588,10 @@ private: case f64: return Literal(getDouble()); case v128: - case anyref: // anyref cannot have literals - case exnref: // exnref cannot have literals + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("invalid type"); @@ -1601,8 +1633,10 @@ private: case f64: return Literal(double(small)); case v128: - case anyref: // anyref cannot have literals - case exnref: // exnref cannot have literals + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1667,8 +1701,10 @@ private: std::numeric_limits<uint64_t>::max())); break; case v128: - case anyref: // anyref cannot have literals - case exnref: // exnref cannot have literals + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1699,8 +1735,10 @@ private: value = Literal(double(int64_t(1) << upTo(64))); break; case v128: - case anyref: // anyref cannot have literals - case exnref: // exnref cannot have literals + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1724,21 +1762,23 @@ private: } Expression* makeConst(Type type) { - switch (type) { - case anyref: - // There's no anyref.const. - // TODO We should return a nullref once we implement instructions for - // reference types proposal. - assert(false && "anyref const is not implemented yet"); - case exnref: - // There's no exnref.const. - // TODO We should return a nullref once we implement instructions for - // reference types proposal. - assert(false && "exnref const is not implemented yet"); - default: - break; + if (type.isRef()) { + assert(wasm.features.hasReferenceTypes()); + // Check if we can use ref.func. + // 'func' is the pointer to the last created function and can be null when + // we set up globals (before we create any functions), in which case we + // can't use ref.func. + if (type == Type::funcref && func && oneIn(2)) { + // First set to target to the last created function, and try to select + // among other existing function if possible + Function* target = func; + if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) { + target = pick(wasm.functions).get(); + } + return builder.makeRefFunc(target->name); + } + return builder.makeRefNull(); } - auto* ret = wasm.allocator.alloc<Const>(); ret->value = makeLiteral(type); ret->type = type; @@ -1757,9 +1797,9 @@ private: // give up return makeTrivial(type); } - // There's no binary ops for exnref - if (type == exnref) { - makeTrivial(type); + // There's no unary ops for reference types + if (type.isRef()) { + return makeTrivial(type); } switch (type) { @@ -1807,8 +1847,11 @@ private: AllTrueVecI64x2), make(v128)}); } - case anyref: // there's no unary ops for anyref - case exnref: // there's no unary ops for exnref + case funcref: + case anyref: + case nullref: + case exnref: + return makeTrivial(type); case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1947,8 +1990,10 @@ private: } WASM_UNREACHABLE("invalid value"); } - case anyref: // there's no unary ops for anyref - case exnref: // there's no unary ops for exnref + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1969,9 +2014,9 @@ private: // give up return makeTrivial(type); } - // There's no binary ops for exnref + // There's no binary ops for reference types if (type.isRef()) { - makeTrivial(type); + return makeTrivial(type); } switch (type) { @@ -2180,8 +2225,10 @@ private: make(v128), make(v128)}); } - case anyref: // there's no binary ops for anyref - case exnref: // there's no binary ops for exnref + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -2189,12 +2236,15 @@ private: WASM_UNREACHABLE("invalid type"); } - Expression* buildSelect(const ThreeArgs& args) { - return builder.makeSelect(args.a, args.b, args.c); + Expression* buildSelect(const ThreeArgs& args, Type type) { + return builder.makeSelect(args.a, args.b, args.c, type); } Expression* makeSelect(Type type) { - return makeDeNanOp(buildSelect({make(i32), make(type), make(type)})); + Type subType1 = pick(getSubTypes(type)); + Type subType2 = pick(getSubTypes(type)); + return makeDeNanOp( + buildSelect({make(i32), make(subType1), make(subType2)}, type)); } Expression* makeSwitch(Type type) { @@ -2338,6 +2388,9 @@ private: Expression* makeSIMD(Type type) { assert(wasm.features.hasSIMD()); + if (type.isRef()) { + return makeTrivial(type); + } if (type != v128) { return makeSIMDExtract(type); } @@ -2380,7 +2433,9 @@ private: op = ExtractLaneVecF64x2; break; case v128: + case funcref: case anyref: + case nullref: case exnref: case none: case unreachable: @@ -2549,6 +2604,18 @@ private: WASM_UNREACHABLE("invalid value"); } + Expression* makeRefIsNull(Type type) { + assert(type == Type::i32); + assert(wasm.features.hasReferenceTypes()); + Type refType; + if (wasm.features.hasExceptionHandling()) { + refType = pick(Type::funcref, Type::anyref, Type::nullref, Type::exnref); + } else { + refType = pick(Type::funcref, Type::anyref, Type::nullref); + } + return builder.makeRefIsNull(make(refType)); + } + Expression* makeMemoryInit() { if (!allowMemory) { return makeTrivial(none); @@ -2593,7 +2660,7 @@ private: // special makers Expression* makeLogging() { - auto type = getConcreteType(); + auto type = getLoggableType(); return builder.makeCall( std::string("log-") + type.toString(), {make(type)}, none); } @@ -2605,20 +2672,64 @@ private: // special getters - Type getReachableType() { - return pick(FeatureOptions<Type>() - .add(FeatureSet::MVP, i32, i64, f32, f64, none) - .add(FeatureSet::SIMD, v128)); - } + std::vector<Type> getReachableTypes() { + return items(FeatureOptions<Type>() + .add(FeatureSet::MVP, + Type::i32, + Type::i64, + Type::f32, + Type::f64, + Type::none) + .add(FeatureSet::SIMD, Type::v128) + .add(FeatureSet::ReferenceTypes, + Type::funcref, + Type::anyref, + Type::nullref) + .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes | + FeatureSet::ExceptionHandling), + Type::exnref)); + } + Type getReachableType() { return pick(getReachableTypes()); } std::vector<Type> getConcreteTypes() { - return items(FeatureOptions<Type>() - .add(FeatureSet::MVP, i32, i64, f32, f64) - .add(FeatureSet::SIMD, v128)); + return items( + FeatureOptions<Type>() + .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64) + .add(FeatureSet::SIMD, Type::v128) + .add(FeatureSet::ReferenceTypes, + Type::funcref, + Type::anyref, + Type::nullref) + .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes | + FeatureSet::ExceptionHandling), + Type::exnref)); } - Type getConcreteType() { return pick(getConcreteTypes()); } + // Get types that can be stored in memory + std::vector<Type> getStorableTypes() { + return items( + FeatureOptions<Type>() + .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64) + .add(FeatureSet::SIMD, Type::v128)); + } + Type getStorableType() { return pick(getStorableTypes()); } + + // - funcref cannot be logged because referenced functions can be inlined or + // removed during optimization + // - there's no point in logging anyref because it is opaque + std::vector<Type> getLoggableTypes() { + return items( + FeatureOptions<Type>() + .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64) + .add(FeatureSet::SIMD, Type::v128) + .add(FeatureSet::ReferenceTypes, Type::nullref) + .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes | + FeatureSet::ExceptionHandling), + Type::exnref)); + } + Type getLoggableType() { return pick(getLoggableTypes()); } + // statistical distributions // 0 to the limit, logarithmic scale @@ -2659,8 +2770,8 @@ private: // low values Index upToSquared(Index x) { return upTo(upTo(x)); } - // pick from a vector - template<typename T> const T& pick(const std::vector<T>& vec) { + // pick from a vector-like container + template<typename T> const typename T::value_type& pick(const T& vec) { assert(!vec.empty()); auto index = upTo(vec.size()); return vec[index]; diff --git a/src/tools/spec-wrapper.h b/src/tools/spec-wrapper.h index beada1b4b..f59291e55 100644 --- a/src/tools/spec-wrapper.h +++ b/src/tools/spec-wrapper.h @@ -48,8 +48,12 @@ static std::string generateSpecWrapper(Module& wasm) { case v128: ret += "(v128.const i32x4 0 0 0 0)"; break; - case anyref: // there's no anyref.const - case exnref: // there's no exnref.const + case funcref: + case anyref: + case nullref: + case exnref: + ret += "(ref.null)"; + break; case none: case unreachable: WASM_UNREACHABLE("unexpected type"); diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 274b6de29..6adb1e174 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -592,7 +592,9 @@ struct Reducer fixed = builder->makeUnary(TruncSFloat64ToInt32, child); break; case v128: + case funcref: case anyref: + case nullref: case exnref: continue; // not implemented yet case none: @@ -615,7 +617,9 @@ struct Reducer fixed = builder->makeUnary(TruncSFloat64ToInt64, child); break; case v128: + case funcref: case anyref: + case nullref: case exnref: continue; // not implemented yet case none: @@ -638,7 +642,9 @@ struct Reducer fixed = builder->makeUnary(DemoteFloat64, child); break; case v128: + case funcref: case anyref: + case nullref: case exnref: continue; // not implemented yet case none: @@ -661,7 +667,9 @@ struct Reducer case f64: WASM_UNREACHABLE("unexpected type"); case v128: + case funcref: case anyref: + case nullref: case exnref: continue; // not implemented yet case none: @@ -671,7 +679,9 @@ struct Reducer break; } case v128: + case funcref: case anyref: + case nullref: case exnref: continue; // not implemented yet case none: @@ -999,6 +1009,10 @@ struct Reducer return false; } // try to replace with a trivial value + if (curr->type.isRef()) { + RefNull* n = builder->makeRefNull(); + return tryToReplaceCurrent(n); + } Const* c = builder->makeConst(Literal(int32_t(0))); if (tryToReplaceCurrent(c)) { return true; diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index d5ee60d8a..6c9d3f36a 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -74,7 +74,7 @@ struct Operation { name = element[i++]->str(); for (size_t j = i; j < element.size(); j++) { Expression* argument = builder.parseExpression(*element[j]); - arguments.push_back(argument->dynCast<Const>()->value); + arguments.push_back(getLiteralFromConstExpression(argument)); } } @@ -214,7 +214,7 @@ static void run_asserts(Name moduleName, assert(!trapped); if (curr.size() >= 3) { Literal expected = - builder->parseExpression(*curr[2])->dynCast<Const>()->value; + getLiteralFromConstExpression(builder->parseExpression(*curr[2])); std::cerr << "seen " << result << ", expected " << expected << '\n'; if (expected != result) { std::cout << "unexpected, should be identical\n"; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 4206defdf..f019d0792 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -343,10 +343,12 @@ enum EncodedType { f32 = -0x3, // 0x7d f64 = -0x4, // 0x7c v128 = -0x5, // 0x7b - // elem_type - AnyFunc = -0x10, // 0x70 + // function reference type + funcref = -0x10, // 0x70 // opaque reference type anyref = -0x11, // 0x6f + // null reference type + nullref = -0x12, // 0x6e // exception reference type exnref = -0x18, // 0x68 // func_type form @@ -402,6 +404,7 @@ enum ASTNodes { Drop = 0x1a, Select = 0x1b, + SelectWithType = 0x1c, // added in reference types proposal LocalGet = 0x20, LocalSet = 0x21, @@ -867,6 +870,12 @@ enum ASTNodes { MemoryCopy = 0x0a, MemoryFill = 0x0b, + // reference types opcodes + + RefNull = 0xd0, + RefIsNull = 0xd1, + RefFunc = 0xd2, + // exception handling opcodes Try = 0x06, @@ -914,9 +923,15 @@ inline S32LEB binaryType(Type type) { case v128: ret = BinaryConsts::EncodedType::v128; break; + case funcref: + ret = BinaryConsts::EncodedType::funcref; + break; case anyref: ret = BinaryConsts::EncodedType::anyref; break; + case nullref: + ret = BinaryConsts::EncodedType::nullref; + break; case exnref: ret = BinaryConsts::EncodedType::exnref; break; @@ -1143,8 +1158,8 @@ public: // we store function imports here before wasm.addFunctionImport after we know // their names std::vector<Function*> functionImports; - // at index i we have all calls to the function i - std::map<Index, std::vector<Call*>> functionCalls; + // at index i we have all refs to the function i + std::map<Index, std::vector<Expression*>> functionRefs; Function* currFunction = nullptr; // before we see a function (like global init expressions), there is no end of // function to check @@ -1279,12 +1294,15 @@ public: bool maybeVisitDataDrop(Expression*& out, uint32_t code); bool maybeVisitMemoryCopy(Expression*& out, uint32_t code); bool maybeVisitMemoryFill(Expression*& out, uint32_t code); - void visitSelect(Select* curr); + void visitSelect(Select* curr, uint8_t code); void visitReturn(Return* curr); bool maybeVisitHost(Expression*& out, uint8_t code); void visitNop(Nop* curr); void visitUnreachable(Unreachable* curr); void visitDrop(Drop* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 918e6a4ab..38009cb8a 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -110,6 +110,12 @@ public: ret->finalize(); return ret; } + Block* makeBlock(const std::vector<Expression*>& items, Type type) { + auto* ret = allocator.alloc<Block>(); + ret->list.set(items); + ret->finalize(type); + return ret; + } Block* makeBlock(const ExpressionList& items) { auto* ret = allocator.alloc<Block>(); ret->list.set(items); @@ -164,6 +170,13 @@ public: ret->finalize(); return ret; } + Loop* makeLoop(Name name, Expression* body, Type type) { + auto* ret = allocator.alloc<Loop>(); + ret->name = name; + ret->body = body; + ret->finalize(type); + return ret; + } Break* makeBreak(Name name, Expression* value = nullptr, Expression* condition = nullptr) { @@ -459,6 +472,7 @@ public: return ret; } Const* makeConst(Literal value) { + assert(value.type.isNumber()); auto* ret = allocator.alloc<Const>(); ret->value = value; ret->type = value.type; @@ -488,6 +502,17 @@ public: ret->finalize(); return ret; } + Select* makeSelect(Expression* condition, + Expression* ifTrue, + Expression* ifFalse, + Type type) { + auto* ret = allocator.alloc<Select>(); + ret->condition = condition; + ret->ifTrue = ifTrue; + ret->ifFalse = ifFalse; + ret->finalize(type); + return ret; + } Return* makeReturn(Expression* value = nullptr) { auto* ret = allocator.alloc<Return>(); ret->value = value; @@ -502,6 +527,23 @@ public: ret->finalize(); return ret; } + RefNull* makeRefNull() { + auto* ret = allocator.alloc<RefNull>(); + ret->finalize(); + return ret; + } + RefIsNull* makeRefIsNull(Expression* value) { + auto* ret = allocator.alloc<RefIsNull>(); + ret->value = value; + ret->finalize(); + return ret; + } + RefFunc* makeRefFunc(Name func) { + auto* ret = allocator.alloc<RefFunc>(); + ret->func = func; + ret->finalize(); + return ret; + } Try* makeTry(Expression* body, Expression* catchBody) { auto* ret = allocator.alloc<Try>(); ret->body = body; @@ -569,6 +611,21 @@ public: return ret; } + Expression* makeConstExpression(Literal value) { + switch (value.type) { + case Type::nullref: + return makeRefNull(); + case Type::funcref: + if (value.getFunc()[0] != 0) { + return makeRefFunc(value.getFunc()); + } + return makeRefNull(); + default: + assert(value.type.isNumber()); + return makeConst(value); + } + } + // Additional utility functions for building on top of nodes // Convenient to have these on Builder, as it has allocation built in @@ -663,6 +720,13 @@ public: return block; } + Block* makeSequence(Expression* left, Expression* right, Type type) { + auto* block = makeBlock(left); + block->list.push_back(right); + block->finalize(type); + return block; + } + // Grab a slice out of a block, replacing it with nops, and returning // either another block with the contents (if more than 1) or a single // expression @@ -728,16 +792,15 @@ public: value = Literal(bytes.data()); break; } + case funcref: case anyref: - // TODO Implement and return nullref - assert(false && "anyref not implemented yet"); + case nullref: case exnref: - // TODO Implement and return nullref - assert(false && "exnref not implemented yet"); + return ExpressionManipulator::refNull(curr); case none: return ExpressionManipulator::nop(curr); case unreachable: - return ExpressionManipulator::convert<T, Unreachable>(curr); + return ExpressionManipulator::unreachable(curr); } return makeConst(value); } diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 571f0d1a5..f37a6edd6 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -143,13 +143,13 @@ public: if (!ret.breaking() && (curr->type.isConcrete() || ret.value.type.isConcrete())) { #if 1 // def WASM_INTERPRETER_DEBUG - if (ret.value.type != curr->type) { + if (!Type::isSubType(ret.value.type, curr->type)) { std::cerr << "expected " << curr->type << ", seeing " << ret.value.type << " from\n" << curr << '\n'; } #endif - assert(ret.value.type == curr->type); + assert(Type::isSubType(ret.value.type, curr->type)); } depth--; return ret; @@ -1095,7 +1095,7 @@ public: return Literal(uint64_t(val)); } } - Flow visitAtomicFence(AtomicFence*) { + Flow visitAtomicFence(AtomicFence* curr) { // Wasm currently supports only sequentially consistent atomics, in which // case atomic_fence can be lowered to nothing. NOTE_ENTER("AtomicFence"); @@ -1123,6 +1123,26 @@ public: Flow visitSIMDLoadExtend(SIMDLoad*) { WASM_UNREACHABLE("unimp"); } Flow visitPush(Push*) { WASM_UNREACHABLE("unimp"); } Flow visitPop(Pop*) { WASM_UNREACHABLE("unimp"); } + Flow visitRefNull(RefNull* curr) { + NOTE_ENTER("RefNull"); + return Literal::makeNullref(); + } + Flow visitRefIsNull(RefIsNull* curr) { + NOTE_ENTER("RefIsNull"); + Flow flow = visit(curr->value); + if (flow.breaking()) { + return flow; + } + Literal value = flow.value; + NOTE_EVAL1(value); + return Literal(value.type == nullref); + } + Flow visitRefFunc(RefFunc* curr) { + NOTE_ENTER("RefFunc"); + NOTE_NAME(curr->func); + return Literal::makeFuncref(curr->func); + } + // TODO Implement EH instructions Flow visitTry(Try*) { WASM_UNREACHABLE("unimp"); } Flow visitThrow(Throw*) { WASM_UNREACHABLE("unimp"); } Flow visitRethrow(Rethrow*) { WASM_UNREACHABLE("unimp"); } @@ -1217,8 +1237,10 @@ public: return Literal(load64u(addr)).castToF64(); case v128: return Literal(load128(addr).data()); - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1272,8 +1294,10 @@ public: case v128: store128(addr, value.getv128()); break; - case anyref: // anyref cannot be stored from memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1464,7 +1488,7 @@ private: for (size_t i = 0; i < function->getNumLocals(); i++) { if (i < arguments.size()) { assert(i < params.size()); - if (params[i] != arguments[i].type) { + if (!Type::isSubType(arguments[i].type, params[i])) { std::cerr << "Function `" << function->name << "` expects type " << params[i] << " for parameter " << i << ", got " << arguments[i].type << "." << std::endl; @@ -1473,7 +1497,7 @@ private: locals[i] = arguments[i]; } else { assert(function->isVar(i)); - locals[i].type = function->getLocalType(i); + locals[i] = Literal::makeZero(function->getLocalType(i)); } } } @@ -1580,7 +1604,8 @@ private: } NOTE_EVAL1(index); NOTE_EVAL1(flow.value); - assert(curr->isTee() ? flow.value.type == curr->type : true); + assert(curr->isTee() ? Type::isSubType(flow.value.type, curr->type) + : true); scope.locals[index] = flow.value; return curr->isTee() ? flow : Flow(); } @@ -2067,7 +2092,7 @@ public: // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); Literal ret = flow.value; - if (function->sig.results != ret.type) { + if (!Type::isSubType(ret.type, function->sig.results)) { std::cerr << "calling " << function->name << " resulted in " << ret << " but the function type is " << function->sig.results << '\n'; diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index d7324d756..8cdcb88f4 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -225,6 +225,9 @@ private: Expression* makeBreak(Element& s); Expression* makeBreakTable(Element& s); Expression* makeReturn(Element& s); + Expression* makeRefNull(Element& s); + Expression* makeRefIsNull(Element& s); + Expression* makeRefFunc(Element& s); Expression* makeTry(Element& s); Expression* makeCatch(Element& s, Type type); Expression* makeThrow(Element& s); diff --git a/src/wasm-stack.h b/src/wasm-stack.h index fbd28b0d5..91c0c5383 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -128,6 +128,9 @@ public: void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -207,6 +210,9 @@ public: void visitSelect(Select* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefNull(RefNull* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -698,6 +704,30 @@ void BinaryenIRWriter<SubType>::visitHost(Host* curr) { emit(curr); } +template<typename SubType> +void BinaryenIRWriter<SubType>::visitRefNull(RefNull* curr) { + emit(curr); +} + +template<typename SubType> +void BinaryenIRWriter<SubType>::visitRefIsNull(RefIsNull* curr) { + visit(curr->value); + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + +template<typename SubType> +void BinaryenIRWriter<SubType>::visitRefFunc(RefFunc* curr) { + if (curr->type == Type::unreachable) { + emitUnreachable(); + return; + } + emit(curr); +} + template<typename SubType> void BinaryenIRWriter<SubType>::visitTry(Try* curr) { emit(curr); visitPossibleBlockContents(curr->body); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 9c6e78360..c9290cbab 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -72,6 +72,9 @@ template<typename SubType, typename ReturnType = void> struct Visitor { ReturnType visitDrop(Drop* curr) { return ReturnType(); } ReturnType visitReturn(Return* curr) { return ReturnType(); } ReturnType visitHost(Host* curr) { return ReturnType(); } + ReturnType visitRefNull(RefNull* curr) { return ReturnType(); } + ReturnType visitRefIsNull(RefIsNull* curr) { return ReturnType(); } + ReturnType visitRefFunc(RefFunc* curr) { return ReturnType(); } ReturnType visitTry(Try* curr) { return ReturnType(); } ReturnType visitThrow(Throw* curr) { return ReturnType(); } ReturnType visitRethrow(Rethrow* curr) { return ReturnType(); } @@ -167,6 +170,12 @@ template<typename SubType, typename ReturnType = void> struct Visitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -241,6 +250,9 @@ struct OverriddenVisitor { UNIMPLEMENTED(Drop); UNIMPLEMENTED(Return); UNIMPLEMENTED(Host); + UNIMPLEMENTED(RefNull); + UNIMPLEMENTED(RefIsNull); + UNIMPLEMENTED(RefFunc); UNIMPLEMENTED(Try); UNIMPLEMENTED(Throw); UNIMPLEMENTED(Rethrow); @@ -337,6 +349,12 @@ struct OverriddenVisitor { DELEGATE(Return); case Expression::Id::HostId: DELEGATE(Host); + case Expression::Id::RefNullId: + DELEGATE(RefNull); + case Expression::Id::RefIsNullId: + DELEGATE(RefIsNull); + case Expression::Id::RefFuncId: + DELEGATE(RefFunc); case Expression::Id::TryId: DELEGATE(Try); case Expression::Id::ThrowId: @@ -476,6 +494,15 @@ struct UnifiedExpressionVisitor : public Visitor<SubType, ReturnType> { ReturnType visitHost(Host* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } + ReturnType visitRefNull(RefNull* curr) { + return static_cast<SubType*>(this)->visitExpression(curr); + } + ReturnType visitRefIsNull(RefIsNull* curr) { + return static_cast<SubType*>(this)->visitExpression(curr); + } + ReturnType visitRefFunc(RefFunc* curr) { + return static_cast<SubType*>(this)->visitExpression(curr); + } ReturnType visitTry(Try* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } @@ -778,6 +805,15 @@ struct Walker : public VisitorType { static void doVisitHost(SubType* self, Expression** currp) { self->visitHost((*currp)->cast<Host>()); } + static void doVisitRefNull(SubType* self, Expression** currp) { + self->visitRefNull((*currp)->cast<RefNull>()); + } + static void doVisitRefIsNull(SubType* self, Expression** currp) { + self->visitRefIsNull((*currp)->cast<RefIsNull>()); + } + static void doVisitRefFunc(SubType* self, Expression** currp) { + self->visitRefFunc((*currp)->cast<RefFunc>()); + } static void doVisitTry(SubType* self, Expression** currp) { self->visitTry((*currp)->cast<Try>()); } @@ -1036,6 +1072,19 @@ struct PostWalker : public Walker<SubType, VisitorType> { } break; } + case Expression::Id::RefNullId: { + self->pushTask(SubType::doVisitRefNull, currp); + break; + } + case Expression::Id::RefIsNullId: { + self->pushTask(SubType::doVisitRefIsNull, currp); + self->pushTask(SubType::scan, &curr->cast<RefIsNull>()->value); + break; + } + case Expression::Id::RefFuncId: { + self->pushTask(SubType::doVisitRefFunc, currp); + break; + } case Expression::Id::TryId: { self->pushTask(SubType::doVisitTry, currp); self->pushTask(SubType::scan, &curr->cast<Try>()->catchBody); @@ -1099,7 +1148,7 @@ struct ControlFlowWalker : public PostWalker<SubType, VisitorType> { Expression* findBreakTarget(Name name) { assert(!controlFlowStack.empty()); Index i = controlFlowStack.size() - 1; - while (1) { + while (true) { auto* curr = controlFlowStack[i]; if (Block* block = curr->template dynCast<Block>()) { if (name == block->name) { @@ -1111,7 +1160,7 @@ struct ControlFlowWalker : public PostWalker<SubType, VisitorType> { } } else { // an if, ignorable - assert(curr->template is<If>()); + assert(curr->template is<If>() || curr->template is<Try>()); } if (i == 0) { return nullptr; @@ -1169,7 +1218,7 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> { Expression* findBreakTarget(Name name) { assert(!expressionStack.empty()); Index i = expressionStack.size() - 1; - while (1) { + while (true) { auto* curr = expressionStack[i]; if (Block* block = curr->template dynCast<Block>()) { if (name == block->name) { @@ -1179,8 +1228,6 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> { if (name == loop->name) { return curr; } - } else { - WASM_UNREACHABLE("unexpected expression type"); } if (i == 0) { return nullptr; diff --git a/src/wasm-type.h b/src/wasm-type.h index 53ef39ef8..668ac3e4d 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -36,7 +36,9 @@ public: f32, f64, v128, + funcref, anyref, + nullref, exnref, _last_value_type, }; @@ -64,7 +66,8 @@ public: bool isInteger() const { return id == i32 || id == i64; } bool isFloat() const { return id == f32 || id == f64; } bool isVector() const { return id == v128; }; - bool isRef() const { return id == anyref || id == exnref; } + bool isNumber() const { return id >= i32 && id <= v128; } + bool isRef() const { return id >= funcref && id <= exnref; } // (In)equality must be defined for both Type and ValueType because it is // otherwise ambiguous whether to convert both this and other to int or @@ -94,6 +97,23 @@ public: // type. static Type get(unsigned byteSize, bool float_); + // Returns true if left is a subtype of right. Subtype includes itself. + static bool isSubType(Type left, Type right); + + // Computes the least upper bound from the type lattice. + // If one of the type is unreachable, the other type becomes the result. If + // the common supertype does not exist, returns none, a poison value. + static Type getLeastUpperBound(Type a, Type b); + + // Computes the least upper bound for all types in the given list. + template<typename T> static Type mergeTypes(const T& types) { + Type type = Type::unreachable; + for (auto other : types) { + type = Type::getLeastUpperBound(type, other); + } + return type; + } + std::string toString() const; }; @@ -134,7 +154,9 @@ constexpr Type i64 = Type::i64; constexpr Type f32 = Type::f32; constexpr Type f64 = Type::f64; constexpr Type v128 = Type::v128; +constexpr Type funcref = Type::funcref; constexpr Type anyref = Type::anyref; +constexpr Type nullref = Type::nullref; constexpr Type exnref = Type::exnref; constexpr Type unreachable = Type::unreachable; diff --git a/src/wasm.h b/src/wasm.h index 48adf103b..c4dbd2f3f 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -531,6 +531,9 @@ public: MemoryFillId, PushId, PopId, + RefNullId, + RefIsNullId, + RefFuncId, TryId, ThrowId, RethrowId, @@ -569,6 +572,8 @@ public: const char* getExpressionName(Expression* curr); +Literal getLiteralFromConstExpression(Expression* curr); + typedef ArenaVector<Expression*> ExpressionList; template<Expression::Id SID> class SpecificExpression : public Expression { @@ -1008,6 +1013,7 @@ public: Expression* condition; void finalize(); + void finalize(Type type_); }; class Drop : public SpecificExpression<Expression::DropId> { @@ -1070,6 +1076,32 @@ public: Pop(MixedArena& allocator) {} }; +class RefNull : public SpecificExpression<Expression::RefNullId> { +public: + RefNull() = default; + RefNull(MixedArena& allocator) {} + + void finalize(); +}; + +class RefIsNull : public SpecificExpression<Expression::RefIsNullId> { +public: + RefIsNull(MixedArena& allocator) {} + + Expression* value; + + void finalize(); +}; + +class RefFunc : public SpecificExpression<Expression::RefFuncId> { +public: + RefFunc(MixedArena& allocator) {} + + Name func; + + void finalize(); +}; + class Try : public SpecificExpression<Expression::TryId> { public: Try(MixedArena& allocator) {} diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 82a150257..4f66b36e3 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -137,8 +137,11 @@ void Literal::getBits(uint8_t (&buf)[16]) const { case Type::v128: memcpy(buf, &v128, sizeof(v128)); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::nullref: + break; + case Type::anyref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -146,10 +149,20 @@ void Literal::getBits(uint8_t (&buf)[16]) const { } bool Literal::operator==(const Literal& other) const { + if (type.isRef() && other.type.isRef()) { + if (type == Type::nullref && other.type == Type::nullref) { + return true; + } + if (type == Type::funcref && other.type == Type::funcref && + func == other.func) { + return true; + } + return false; + } if (type != other.type) { return false; } - if (type == none) { + if (type == Type::none) { return true; } uint8_t bits[16], other_bits[16]; @@ -273,8 +286,14 @@ std::ostream& operator<<(std::ostream& o, Literal literal) { o << "i32x4 "; literal.printVec128(o, literal.getv128()); break; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + o << "funcref(" << literal.getFunc() << ")"; + break; + case Type::nullref: + o << "nullref"; + break; + case Type::anyref: + case Type::exnref: case Type::unreachable: WASM_UNREACHABLE("invalid type"); } @@ -477,7 +496,9 @@ Literal Literal::eqz() const { case Type::f64: return eq(Literal(double(0))); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -497,7 +518,9 @@ Literal Literal::neg() const { case Type::f64: return Literal(int64_t(i64 ^ 0x8000000000000000ULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -517,7 +540,9 @@ Literal Literal::abs() const { case Type::f64: return Literal(int64_t(i64 & 0x7fffffffffffffffULL)).castToF64(); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -620,7 +645,9 @@ Literal Literal::add(const Literal& other) const { case Type::f64: return Literal(getf64() + other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -640,7 +667,9 @@ Literal Literal::sub(const Literal& other) const { case Type::f64: return Literal(getf64() - other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -731,7 +760,9 @@ Literal Literal::mul(const Literal& other) const { case Type::f64: return Literal(getf64() * other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -967,7 +998,9 @@ Literal Literal::eq(const Literal& other) const { case Type::f64: return Literal(getf64() == other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -987,7 +1020,9 @@ Literal Literal::ne(const Literal& other) const { case Type::f64: return Literal(getf64() != other.getf64()); case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 82eb51d7e..ba5a8d3dd 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -262,7 +262,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one table\n"); writeImportHeader(&wasm->table); o << U32LEB(int32_t(ExternalKind::Table)); - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -463,7 +463,7 @@ void WasmBinaryWriter::writeFunctionTableDeclaration() { BYN_TRACE("== writeFunctionTableDeclaration\n"); auto start = startSection(BinaryConsts::Section::Table); o << U32LEB(1); // Declare 1 table. - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + o << S32LEB(BinaryConsts::EncodedType::funcref); writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.hasMax(), @@ -1059,8 +1059,12 @@ Type WasmBinaryBuilder::getType() { return f64; case BinaryConsts::EncodedType::v128: return v128; + case BinaryConsts::EncodedType::funcref: + return funcref; case BinaryConsts::EncodedType::anyref: return anyref; + case BinaryConsts::EncodedType::nullref: + return nullref; case BinaryConsts::EncodedType::exnref: return exnref; default: @@ -1258,8 +1262,8 @@ void WasmBinaryBuilder::readImports() { wasm.table.name = Name(std::string("timport$") + std::to_string(i)); auto elementType = getS32LEB(); WASM_UNUSED(elementType); - if (elementType != BinaryConsts::EncodedType::AnyFunc) { - throwError("Imported table type is not AnyFunc"); + if (elementType != BinaryConsts::EncodedType::funcref) { + throwError("Imported table type is not funcref"); } wasm.table.exists = true; bool is_shared; @@ -1802,11 +1806,17 @@ void WasmBinaryBuilder::processFunctions() { wasm.addExport(curr); } - for (auto& iter : functionCalls) { + for (auto& iter : functionRefs) { size_t index = iter.first; - auto& calls = iter.second; - for (auto* call : calls) { - call->target = getFunctionName(index); + auto& refs = iter.second; + for (auto* ref : refs) { + if (auto* call = ref->dynCast<Call>()) { + call->target = getFunctionName(index); + } else if (auto* refFunc = ref->dynCast<RefFunc>()) { + refFunc->func = getFunctionName(index); + } else { + WASM_UNREACHABLE("Invalid type in function references"); + } } } @@ -1869,8 +1879,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() { } wasm.table.exists = true; auto elemType = getS32LEB(); - if (elemType != BinaryConsts::EncodedType::AnyFunc) { - throwError("ElementType must be AnyFunc in MVP"); + if (elemType != BinaryConsts::EncodedType::funcref) { + throwError("ElementType must be funcref in MVP"); } bool is_shared; getResizableLimits( @@ -2117,7 +2127,8 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { visitGlobalSet((curr = allocator.alloc<GlobalSet>())->cast<GlobalSet>()); break; case BinaryConsts::Select: - visitSelect((curr = allocator.alloc<Select>())->cast<Select>()); + case BinaryConsts::SelectWithType: + visitSelect((curr = allocator.alloc<Select>())->cast<Select>(), code); break; case BinaryConsts::Return: visitReturn((curr = allocator.alloc<Return>())->cast<Return>()); @@ -2137,6 +2148,15 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { case BinaryConsts::Catch: curr = nullptr; break; + case BinaryConsts::RefNull: + visitRefNull((curr = allocator.alloc<RefNull>())->cast<RefNull>()); + break; + case BinaryConsts::RefIsNull: + visitRefIsNull((curr = allocator.alloc<RefIsNull>())->cast<RefIsNull>()); + break; + case BinaryConsts::RefFunc: + visitRefFunc((curr = allocator.alloc<RefFunc>())->cast<RefFunc>()); + break; case BinaryConsts::Try: visitTry((curr = allocator.alloc<Try>())->cast<Try>()); break; @@ -2510,7 +2530,7 @@ void WasmBinaryBuilder::visitCall(Call* curr) { curr->operands[num - i - 1] = popNonVoidExpression(); } curr->type = sig.results; - functionCalls[index].push_back(curr); // we don't know function names yet + functionRefs[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -4326,12 +4346,24 @@ bool WasmBinaryBuilder::maybeVisitSIMDLoad(Expression*& out, uint32_t code) { return true; } -void WasmBinaryBuilder::visitSelect(Select* curr) { - BYN_TRACE("zz node: Select\n"); +void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) { + BYN_TRACE("zz node: Select, code " << int32_t(code) << std::endl); + if (code == BinaryConsts::SelectWithType) { + size_t numTypes = getU32LEB(); + std::vector<Type> types; + for (size_t i = 0; i < numTypes; i++) { + types.push_back(getType()); + } + curr->type = Type(types); + } curr->condition = popNonVoidExpression(); curr->ifFalse = popNonVoidExpression(); curr->ifTrue = popNonVoidExpression(); - curr->finalize(); + if (code == BinaryConsts::SelectWithType) { + curr->finalize(curr->type); + } else { + curr->finalize(); + } } void WasmBinaryBuilder::visitReturn(Return* curr) { @@ -4383,6 +4415,27 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) { curr->finalize(); } +void WasmBinaryBuilder::visitRefNull(RefNull* curr) { + BYN_TRACE("zz node: RefNull\n"); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefIsNull(RefIsNull* curr) { + BYN_TRACE("zz node: RefIsNull\n"); + curr->value = popNonVoidExpression(); + curr->finalize(); +} + +void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) { + BYN_TRACE("zz node: RefFunc\n"); + Index index = getU32LEB(); + if (index >= functionImports.size() + functionSignatures.size()) { + throwError("ref.func: invalid call index"); + } + functionRefs[index].push_back(curr); // we don't know function names yet + curr->finalize(); +} + void WasmBinaryBuilder::visitTry(Try* curr) { BYN_TRACE("zz node: Try\n"); // For simplicity of implementation, like if scopes, we create a hidden block diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 20aff2091..3b12c4346 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -850,16 +850,22 @@ Type SExpressionWasmBuilder::stringToType(const char* str, return v128; } } + if (strncmp(str, "funcref", 7) == 0 && (prefix || str[7] == 0)) { + return funcref; + } if (strncmp(str, "anyref", 6) == 0 && (prefix || str[6] == 0)) { return anyref; } + if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) { + return nullref; + } if (strncmp(str, "exnref", 6) == 0 && (prefix || str[6] == 0)) { return exnref; } if (allowError) { return none; } - throw ParseException("invalid wasm type"); + throw ParseException(std::string("invalid wasm type: ") + str); } Type SExpressionWasmBuilder::stringToLaneType(const char* str) { @@ -936,10 +942,16 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) { Expression* SExpressionWasmBuilder::makeSelect(Element& s) { auto ret = allocator.alloc<Select>(); - ret->ifTrue = parseExpression(s[1]); - ret->ifFalse = parseExpression(s[2]); - ret->condition = parseExpression(s[3]); - ret->finalize(); + Index i = 1; + Type type = parseOptionalResultType(s, i); + ret->ifTrue = parseExpression(s[i++]); + ret->ifFalse = parseExpression(s[i++]); + ret->condition = parseExpression(s[i]); + if (type.isConcrete()) { + ret->finalize(type); + } else { + ret->finalize(); + } return ret; } @@ -1718,6 +1730,27 @@ Expression* SExpressionWasmBuilder::makeReturn(Element& s) { return ret; } +Expression* SExpressionWasmBuilder::makeRefNull(Element& s) { + auto ret = allocator.alloc<RefNull>(); + ret->finalize(); + return ret; +} + +Expression* SExpressionWasmBuilder::makeRefIsNull(Element& s) { + auto ret = allocator.alloc<RefIsNull>(); + ret->value = parseExpression(s[1]); + ret->finalize(); + return ret; +} + +Expression* SExpressionWasmBuilder::makeRefFunc(Element& s) { + auto func = getFunctionName(*s[1]); + auto ret = allocator.alloc<RefFunc>(); + ret->func = func; + ret->finalize(); + return ret; +} + // try-catch-end is written in the folded wast format as // (try // ... diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 593214838..22d6a0036 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -147,8 +147,10 @@ void BinaryInstWriter::visitLoad(Load* curr) { // the pointer is unreachable, so we are never reached; just don't emit // a load return; - case anyref: // anyref cannot be loaded from memory - case exnref: // exnref cannot be loaded from memory + case funcref: + case anyref: + case nullref: + case exnref: case none: WASM_UNREACHABLE("unexpected type"); } @@ -247,8 +249,10 @@ void BinaryInstWriter::visitStore(Store* curr) { o << int8_t(BinaryConsts::SIMDPrefix) << U32LEB(BinaryConsts::V128Store); break; - case anyref: // anyref cannot be stored from memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -642,8 +646,10 @@ void BinaryInstWriter::visitConst(Const* curr) { } break; } - case anyref: // there's no anyref.const - case exnref: // there's no exnref.const + case funcref: + case anyref: + case nullref: + case exnref: case none: case unreachable: WASM_UNREACHABLE("unexpected type"); @@ -1541,7 +1547,15 @@ void BinaryInstWriter::visitBinary(Binary* curr) { } void BinaryInstWriter::visitSelect(Select* curr) { - o << int8_t(BinaryConsts::Select); + if (curr->type.isRef()) { + o << int8_t(BinaryConsts::SelectWithType) << U32LEB(curr->type.size()); + for (size_t i = 0; i < curr->type.size(); i++) { + o << binaryType(curr->type != Type::unreachable ? curr->type + : Type::none); + } + } else { + o << int8_t(BinaryConsts::Select); + } } void BinaryInstWriter::visitReturn(Return* curr) { @@ -1562,6 +1576,19 @@ void BinaryInstWriter::visitHost(Host* curr) { o << U32LEB(0); // Reserved flags field } +void BinaryInstWriter::visitRefNull(RefNull* curr) { + o << int8_t(BinaryConsts::RefNull); +} + +void BinaryInstWriter::visitRefIsNull(RefIsNull* curr) { + o << int8_t(BinaryConsts::RefIsNull); +} + +void BinaryInstWriter::visitRefFunc(RefFunc* curr) { + o << int8_t(BinaryConsts::RefFunc) + << U32LEB(parent.getFunctionIndex(curr->func)); +} + void BinaryInstWriter::visitTry(Try* curr) { breakStack.emplace_back(IMPOSSIBLE_CONTINUE); o << int8_t(BinaryConsts::Try); @@ -1659,11 +1686,21 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { continue; } index += numLocalsByType[v128]; + if (type == funcref) { + mappedLocals[i] = index + currLocalsByType[funcref] - 1; + continue; + } + index += numLocalsByType[funcref]; if (type == anyref) { mappedLocals[i] = index + currLocalsByType[anyref] - 1; continue; } index += numLocalsByType[anyref]; + if (type == nullref) { + mappedLocals[i] = index + currLocalsByType[nullref] - 1; + continue; + } + index += numLocalsByType[nullref]; if (type == exnref) { mappedLocals[i] = index + currLocalsByType[exnref] - 1; continue; @@ -1671,11 +1708,12 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { WASM_UNREACHABLE("unexpected type"); } // Emit them. - o << U32LEB((numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) + - (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) + - (numLocalsByType[v128] ? 1 : 0) + - (numLocalsByType[anyref] ? 1 : 0) + - (numLocalsByType[exnref] ? 1 : 0)); + o << U32LEB( + (numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) + + (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) + + (numLocalsByType[v128] ? 1 : 0) + (numLocalsByType[funcref] ? 1 : 0) + + (numLocalsByType[anyref] ? 1 : 0) + (numLocalsByType[nullref] ? 1 : 0) + + (numLocalsByType[exnref] ? 1 : 0)); if (numLocalsByType[i32]) { o << U32LEB(numLocalsByType[i32]) << binaryType(i32); } @@ -1691,9 +1729,15 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() { if (numLocalsByType[v128]) { o << U32LEB(numLocalsByType[v128]) << binaryType(v128); } + if (numLocalsByType[funcref]) { + o << U32LEB(numLocalsByType[funcref]) << binaryType(funcref); + } if (numLocalsByType[anyref]) { o << U32LEB(numLocalsByType[anyref]) << binaryType(anyref); } + if (numLocalsByType[nullref]) { + o << U32LEB(numLocalsByType[nullref]) << binaryType(nullref); + } if (numLocalsByType[exnref]) { o << U32LEB(numLocalsByType[exnref]) << binaryType(exnref); } @@ -1760,7 +1804,7 @@ StackInst* StackIRGenerator::makeStackInst(StackInst::Op op, // type. stackType = none; } else if (op != StackInst::BlockEnd && op != StackInst::IfEnd && - op != StackInst::LoopEnd) { + op != StackInst::LoopEnd && op != StackInst::TryEnd) { // If a concrete type is returned, we mark the end of the construct has // having that type (as it is pushed to the value stack at that point), // other parts are marked as none). @@ -1781,13 +1825,15 @@ void StackIRToBinaryWriter::write() { case StackInst::Basic: case StackInst::BlockBegin: case StackInst::IfBegin: - case StackInst::LoopBegin: { + case StackInst::LoopBegin: + case StackInst::TryBegin: { writer.visit(inst->origin); break; } case StackInst::BlockEnd: case StackInst::IfEnd: - case StackInst::LoopEnd: { + case StackInst::LoopEnd: + case StackInst::TryEnd: { writer.emitScopeEnd(); break; } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 939ee8c93..62c30d1e0 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -62,7 +62,9 @@ std::vector<std::unique_ptr<std::vector<Type>>> typeLists = [] { add({Type::f32}); add({Type::f64}); add({Type::v128}); + add({Type::funcref}); add({Type::anyref}); + add({Type::nullref}); add({Type::exnref}); return lists; }(); @@ -75,7 +77,9 @@ std::unordered_map<std::vector<Type>, uint32_t> indices = { {{Type::f32}, Type::f32}, {{Type::f64}, Type::f64}, {{Type::v128}, Type::v128}, + {{Type::funcref}, Type::funcref}, {{Type::anyref}, Type::anyref}, + {{Type::nullref}, Type::nullref}, {{Type::exnref}, Type::exnref}, }; @@ -154,8 +158,10 @@ unsigned Type::getByteSize() const { return 8; case Type::v128: return 16; - case Type::anyref: // anyref type is opaque - case Type::exnref: // exnref type is opaque + case Type::funcref: + case Type::anyref: + case Type::nullref: + case Type::exnref: case Type::none: case Type::unreachable: WASM_UNREACHABLE("invalid type"); @@ -164,7 +170,7 @@ unsigned Type::getByteSize() const { } Type Type::reinterpret() const { - assert(isSingle() && "reinterpret only works with single types"); + assert(isSingle() && "reinterpretType only works with single types"); Type singleType = *expand().begin(); switch (singleType) { case Type::i32: @@ -176,7 +182,9 @@ Type Type::reinterpret() const { case Type::f64: return i64; case Type::v128: + case Type::funcref: case Type::anyref: + case Type::nullref: case Type::exnref: case Type::none: case Type::unreachable: @@ -221,6 +229,39 @@ Type Type::get(unsigned byteSize, bool float_) { WASM_UNREACHABLE("invalid size"); } +bool Type::Type::isSubType(Type left, Type right) { + if (left == right) { + return true; + } + if (left.isRef() && right.isRef() && + (right == Type::anyref || left == Type::nullref)) { + return true; + } + return false; +} + +Type Type::Type::getLeastUpperBound(Type a, Type b) { + if (a == b) { + return a; + } + if (a == Type::unreachable) { + return b; + } + if (b == Type::unreachable) { + return a; + } + if (!a.isRef() || !b.isRef()) { + return none; // a poison value that must not be consumed + } + if (a == Type::nullref) { + return b; + } + if (b == Type::nullref) { + return a; + } + return Type::anyref; +} + namespace { std::ostream& @@ -280,9 +321,15 @@ std::ostream& operator<<(std::ostream& os, Type type) { case Type::v128: os << "v128"; break; + case Type::funcref: + os << "funcref"; + break; case Type::anyref: os << "anyref"; break; + case Type::nullref: + os << "nullref"; + break; case Type::exnref: os << "exnref"; break; diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 55e115d95..7bf51c5f7 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -21,6 +21,7 @@ #include "ir/branch-utils.h" #include "ir/features.h" +#include "ir/global-utils.h" #include "ir/module-utils.h" #include "ir/utils.h" #include "support/colors.h" @@ -181,6 +182,31 @@ struct ValidationInfo { fail(text, curr, func); } } + + // Type 'left' should be a subtype of 'right'. + bool shouldBeSubType(Type left, + Type right, + Expression* curr, + const char* text, + Function* func = nullptr) { + if (Type::isSubType(left, right)) { + return true; + } + fail(text, curr, func); + return false; + } + + // Type 'left' should be a subtype of 'right', or unreachable. + bool shouldBeSubTypeOrUnreachable(Type left, + Type right, + Expression* curr, + const char* text, + Function* func = nullptr) { + if (left == Type::unreachable) { + return true; + } + return shouldBeSubType(left, right, curr, text, func); + } }; struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { @@ -210,7 +236,7 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { std::unordered_map<Name, BreakInfo> breakInfos; - Type returnType = unreachable; // type used in returns + std::set<Type> returnTypes; // types used in returns // Binaryen IR requires that label names must be unique - IR generators must // ensure that @@ -287,6 +313,8 @@ public: void visitDrop(Drop* curr); void visitReturn(Return* curr); void visitHost(Host* curr); + void visitRefIsNull(RefIsNull* curr); + void visitRefFunc(RefFunc* curr); void visitTry(Try* curr); void visitThrow(Throw* curr); void visitRethrow(Rethrow* curr); @@ -327,6 +355,19 @@ private: return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction()); } + bool + shouldBeSubType(Type left, Type right, Expression* curr, const char* text) { + return info.shouldBeSubType(left, right, curr, text, getFunction()); + } + + bool shouldBeSubTypeOrUnreachable(Type left, + Type right, + Expression* curr, + const char* text) { + return info.shouldBeSubTypeOrUnreachable( + left, right, curr, text, getFunction()); + } + void validateAlignment( size_t align, Type type, Index bytes, bool isAtomic, Expression* curr); void validateMemBytes(uint8_t bytes, Type type, Expression* curr); @@ -364,29 +405,23 @@ void FunctionValidator::visitBlock(Block* curr) { // none or unreachable means a poison value that we should ignore - if // consumed, it will error if (info.type.isConcrete() && curr->type.isConcrete()) { - shouldBeEqual( - curr->type, + shouldBeSubType( info.type, + curr->type, curr, "block+breaks must have right type if breaks return a value"); } if (curr->type.isConcrete() && info.arity && info.type != unreachable) { - shouldBeEqual(curr->type, - info.type, - curr, - "block+breaks must have right type if breaks have arity"); + shouldBeSubType( + info.type, + curr->type, + curr, + "block+breaks must have right type if breaks have arity"); } shouldBeTrue( info.arity != BreakInfo::PoisonArity, curr, "break arities must match"); if (curr->list.size() > 0) { auto last = curr->list.back()->type; - if (last.isConcrete() && info.type != unreachable) { - shouldBeEqual(last, - info.type, - curr, - "block+breaks must have right type if block ends with " - "a reachable value"); - } if (last == none) { shouldBeTrue(info.arity == Index(0), curr, @@ -420,9 +455,9 @@ void FunctionValidator::visitBlock(Block* curr) { "not flow out a value"); } else { if (backType.isConcrete()) { - shouldBeEqual( - curr->type, + shouldBeSubType( backType, + curr->type, curr, "block with value and last element with value must match types"); } else { @@ -457,6 +492,23 @@ void FunctionValidator::visitLoop(Loop* curr) { curr, "bad body for a loop that has no value"); } + + // When there are multiple instructions within a loop, they are wrapped in a + // Block internally, so visitBlock can take care of verification. Here we + // check cases when there is only one instruction in a Loop. + if (!curr->body->is<Block>()) { + if (!curr->type.isConcrete()) { + shouldBeFalse(curr->body->type.isConcrete(), + curr, + "if loop is not returning a value, final element should " + "not flow out a value"); + } else { + shouldBeSubTypeOrUnreachable(curr->body->type, + curr->type, + curr, + "loop with value and body must match types"); + } + } } void FunctionValidator::visitIf(If* curr) { @@ -476,12 +528,12 @@ void FunctionValidator::visitIf(If* curr) { } } else { if (curr->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable( + shouldBeSubTypeOrUnreachable( curr->ifTrue->type, curr->type, curr, "returning if-else's true must have right type"); - shouldBeEqualOrFirstIsUnreachable( + shouldBeSubTypeOrUnreachable( curr->ifFalse->type, curr->type, curr, @@ -499,25 +551,16 @@ void FunctionValidator::visitIf(If* curr) { } } if (curr->ifTrue->type.isConcrete()) { - shouldBeEqual(curr->type, - curr->ifTrue->type, - curr, - "if type must match concrete ifTrue"); - shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, - curr->ifTrue->type, - curr, - "other arm must match concrete ifTrue"); + shouldBeSubType(curr->ifTrue->type, + curr->type, + curr, + "if type must match concrete ifTrue"); } if (curr->ifFalse->type.isConcrete()) { - shouldBeEqual(curr->type, - curr->ifFalse->type, - curr, - "if type must match concrete ifFalse"); - shouldBeEqualOrFirstIsUnreachable( - curr->ifTrue->type, - curr->ifFalse->type, - curr, - "other arm must match concrete ifFalse"); + shouldBeSubType(curr->ifFalse->type, + curr->type, + curr, + "if type must match concrete ifFalse"); } } } @@ -545,13 +588,7 @@ void FunctionValidator::noteBreak(Name name, Type valueType, Expression* curr) { if (!info.hasBeenSet()) { info = BreakInfo(valueType, arity); } else { - if (info.type == unreachable) { - info.type = valueType; - } else if (valueType != unreachable) { - if (valueType != info.type) { - info.type = none; // a poison value that must not be consumed - } - } + info.type = Type::getLeastUpperBound(info.type, valueType); if (arity != info.arity) { info.arity = BreakInfo::PoisonArity; } @@ -600,10 +637,10 @@ void FunctionValidator::visitCall(Call* curr) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - params[i], - curr, - "call param types must match") && + if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type, + params[i], + curr, + "call param types must match") && !info.quiet) { getStream() << "(on argument " << i << ")\n"; } @@ -653,10 +690,10 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - params[i], - curr, - "call param types must match") && + if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type, + params[i], + curr, + "call param types must match") && !info.quiet) { getStream() << "(on argument " << i << ")\n"; } @@ -723,10 +760,10 @@ void FunctionValidator::visitLocalSet(LocalSet* curr) { curr, "local.set type must be correct"); } - shouldBeEqual(curr->value->type, - getFunction()->getLocalType(curr->index), - curr, - "local.set's value type must be correct"); + shouldBeSubType(curr->value->type, + getFunction()->getLocalType(curr->index), + curr, + "local.set's value type must be correct"); } } } @@ -750,10 +787,10 @@ void FunctionValidator::visitGlobalSet(GlobalSet* curr) { "global.set name must be valid (and not an import; imports " "can't be modified)")) { shouldBeTrue(global->mutable_, curr, "global.set global must be mutable"); - shouldBeEqualOrFirstIsUnreachable(curr->value->type, - global->type, - curr, - "global.set value must have right type"); + shouldBeSubTypeOrUnreachable(curr->value->type, + global->type, + curr, + "global.set value must have right type"); } } @@ -1182,12 +1219,14 @@ void FunctionValidator::validateMemBytes(uint8_t bytes, shouldBeEqual( bytes, uint8_t(16), curr, "expected v128 operation to touch 16 bytes"); break; - case anyref: // anyref cannot be stored in memory - case exnref: // exnref cannot be stored in memory - case none: - WASM_UNREACHABLE("unexpected type"); case unreachable: break; + case funcref: + case anyref: + case nullref: + case exnref: + case none: + WASM_UNREACHABLE("unexpected type"); } } @@ -1616,15 +1655,18 @@ void FunctionValidator::visitSelect(Select* curr) { shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); shouldBeUnequal( curr->ifFalse->type, none, curr, "select right must be valid"); + shouldBeUnequal(curr->type, none, curr, "select type must be valid"); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid"); - if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) { - shouldBeEqual(curr->ifTrue->type, - curr->ifFalse->type, - curr, - "select sides must be equal"); + if (curr->type != unreachable) { + shouldBeTrue(Type::isSubType(curr->ifTrue->type, curr->type), + curr, + "select's left expression must be subtype of select's type"); + shouldBeTrue(Type::isSubType(curr->ifFalse->type, curr->type), + curr, + "select's right expression must be subtype of select's type"); } } @@ -1636,16 +1678,7 @@ void FunctionValidator::visitDrop(Drop* curr) { } void FunctionValidator::visitReturn(Return* curr) { - if (curr->value) { - if (returnType == unreachable) { - returnType = curr->value->type; - } else if (curr->value->type != unreachable) { - shouldBeEqual( - curr->value->type, returnType, curr, "function results must match"); - } - } else { - returnType = none; - } + returnTypes.insert(curr->value ? curr->value->type : Type::none); } void FunctionValidator::visitHost(Host* curr) { @@ -1668,32 +1701,37 @@ void FunctionValidator::visitHost(Host* curr) { } } +void FunctionValidator::visitRefIsNull(RefIsNull* curr) { + shouldBeTrue(curr->value->type == Type::unreachable || + curr->value->type.isRef(), + curr->value, + "ref.is_null's argument should be a reference type"); +} + +void FunctionValidator::visitRefFunc(RefFunc* curr) { + auto* func = getModule()->getFunctionOrNull(curr->func); + shouldBeTrue(!!func, curr, "function argument of ref.func must exist"); +} + void FunctionValidator::visitTry(Try* curr) { if (curr->type != unreachable) { - shouldBeEqualOrFirstIsUnreachable( - curr->body->type, - curr->type, - curr->body, - "try's type does not match try body's type"); - shouldBeEqualOrFirstIsUnreachable( - curr->catchBody->type, - curr->type, - curr->catchBody, - "try's type does not match catch's body type"); - } - if (curr->body->type.isConcrete()) { - shouldBeEqualOrFirstIsUnreachable( - curr->catchBody->type, - curr->body->type, - curr->catchBody, - "try's body type must match catch's body type"); - } - if (curr->catchBody->type.isConcrete()) { - shouldBeEqualOrFirstIsUnreachable( - curr->body->type, - curr->catchBody->type, - curr->body, - "try's body type must match catch's body type"); + shouldBeSubTypeOrUnreachable(curr->body->type, + curr->type, + curr->body, + "try's type does not match try body's type"); + shouldBeSubTypeOrUnreachable(curr->catchBody->type, + curr->type, + curr->catchBody, + "try's type does not match catch's body type"); + } else { + shouldBeEqual(curr->body->type, + unreachable, + curr, + "unreachable try-catch must have unreachable try body"); + shouldBeEqual(curr->catchBody->type, + unreachable, + curr, + "unreachable try-catch must have unreachable catch body"); } } @@ -1727,10 +1765,10 @@ void FunctionValidator::visitThrow(Throw* curr) { void FunctionValidator::visitRethrow(Rethrow* curr) { shouldBeEqual( curr->type, unreachable, curr, "rethrow's type must be unreachable"); - shouldBeEqual(curr->exnref->type, - exnref, - curr->exnref, - "rethrow's argument must be exnref type"); + shouldBeSubType(curr->exnref->type, + Type::exnref, + curr->exnref, + "rethrow's argument must be exnref type or its subtype"); } void FunctionValidator::visitBrOnExn(BrOnExn* curr) { @@ -1740,10 +1778,11 @@ void FunctionValidator::visitBrOnExn(BrOnExn* curr) { curr, "br_on_exn's event params and event's params are different"); noteBreak(curr->name, curr->sent, curr); - shouldBeTrue(curr->exnref->type == unreachable || - curr->exnref->type == exnref, - curr, - "br_on_exn's argument must be unreachable or exnref type"); + shouldBeSubTypeOrUnreachable( + curr->exnref->type, + Type::exnref, + curr, + "br_on_exn's argument must be unreachable or exnref type or its subtype"); if (curr->exnref->type == unreachable) { shouldBeTrue(curr->type == unreachable, curr, @@ -1779,21 +1818,22 @@ void FunctionValidator::visitFunction(Function* curr) { "all used types should be allowed"); // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return - if (curr->body->type != unreachable) { - shouldBeEqual(curr->sig.results, - curr->body->type, - curr->body, - "function body type must match, if function returns"); - } - if (returnType != unreachable) { - shouldBeEqual(curr->sig.results, - returnType, - curr->body, - "function result must match, if function has returns"); + shouldBeSubTypeOrUnreachable( + curr->body->type, + curr->sig.results, + curr->body, + "function body type must match, if function returns"); + for (Type returnType : returnTypes) { + shouldBeSubTypeOrUnreachable( + returnType, + curr->sig.results, + curr->body, + "function result must match, if function has returns"); } + shouldBeTrue( breakInfos.empty(), curr->body, "all named break targets must exist"); - returnType = unreachable; + returnTypes.clear(); labelNames.clear(); // validate optional local names std::set<Name> seen; @@ -1858,8 +1898,10 @@ void FunctionValidator::validateAlignment( case v128: case unreachable: break; - case anyref: // anyref cannot be stored in memory - case exnref: // exnref cannot be stored in memory + case funcref: + case anyref: + case nullref: + case exnref: case none: WASM_UNREACHABLE("invalid type"); } @@ -1890,7 +1932,8 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { // // The block has an added type, not derived from the ast itself, so it // is ok for it to be either i32 or unreachable. - if (!(oldType.isConcrete() && newType == unreachable)) { + if (!Type::isSubType(newType, oldType) && + !(oldType.isConcrete() && newType == Type::unreachable)) { std::ostringstream ss; ss << "stale type found in " << scope << " on " << curr << "\n(marked as " << oldType << ", should be " << newType @@ -2011,13 +2054,14 @@ static void validateGlobals(Module& module, ValidationInfo& info) { info.shouldBeTrue( curr->init != nullptr, curr->name, "global init must be non-null"); assert(curr->init); - info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GlobalGet>(), + info.shouldBeTrue(GlobalUtils::canInitializeGlobal(curr->init), curr->name, "global init must be valid"); - if (!info.shouldBeEqual(curr->type, - curr->init->type, - curr->init, - "global init must have correct type") && + + if (!info.shouldBeSubType(curr->init->type, + curr->type, + curr->init, + "global init must have correct type") && !info.quiet) { info.getStream(nullptr) << "(on global " << curr->name << ")\n"; } @@ -2118,9 +2162,9 @@ static void validateEvents(Module& module, ValidationInfo& info) { curr->name, "Event type's result type should be none"); for (auto type : curr->sig.params.expand()) { - info.shouldBeTrue(type.isInteger() || type.isFloat(), + info.shouldBeTrue(type.isConcrete(), curr->name, - "Values in an event should have integer or float type"); + "Values in an event should have concrete types"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index ff1295bad..11d203835 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -173,13 +173,19 @@ const char* getExpressionName(Expression* curr) { return "push"; case Expression::Id::PopId: return "pop"; - case Expression::TryId: + case Expression::Id::RefNullId: + return "ref.null"; + case Expression::Id::RefIsNullId: + return "ref.is_null"; + case Expression::Id::RefFuncId: + return "ref.func"; + case Expression::Id::TryId: return "try"; - case Expression::ThrowId: + case Expression::Id::ThrowId: return "throw"; - case Expression::RethrowId: + case Expression::Id::RethrowId: return "rethrow"; - case Expression::BrOnExnId: + case Expression::Id::BrOnExnId: return "br_on_exn"; case Expression::Id::NumExpressionIds: WASM_UNREACHABLE("invalid expr id"); @@ -187,6 +193,18 @@ const char* getExpressionName(Expression* curr) { WASM_UNREACHABLE("invalid expr id"); } +Literal getLiteralFromConstExpression(Expression* curr) { + if (auto* c = curr->dynCast<Const>()) { + return c->value; + } else if (curr->is<RefNull>()) { + return Literal::makeNullref(); + } else if (auto* r = curr->dynCast<RefFunc>()) { + return Literal::makeFuncref(r->func); + } else { + WASM_UNREACHABLE("Not a constant expression"); + } +} + // core AST type checking struct TypeSeeker : public PostWalker<TypeSeeker> { @@ -248,27 +266,6 @@ struct TypeSeeker : public PostWalker<TypeSeeker> { } }; -static Type mergeTypes(std::vector<Type>& types) { - Type type = unreachable; - for (auto other : types) { - // once none, stop. it then indicates a poison value, that must not be - // consumed and ignore unreachable - if (type != none) { - if (other == none) { - type = none; - } else if (other != unreachable) { - if (type == unreachable) { - type = other; - } else if (type != other) { - // poison value, we saw multiple types; this should not be consumed - type = none; - } - } - } - } - return type; -} - // a block is unreachable if one of its elements is unreachable, // and there are no branches to it static void handleUnreachable(Block* block, @@ -336,7 +333,7 @@ void Block::finalize() { } TypeSeeker seeker(this, this->name); - type = mergeTypes(seeker.types); + type = Type::mergeTypes(seeker.types); handleUnreachable(this); } @@ -364,19 +361,8 @@ void If::finalize(Type type_) { } void If::finalize() { - if (ifFalse) { - if (ifTrue->type == ifFalse->type) { - type = ifTrue->type; - } else if (ifTrue->type.isConcrete() && ifFalse->type == unreachable) { - type = ifTrue->type; - } else if (ifFalse->type.isConcrete() && ifTrue->type == unreachable) { - type = ifFalse->type; - } else { - type = none; - } - } else { - type = none; // if without else - } + type = ifFalse ? Type::getLeastUpperBound(ifTrue->type, ifFalse->type) + : Type::none; // if the arms return a value, leave it even if the condition // is unreachable, we still mark ourselves as having that type, e.g. // (if (result i32) @@ -828,13 +814,15 @@ void Binary::finalize() { } } +void Select::finalize(Type type_) { type = type_; } + void Select::finalize() { assert(ifTrue && ifFalse); if (ifTrue->type == unreachable || ifFalse->type == unreachable || condition->type == unreachable) { type = unreachable; } else { - type = ifTrue->type; + type = Type::getLeastUpperBound(ifTrue->type, ifFalse->type); } } @@ -864,16 +852,20 @@ void Host::finalize() { } } -void Try::finalize() { - if (body->type == catchBody->type) { - type = body->type; - } else if (body->type.isConcrete() && catchBody->type == unreachable) { - type = body->type; - } else if (catchBody->type.isConcrete() && body->type == unreachable) { - type = catchBody->type; - } else { - type = none; +void RefNull::finalize() { type = Type::nullref; } + +void RefIsNull::finalize() { + if (value->type == Type::unreachable) { + type = Type::unreachable; + return; } + type = Type::i32; +} + +void RefFunc::finalize() { type = Type::funcref; } + +void Try::finalize() { + type = Type::getLeastUpperBound(body->type, catchBody->type); } void Try::finalize(Type type_) { diff --git a/src/wasm2js.h b/src/wasm2js.h index 2cacbe8e5..1adde23b8 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -1848,6 +1848,18 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m, unimplemented(curr); WASM_UNREACHABLE("unimp"); } + Ref visitRefNull(RefNull* curr) { + unimplemented(curr); + WASM_UNREACHABLE("unimp"); + } + Ref visitRefIsNull(RefIsNull* curr) { + unimplemented(curr); + WASM_UNREACHABLE("unimp"); + } + Ref visitRefFunc(RefFunc* curr) { + unimplemented(curr); + WASM_UNREACHABLE("unimp"); + } Ref visitTry(Try* curr) { unimplemented(curr); WASM_UNREACHABLE("unimp"); |