summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorHeejin Ahn <aheejin@gmail.com>2019-12-30 17:55:20 -0800
committerGitHub <noreply@github.com>2019-12-30 17:55:20 -0800
commitbcc76146fed433cbc8ba01a9f568d979c145110b (patch)
treeab70ad24afc257b73513c3e62f3aab9938d05944 /src
parenta30f1df5696ccb3490e2eaa3a9ed5e7e487c7b0e (diff)
downloadbinaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.tar.gz
binaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.tar.bz2
binaryen-bcc76146fed433cbc8ba01a9f568d979c145110b.zip
Add support for reference types proposal (#2451)
This adds support for the reference type proposal. This includes support for all reference types (`anyref`, `funcref`(=`anyfunc`), and `nullref`) and four new instructions: `ref.null`, `ref.is_null`, `ref.func`, and new typed `select`. This also adds subtype relationship support between reference types. This does not include table instructions yet. This also does not include wasm2js support. Fixes #2444 and fixes #2447.
Diffstat (limited to 'src')
-rw-r--r--src/asmjs/asm_v_wasm.cpp13
-rw-r--r--src/binaryen-c.cpp101
-rw-r--r--src/binaryen-c.h19
-rw-r--r--src/gen-s-parser.inc64
-rw-r--r--src/ir/ExpressionAnalyzer.cpp3
-rw-r--r--src/ir/ExpressionManipulator.cpp15
-rw-r--r--src/ir/ReFinalize.cpp27
-rw-r--r--src/ir/abstract.h12
-rw-r--r--src/ir/block-utils.h3
-rw-r--r--src/ir/effects.h3
-rw-r--r--src/ir/flat.h12
-rw-r--r--src/ir/global-utils.h6
-rw-r--r--src/ir/literal-utils.h4
-rw-r--r--src/ir/manipulation.h16
-rw-r--r--src/ir/properties.h4
-rw-r--r--src/ir/utils.h6
-rw-r--r--src/js/binaryen.js-post.js53
-rw-r--r--src/literal.h36
-rw-r--r--src/parsing.h6
-rw-r--r--src/passes/ConstHoisting.cpp9
-rw-r--r--src/passes/DeadCodeElimination.cpp6
-rw-r--r--src/passes/Flatten.cpp37
-rw-r--r--src/passes/FuncCastEmulation.cpp16
-rw-r--r--src/passes/Inlining.cpp17
-rw-r--r--src/passes/InstrumentLocals.cpp32
-rw-r--r--src/passes/LegalizeJSInterface.cpp33
-rw-r--r--src/passes/LocalCSE.cpp7
-rw-r--r--src/passes/MergeLocals.cpp15
-rw-r--r--src/passes/OptimizeInstructions.cpp6
-rw-r--r--src/passes/Precompute.cpp19
-rw-r--r--src/passes/Print.cpp40
-rw-r--r--src/passes/RemoveUnusedModuleElements.cpp6
-rw-r--r--src/passes/SimplifyGlobals.cpp19
-rw-r--r--src/passes/SimplifyLocals.cpp6
-rw-r--r--src/passes/opt-utils.h13
-rw-r--r--src/shell-interface.h8
-rw-r--r--src/support/name.h2
-rw-r--r--src/support/small_vector.h10
-rw-r--r--src/tools/execution-results.h24
-rw-r--r--src/tools/fuzzing.h253
-rw-r--r--src/tools/spec-wrapper.h8
-rw-r--r--src/tools/wasm-reduce.cpp14
-rw-r--r--src/tools/wasm-shell.cpp4
-rw-r--r--src/wasm-binary.h28
-rw-r--r--src/wasm-builder.h73
-rw-r--r--src/wasm-interpreter.h47
-rw-r--r--src/wasm-s-parser.h3
-rw-r--r--src/wasm-stack.h30
-rw-r--r--src/wasm-traversal.h57
-rw-r--r--src/wasm-type.h24
-rw-r--r--src/wasm.h32
-rw-r--r--src/wasm/literal.cpp45
-rw-r--r--src/wasm/wasm-binary.cpp83
-rw-r--r--src/wasm/wasm-s-parser.cpp43
-rw-r--r--src/wasm/wasm-stack.cpp76
-rw-r--r--src/wasm/wasm-type.cpp53
-rw-r--r--src/wasm/wasm-validator.cpp306
-rw-r--r--src/wasm/wasm.cpp90
-rw-r--r--src/wasm2js.h12
59 files changed, 1532 insertions, 477 deletions
diff --git a/src/asmjs/asm_v_wasm.cpp b/src/asmjs/asm_v_wasm.cpp
index 3720ca079..5959db43e 100644
--- a/src/asmjs/asm_v_wasm.cpp
+++ b/src/asmjs/asm_v_wasm.cpp
@@ -53,10 +53,11 @@ AsmType wasmToAsmType(Type type) {
return ASM_INT64;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
case anyref:
- assert(false && "anyref is not supported by asm2wasm");
+ case nullref:
case exnref:
- assert(false && "exnref is not supported by asm2wasm");
+ assert(false && "reference types are not supported by asm2wasm");
case none:
return ASM_NONE;
case unreachable:
@@ -77,10 +78,14 @@ char getSig(Type type) {
return 'd';
case v128:
return 'V';
+ case funcref:
+ return 'F';
case anyref:
- return 'a';
+ return 'A';
+ case nullref:
+ return 'N';
case exnref:
- return 'e';
+ return 'E';
case none:
return 'v';
case unreachable:
diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp
index 82cbc4c1f..826193e06 100644
--- a/src/binaryen-c.cpp
+++ b/src/binaryen-c.cpp
@@ -64,13 +64,16 @@ BinaryenLiteral toBinaryenLiteral(Literal x) {
case Type::f64:
ret.i64 = x.reinterpreti64();
break;
- case Type::v128: {
+ case Type::v128:
memcpy(&ret.v128, x.getv128Ptr(), 16);
break;
- }
-
- case Type::anyref: // there's no anyref literals
- case Type::exnref: // there's no exnref literals
+ case Type::funcref:
+ ret.func = x.getFunc().c_str();
+ break;
+ case Type::nullref:
+ break;
+ case Type::anyref:
+ case Type::exnref:
case Type::none:
case Type::unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -90,8 +93,12 @@ Literal fromBinaryenLiteral(BinaryenLiteral x) {
return Literal(x.i64).castToF64();
case Type::v128:
return Literal(x.v128);
- case Type::anyref: // there's no anyref literals
- case Type::exnref: // there's no exnref literals
+ case Type::funcref:
+ return Literal::makeFuncref(x.func);
+ case Type::nullref:
+ return Literal::makeNullref();
+ case Type::anyref:
+ case Type::exnref:
case Type::none:
case Type::unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -209,8 +216,14 @@ void printArg(std::ostream& setup, std::ostream& out, BinaryenLiteral arg) {
out << "BinaryenLiteralVec128(" << array << ")";
break;
}
- case Type::anyref: // there's no anyref literals
- case Type::exnref: // there's no exnref literals
+ case Type::funcref:
+ out << "BinaryenLiteralFuncref(" << arg.func << ")";
+ break;
+ case Type::nullref:
+ out << "BinaryenLiteralNullref()";
+ break;
+ case Type::anyref:
+ case Type::exnref:
case Type::none:
case Type::unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -265,7 +278,9 @@ BinaryenType BinaryenTypeInt64(void) { return i64; }
BinaryenType BinaryenTypeFloat32(void) { return f32; }
BinaryenType BinaryenTypeFloat64(void) { return f64; }
BinaryenType BinaryenTypeVec128(void) { return v128; }
+BinaryenType BinaryenTypeFuncref(void) { return funcref; }
BinaryenType BinaryenTypeAnyref(void) { return anyref; }
+BinaryenType BinaryenTypeNullref(void) { return nullref; }
BinaryenType BinaryenTypeExnref(void) { return exnref; }
BinaryenType BinaryenTypeUnreachable(void) { return unreachable; }
BinaryenType BinaryenTypeAuto(void) { return uint32_t(-1); }
@@ -397,6 +412,15 @@ BinaryenExpressionId BinaryenMemoryCopyId(void) {
BinaryenExpressionId BinaryenMemoryFillId(void) {
return Expression::Id::MemoryFillId;
}
+BinaryenExpressionId BinaryenRefNullId(void) {
+ return Expression::Id::RefNullId;
+}
+BinaryenExpressionId BinaryenRefIsNullId(void) {
+ return Expression::Id::RefIsNullId;
+}
+BinaryenExpressionId BinaryenRefFuncId(void) {
+ return Expression::Id::RefFuncId;
+}
BinaryenExpressionId BinaryenTryId(void) { return Expression::Id::TryId; }
BinaryenExpressionId BinaryenThrowId(void) { return Expression::Id::ThrowId; }
BinaryenExpressionId BinaryenRethrowId(void) {
@@ -1330,17 +1354,22 @@ BinaryenExpressionRef BinaryenBinary(BinaryenModuleRef module,
BinaryenExpressionRef BinaryenSelect(BinaryenModuleRef module,
BinaryenExpressionRef condition,
BinaryenExpressionRef ifTrue,
- BinaryenExpressionRef ifFalse) {
+ BinaryenExpressionRef ifFalse,
+ BinaryenType type) {
auto* ret = ((Module*)module)->allocator.alloc<Select>();
if (tracing) {
- traceExpression(ret, "BinaryenSelect", condition, ifTrue, ifFalse);
+ traceExpression(ret, "BinaryenSelect", condition, ifTrue, ifFalse, type);
}
ret->condition = (Expression*)condition;
ret->ifTrue = (Expression*)ifTrue;
ret->ifFalse = (Expression*)ifFalse;
- ret->finalize();
+ if (type != BinaryenTypeAuto()) {
+ ret->finalize(Type(type));
+ } else {
+ ret->finalize();
+ }
return static_cast<Expression*>(ret);
}
BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module,
@@ -1695,6 +1724,32 @@ BinaryenExpressionRef BinaryenPop(BinaryenModuleRef module, BinaryenType type) {
return static_cast<Expression*>(ret);
}
+BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module) {
+ auto* ret = Builder(*(Module*)module).makeRefNull();
+ if (tracing) {
+ traceExpression(ret, "BinaryenRefNull");
+ }
+ return static_cast<Expression*>(ret);
+}
+
+BinaryenExpressionRef BinaryenRefIsNull(BinaryenModuleRef module,
+ BinaryenExpressionRef value) {
+ auto* ret = Builder(*(Module*)module).makeRefIsNull((Expression*)value);
+ if (tracing) {
+ traceExpression(ret, "BinaryenRefIsNull", value);
+ }
+ return static_cast<Expression*>(ret);
+}
+
+BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
+ const char* func) {
+ auto* ret = Builder(*(Module*)module).makeRefFunc(func);
+ if (tracing) {
+ traceExpression(ret, "BinaryenRefFunc", StringLit(func));
+ }
+ return static_cast<Expression*>(ret);
+}
+
BinaryenExpressionRef BinaryenTry(BinaryenModuleRef module,
BinaryenExpressionRef body,
BinaryenExpressionRef catchBody) {
@@ -2964,6 +3019,28 @@ BinaryenExpressionRef BinaryenPushGetValue(BinaryenExpressionRef expr) {
assert(expression->is<Push>());
return static_cast<Push*>(expression)->value;
}
+// RefIsNull
+BinaryenExpressionRef BinaryenRefIsNullGetValue(BinaryenExpressionRef expr) {
+ if (tracing) {
+ std::cout << " BinaryenRefIsNullGetValue(expressions[" << expressions[expr]
+ << "]);\n";
+ }
+
+ auto* expression = (Expression*)expr;
+ assert(expression->is<RefIsNull>());
+ return static_cast<RefIsNull*>(expression)->value;
+}
+// RefFunc
+const char* BinaryenRefFuncGetFunc(BinaryenExpressionRef expr) {
+ if (tracing) {
+ std::cout << " BinaryenRefFuncGetFunc(expressions[" << expressions[expr]
+ << "]);\n";
+ }
+
+ auto* expression = (Expression*)expr;
+ assert(expression->is<RefFunc>());
+ return static_cast<RefFunc*>(expression)->func.c_str();
+}
// Try
BinaryenExpressionRef BinaryenTryGetBody(BinaryenExpressionRef expr) {
if (tracing) {
diff --git a/src/binaryen-c.h b/src/binaryen-c.h
index f169ff367..36d82ecb0 100644
--- a/src/binaryen-c.h
+++ b/src/binaryen-c.h
@@ -98,7 +98,9 @@ BINARYEN_API BinaryenType BinaryenTypeInt64(void);
BINARYEN_API BinaryenType BinaryenTypeFloat32(void);
BINARYEN_API BinaryenType BinaryenTypeFloat64(void);
BINARYEN_API BinaryenType BinaryenTypeVec128(void);
+BINARYEN_API BinaryenType BinaryenTypeFuncref(void);
BINARYEN_API BinaryenType BinaryenTypeAnyref(void);
+BINARYEN_API BinaryenType BinaryenTypeNullref(void);
BINARYEN_API BinaryenType BinaryenTypeExnref(void);
BINARYEN_API BinaryenType BinaryenTypeUnreachable(void);
// Not a real type. Used as the last parameter to BinaryenBlock to let
@@ -158,6 +160,9 @@ BINARYEN_API BinaryenExpressionId BinaryenMemoryInitId(void);
BINARYEN_API BinaryenExpressionId BinaryenDataDropId(void);
BINARYEN_API BinaryenExpressionId BinaryenMemoryCopyId(void);
BINARYEN_API BinaryenExpressionId BinaryenMemoryFillId(void);
+BINARYEN_API BinaryenExpressionId BinaryenRefNullId(void);
+BINARYEN_API BinaryenExpressionId BinaryenRefIsNullId(void);
+BINARYEN_API BinaryenExpressionId BinaryenRefFuncId(void);
BINARYEN_API BinaryenExpressionId BinaryenTryId(void);
BINARYEN_API BinaryenExpressionId BinaryenThrowId(void);
BINARYEN_API BinaryenExpressionId BinaryenRethrowId(void);
@@ -222,6 +227,7 @@ struct BinaryenLiteral {
float f32;
double f64;
uint8_t v128[16];
+ const char* func;
};
};
@@ -692,7 +698,8 @@ BINARYEN_API BinaryenExpressionRef
BinaryenSelect(BinaryenModuleRef module,
BinaryenExpressionRef condition,
BinaryenExpressionRef ifTrue,
- BinaryenExpressionRef ifFalse);
+ BinaryenExpressionRef ifFalse,
+ BinaryenType type);
BINARYEN_API BinaryenExpressionRef BinaryenDrop(BinaryenModuleRef module,
BinaryenExpressionRef value);
// Return: value can be NULL
@@ -797,6 +804,11 @@ BinaryenMemoryFill(BinaryenModuleRef module,
BinaryenExpressionRef dest,
BinaryenExpressionRef value,
BinaryenExpressionRef size);
+BINARYEN_API BinaryenExpressionRef BinaryenRefNull(BinaryenModuleRef module);
+BINARYEN_API BinaryenExpressionRef
+BinaryenRefIsNull(BinaryenModuleRef module, BinaryenExpressionRef value);
+BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
+ const char* func);
BINARYEN_API BinaryenExpressionRef BinaryenTry(BinaryenModuleRef module,
BinaryenExpressionRef body,
BinaryenExpressionRef catchBody);
@@ -1036,6 +1048,11 @@ BINARYEN_API BinaryenExpressionRef
BinaryenMemoryFillGetSize(BinaryenExpressionRef expr);
BINARYEN_API BinaryenExpressionRef
+BinaryenRefIsNullGetValue(BinaryenExpressionRef expr);
+
+BINARYEN_API const char* BinaryenRefFuncGetFunc(BinaryenExpressionRef expr);
+
+BINARYEN_API BinaryenExpressionRef
BinaryenTryGetBody(BinaryenExpressionRef expr);
BINARYEN_API BinaryenExpressionRef
BinaryenTryGetCatchBody(BinaryenExpressionRef expr);
diff --git a/src/gen-s-parser.inc b/src/gen-s-parser.inc
index d60ee1794..eb5626b3f 100644
--- a/src/gen-s-parser.inc
+++ b/src/gen-s-parser.inc
@@ -653,6 +653,9 @@ switch (op[0]) {
default: goto parse_error;
}
}
+ case 'u':
+ if (strcmp(op, "funcref.pop") == 0) { return makePop(funcref); }
+ goto parse_error;
default: goto parse_error;
}
}
@@ -2480,30 +2483,57 @@ switch (op[0]) {
default: goto parse_error;
}
}
- case 'n':
- if (strcmp(op, "nop") == 0) { return makeNop(); }
- goto parse_error;
+ case 'n': {
+ switch (op[1]) {
+ case 'o':
+ if (strcmp(op, "nop") == 0) { return makeNop(); }
+ goto parse_error;
+ case 'u':
+ if (strcmp(op, "nullref.pop") == 0) { return makePop(nullref); }
+ goto parse_error;
+ default: goto parse_error;
+ }
+ }
case 'p':
if (strcmp(op, "push") == 0) { return makePush(s); }
goto parse_error;
case 'r': {
- switch (op[3]) {
- case 'h':
- if (strcmp(op, "rethrow") == 0) { return makeRethrow(s); }
- goto parse_error;
- case 'u': {
- switch (op[6]) {
- case '\0':
- if (strcmp(op, "return") == 0) { return makeReturn(s); }
+ switch (op[2]) {
+ case 'f': {
+ switch (op[4]) {
+ case 'f':
+ if (strcmp(op, "ref.func") == 0) { return makeRefFunc(s); }
goto parse_error;
- case '_': {
- switch (op[11]) {
+ case 'i':
+ if (strcmp(op, "ref.is_null") == 0) { return makeRefIsNull(s); }
+ goto parse_error;
+ case 'n':
+ if (strcmp(op, "ref.null") == 0) { return makeRefNull(s); }
+ goto parse_error;
+ default: goto parse_error;
+ }
+ }
+ case 't': {
+ switch (op[3]) {
+ case 'h':
+ if (strcmp(op, "rethrow") == 0) { return makeRethrow(s); }
+ goto parse_error;
+ case 'u': {
+ switch (op[6]) {
case '\0':
- if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); }
- goto parse_error;
- case '_':
- if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); }
+ if (strcmp(op, "return") == 0) { return makeReturn(s); }
goto parse_error;
+ case '_': {
+ switch (op[11]) {
+ case '\0':
+ if (strcmp(op, "return_call") == 0) { return makeCall(s, /*isReturn=*/true); }
+ goto parse_error;
+ case '_':
+ if (strcmp(op, "return_call_indirect") == 0) { return makeCallIndirect(s, /*isReturn=*/true); }
+ goto parse_error;
+ default: goto parse_error;
+ }
+ }
default: goto parse_error;
}
}
diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp
index 7355d1856..4b9869ddd 100644
--- a/src/ir/ExpressionAnalyzer.cpp
+++ b/src/ir/ExpressionAnalyzer.cpp
@@ -218,6 +218,9 @@ template<typename T> void visitImmediates(Expression* curr, T& visitor) {
visitor.visitInt(curr->op);
visitor.visitNonScopeName(curr->nameOperand);
}
+ void visitRefNull(RefNull* curr) {}
+ void visitRefIsNull(RefIsNull* curr) {}
+ void visitRefFunc(RefFunc* curr) { visitor.visitNonScopeName(curr->func); }
void visitTry(Try* curr) {}
void visitThrow(Throw* curr) { visitor.visitNonScopeName(curr->event); }
void visitRethrow(Rethrow* curr) {}
diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp
index fbee9f9c1..acea09bad 100644
--- a/src/ir/ExpressionManipulator.cpp
+++ b/src/ir/ExpressionManipulator.cpp
@@ -58,7 +58,7 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) {
curr->type);
}
Expression* visitLoop(Loop* curr) {
- return builder.makeLoop(curr->name, copy(curr->body));
+ return builder.makeLoop(curr->name, copy(curr->body), curr->type);
}
Expression* visitBreak(Break* curr) {
return builder.makeBreak(
@@ -208,8 +208,10 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) {
return builder.makeBinary(curr->op, copy(curr->left), copy(curr->right));
}
Expression* visitSelect(Select* curr) {
- return builder.makeSelect(
- copy(curr->condition), copy(curr->ifTrue), copy(curr->ifFalse));
+ return builder.makeSelect(copy(curr->condition),
+ copy(curr->ifTrue),
+ copy(curr->ifFalse),
+ curr->type);
}
Expression* visitDrop(Drop* curr) {
return builder.makeDrop(copy(curr->value));
@@ -226,6 +228,13 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) {
builder.makeHost(curr->op, curr->nameOperand, std::move(operands));
return ret;
}
+ Expression* visitRefNull(RefNull* curr) { return builder.makeRefNull(); }
+ Expression* visitRefIsNull(RefIsNull* curr) {
+ return builder.makeRefIsNull(copy(curr->value));
+ }
+ Expression* visitRefFunc(RefFunc* curr) {
+ return builder.makeRefFunc(curr->func);
+ }
Expression* visitTry(Try* curr) {
return builder.makeTry(
copy(curr->body), copy(curr->catchBody), curr->type);
diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp
index be0a8604b..9243869a1 100644
--- a/src/ir/ReFinalize.cpp
+++ b/src/ir/ReFinalize.cpp
@@ -44,23 +44,13 @@ void ReFinalize::visitBlock(Block* curr) {
curr->type = none;
return;
}
- // do this quickly, without any validation
- // last element determines type
+ // Get the least upper bound type of the last element and all branch return
+ // values
curr->type = curr->list.back()->type;
- // if concrete, it doesn't matter if we have an unreachable child, and we
- // don't need to look at breaks
- if (curr->type.isConcrete()) {
- return;
- }
- // otherwise, we have no final fallthrough element to determine the type,
- // could be determined by breaks
if (curr->name.is()) {
auto iter = breakValues.find(curr->name);
if (iter != breakValues.end()) {
- // there is a break to here
- auto type = iter->second;
- assert(type != unreachable); // we would have removed such branches
- curr->type = type;
+ curr->type = Type::getLeastUpperBound(curr->type, iter->second);
return;
}
}
@@ -130,6 +120,9 @@ void ReFinalize::visitSelect(Select* curr) { curr->finalize(); }
void ReFinalize::visitDrop(Drop* curr) { curr->finalize(); }
void ReFinalize::visitReturn(Return* curr) { curr->finalize(); }
void ReFinalize::visitHost(Host* curr) { curr->finalize(); }
+void ReFinalize::visitRefNull(RefNull* curr) { curr->finalize(); }
+void ReFinalize::visitRefIsNull(RefIsNull* curr) { curr->finalize(); }
+void ReFinalize::visitRefFunc(RefFunc* curr) { curr->finalize(); }
void ReFinalize::visitTry(Try* curr) { curr->finalize(); }
void ReFinalize::visitThrow(Throw* curr) { curr->finalize(); }
void ReFinalize::visitRethrow(Rethrow* curr) { curr->finalize(); }
@@ -159,8 +152,12 @@ void ReFinalize::visitEvent(Event* curr) { WASM_UNREACHABLE("unimp"); }
void ReFinalize::visitModule(Module* curr) { WASM_UNREACHABLE("unimp"); }
void ReFinalize::updateBreakValueType(Name name, Type type) {
- if (type != unreachable || breakValues.count(name) == 0) {
- breakValues[name] = type;
+ if (type != Type::unreachable) {
+ if (breakValues.count(name) == 0) {
+ breakValues[name] = type;
+ } else {
+ breakValues[name] = Type::getLeastUpperBound(breakValues[name], type);
+ }
}
}
diff --git a/src/ir/abstract.h b/src/ir/abstract.h
index 384f8b555..76215d07f 100644
--- a/src/ir/abstract.h
+++ b/src/ir/abstract.h
@@ -80,8 +80,10 @@ inline UnaryOp getUnary(Type type, Op op) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: // there's no unary instructions for anyref
- case exnref: // there's no unary instructions for exnref
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable: {
return InvalidUnary;
@@ -211,8 +213,10 @@ inline BinaryOp getBinary(Type type, Op op) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: // there's no binary instructions for anyref
- case exnref: // there's no binary instructions for exnref
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable: {
return InvalidBinary;
diff --git a/src/ir/block-utils.h b/src/ir/block-utils.h
index ca8b7179b..153dd45b3 100644
--- a/src/ir/block-utils.h
+++ b/src/ir/block-utils.h
@@ -43,7 +43,8 @@ simplifyToContents(Block* block, T* parent, bool allowTypeChange = false) {
// no side effects, and singleton is not returning a value, so we can
// throw away the block and its contents, basically
return Builder(*parent->getModule()).replaceWithIdenticalType(block);
- } else if (block->type == singleton->type || allowTypeChange) {
+ } else if (Type::isSubType(singleton->type, block->type) ||
+ allowTypeChange) {
return singleton;
} else {
// (side effects +) type change, must be block with declared value but
diff --git a/src/ir/effects.h b/src/ir/effects.h
index e93c63017..6eb2da91d 100644
--- a/src/ir/effects.h
+++ b/src/ir/effects.h
@@ -387,6 +387,9 @@ struct EffectAnalyzer
// Atomics are also sequentially consistent with memory.grow.
isAtomic = true;
}
+ void visitRefNull(RefNull* curr) {}
+ void visitRefIsNull(RefIsNull* curr) {}
+ void visitRefFunc(RefFunc* curr) {}
void visitTry(Try* curr) {}
// We safely model throws as branches
void visitThrow(Throw* curr) { branches = true; }
diff --git a/src/ir/flat.h b/src/ir/flat.h
index dd72e339d..01a94a759 100644
--- a/src/ir/flat.h
+++ b/src/ir/flat.h
@@ -56,6 +56,7 @@
#define wasm_ir_flat_h
#include "ir/iteration.h"
+#include "ir/properties.h"
#include "pass.h"
#include "wasm-traversal.h"
@@ -64,7 +65,8 @@ namespace wasm {
namespace Flat {
inline bool isControlFlowStructure(Expression* curr) {
- return curr->is<Block>() || curr->is<If>() || curr->is<Loop>();
+ return curr->is<Block>() || curr->is<If>() || curr->is<Loop>() ||
+ curr->is<Try>();
}
inline void verifyFlatness(Function* func) {
@@ -79,10 +81,10 @@ inline void verifyFlatness(Function* func) {
verify(!curr->type.isConcrete(), "tees are not allowed, only sets");
} else {
for (auto* child : ChildIterator(curr)) {
- verify(child->is<Const>() || child->is<LocalGet>() ||
- child->is<Unreachable>(),
- "instructions must only have const, local.get, or unreachable "
- "as children");
+ verify(Properties::isConstantExpression(child) ||
+ child->is<LocalGet>() || child->is<Unreachable>(),
+ "instructions must only have constant expressions, local.get, "
+ "or unreachable as children");
}
}
}
diff --git a/src/ir/global-utils.h b/src/ir/global-utils.h
index 93e5c8a67..e096aec8c 100644
--- a/src/ir/global-utils.h
+++ b/src/ir/global-utils.h
@@ -52,6 +52,12 @@ getGlobalInitializedToImport(Module& wasm, Name module, Name base) {
});
return ret;
}
+
+inline bool canInitializeGlobal(const Expression* curr) {
+ return curr->is<Const>() || curr->is<RefNull>() || curr->is<RefFunc>() ||
+ curr->is<GlobalGet>();
+}
+
} // namespace GlobalUtils
} // namespace wasm
diff --git a/src/ir/literal-utils.h b/src/ir/literal-utils.h
index 63a2b3b44..4bc79eee9 100644
--- a/src/ir/literal-utils.h
+++ b/src/ir/literal-utils.h
@@ -39,6 +39,10 @@ inline Expression* makeZero(Type type, Module& wasm) {
return builder.makeUnary(SplatVecI32x4,
builder.makeConst(Literal(int32_t(0))));
}
+ if (type.isRef()) {
+ Builder builder(wasm);
+ return builder.makeRefNull();
+ }
return makeFromInt32(0, type, wasm);
}
diff --git a/src/ir/manipulation.h b/src/ir/manipulation.h
index ec137d372..49ed7e11e 100644
--- a/src/ir/manipulation.h
+++ b/src/ir/manipulation.h
@@ -33,14 +33,24 @@ inline OutputType* convert(InputType* input) {
return output;
}
-// Convenience method for nop, which is a common conversion
+// Convenience methods for certain instructions, which are common conversions
template<typename InputType> inline Nop* nop(InputType* target) {
- return convert<InputType, Nop>(target);
+ auto* ret = convert<InputType, Nop>(target);
+ ret->finalize();
+ return ret;
+}
+
+template<typename InputType> inline RefNull* refNull(InputType* target) {
+ auto* ret = convert<InputType, RefNull>(target);
+ ret->finalize();
+ return ret;
}
template<typename InputType>
inline Unreachable* unreachable(InputType* target) {
- return convert<InputType, Unreachable>(target);
+ auto* ret = convert<InputType, Unreachable>(target);
+ ret->finalize();
+ return ret;
}
// Convert a node that allocates
diff --git a/src/ir/properties.h b/src/ir/properties.h
index bb88af6c5..f4c9686b6 100644
--- a/src/ir/properties.h
+++ b/src/ir/properties.h
@@ -187,6 +187,10 @@ inline Expression* getFallthrough(Expression* curr) {
return curr;
}
+inline bool isConstantExpression(const Expression* curr) {
+ return curr->is<Const>() || curr->is<RefNull>() || curr->is<RefFunc>();
+}
+
} // namespace Properties
} // namespace wasm
diff --git a/src/ir/utils.h b/src/ir/utils.h
index cad7bc885..9bd3c9e0b 100644
--- a/src/ir/utils.h
+++ b/src/ir/utils.h
@@ -146,6 +146,9 @@ struct ReFinalize
void visitDrop(Drop* curr);
void visitReturn(Return* curr);
void visitHost(Host* curr);
+ void visitRefNull(RefNull* curr);
+ void visitRefIsNull(RefIsNull* curr);
+ void visitRefFunc(RefFunc* curr);
void visitTry(Try* curr);
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
@@ -210,6 +213,9 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> {
void visitDrop(Drop* curr) { curr->finalize(); }
void visitReturn(Return* curr) { curr->finalize(); }
void visitHost(Host* curr) { curr->finalize(); }
+ void visitRefNull(RefNull* curr) { curr->finalize(); }
+ void visitRefIsNull(RefIsNull* curr) { curr->finalize(); }
+ void visitRefFunc(RefFunc* curr) { curr->finalize(); }
void visitTry(Try* curr) { curr->finalize(); }
void visitThrow(Throw* curr) { curr->finalize(); }
void visitRethrow(Rethrow* curr) { curr->finalize(); }
diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js
index 8cddc61c7..2993573d1 100644
--- a/src/js/binaryen.js-post.js
+++ b/src/js/binaryen.js-post.js
@@ -38,7 +38,9 @@ function initializeConstants() {
['f32', 'Float32'],
['f64', 'Float64'],
['v128', 'Vec128'],
+ ['funcref', 'Funcref'],
['anyref', 'Anyref'],
+ ['nullref', 'Nullref'],
['exnref', 'Exnref'],
['unreachable', 'Unreachable'],
['auto', 'Auto']
@@ -86,6 +88,9 @@ function initializeConstants() {
'DataDrop',
'MemoryCopy',
'MemoryFill',
+ 'RefNull',
+ 'RefIsNull',
+ 'RefFunc',
'Try',
'Throw',
'Rethrow',
@@ -1952,20 +1957,47 @@ function wrapModule(module, self) {
},
};
+ self['funcref'] = {
+ 'pop': function() {
+ return Module['_BinaryenPop'](module, Module['funcref']);
+ }
+ };
+
self['anyref'] = {
'pop': function() {
return Module['_BinaryenPop'](module, Module['anyref']);
}
};
+ self['nullref'] = {
+ 'pop': function() {
+ return Module['_BinaryenPop'](module, Module['nullref']);
+ }
+ };
+
self['exnref'] = {
'pop': function() {
return Module['_BinaryenPop'](module, Module['exnref']);
}
};
- self['select'] = function(condition, ifTrue, ifFalse) {
- return Module['_BinaryenSelect'](module, condition, ifTrue, ifFalse);
+ self['ref'] = {
+ 'null': function() {
+ return Module['_BinaryenRefNull'](module);
+ },
+ 'is_null': function(value) {
+ return Module['_BinaryenRefIsNull'](module, value);
+ },
+ 'func': function(func) {
+ return preserveStack(function() {
+ return Module['_BinaryenRefFunc'](module, strToStack(func));
+ });
+ }
+ };
+
+ self['select'] = function(condition, ifTrue, ifFalse, type) {
+ return Module['_BinaryenSelect'](
+ module, condition, ifTrue, ifFalse, typeof type !== 'undefined' ? type : Module['auto']);
};
self['drop'] = function(value) {
return Module['_BinaryenDrop'](module, value);
@@ -2651,6 +2683,23 @@ Module['getExpressionInfo'] = function(expr) {
'value': Module['_BinaryenMemoryFillGetValue'](expr),
'size': Module['_BinaryenMemoryFillGetSize'](expr)
};
+ case Module['RefNullId']:
+ return {
+ 'id': id,
+ 'type': type
+ };
+ case Module['RefIsNullId']:
+ return {
+ 'id': id,
+ 'type': type,
+ 'value': Module['_BinaryenRefIsNullGetValue'](expr)
+ };
+ case Module['RefFuncId']:
+ return {
+ 'id': id,
+ 'type': type,
+ 'func': UTF8ToString(Module['_BinaryenRefFuncGetFunc'](expr)),
+ };
case Module['TryId']:
return {
'id': id,
diff --git a/src/literal.h b/src/literal.h
index 1d19e6661..ef3e13d44 100644
--- a/src/literal.h
+++ b/src/literal.h
@@ -22,6 +22,7 @@
#include "compiler-support.h"
#include "support/hash.h"
+#include "support/name.h"
#include "support/utilities.h"
#include "wasm-type.h"
@@ -34,6 +35,7 @@ class Literal {
int32_t i32;
int64_t i64;
uint8_t v128[16];
+ Name func; // function name for funcref
};
public:
@@ -57,11 +59,12 @@ public:
explicit Literal(const std::array<Literal, 8>&);
explicit Literal(const std::array<Literal, 4>&);
explicit Literal(const std::array<Literal, 2>&);
+ explicit Literal(Name func) : func(func), type(Type::funcref) {}
- bool isConcrete() { return type != none; }
- bool isNull() { return type == none; }
+ bool isConcrete() { return type != Type::none; }
+ bool isNone() { return type == Type::none; }
- inline static Literal makeFromInt32(int32_t x, Type type) {
+ static Literal makeFromInt32(int32_t x, Type type) {
switch (type) {
case Type::i32:
return Literal(int32_t(x));
@@ -80,16 +83,26 @@ public:
Literal(int32_t(0)),
Literal(int32_t(0)),
Literal(int32_t(0))}});
- case Type::anyref: // there's no anyref literals
- case Type::exnref: // there's no exnref literals
- case none:
- case unreachable:
+ case Type::funcref:
+ case Type::anyref:
+ case Type::nullref:
+ case Type::exnref:
+ case Type::none:
+ case Type::unreachable:
WASM_UNREACHABLE("unexpected type");
}
WASM_UNREACHABLE("unexpected type");
}
- inline static Literal makeZero(Type type) { return makeFromInt32(0, type); }
+ static Literal makeZero(Type type) {
+ if (type.isRef()) {
+ return makeNullref();
+ }
+ return makeFromInt32(0, type);
+ }
+
+ static Literal makeNullref() { return Literal(Type(Type::nullref)); }
+ static Literal makeFuncref(Name func) { return Literal(func.c_str()); }
Literal castToF32();
Literal castToF64();
@@ -113,6 +126,7 @@ public:
return bit_cast<double>(i64);
}
std::array<uint8_t, 16> getv128() const;
+ Name getFunc() const { return func; }
// careful!
int32_t* geti32Ptr() {
@@ -464,8 +478,10 @@ template<> struct less<wasm::Literal> {
return a.reinterpreti64() < b.reinterpreti64();
case wasm::Type::v128:
return memcmp(a.getv128Ptr(), b.getv128Ptr(), 16) < 0;
- case wasm::Type::anyref: // anyref is an opaque value
- case wasm::Type::exnref: // exnref is an opaque value
+ case wasm::Type::funcref:
+ case wasm::Type::anyref:
+ case wasm::Type::nullref:
+ case wasm::Type::exnref:
case wasm::Type::none:
case wasm::Type::unreachable:
return false;
diff --git a/src/parsing.h b/src/parsing.h
index 7017fdb0f..d64236df3 100644
--- a/src/parsing.h
+++ b/src/parsing.h
@@ -263,8 +263,10 @@ parseConst(cashew::IString s, Type type, MixedArena& allocator) {
break;
}
case v128:
- case anyref: // there's no anyref.const
- case exnref: // there's no exnref.const
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
WASM_UNREACHABLE("unexpected const type");
case none:
case unreachable: {
diff --git a/src/passes/ConstHoisting.cpp b/src/passes/ConstHoisting.cpp
index dbb3853d8..4e8cd9910 100644
--- a/src/passes/ConstHoisting.cpp
+++ b/src/passes/ConstHoisting.cpp
@@ -91,9 +91,12 @@ private:
size = value.type.getByteSize();
break;
}
- case v128: // v128 not implemented yet
- case anyref: // anyref cannot have literals
- case exnref: { // exnref cannot have literals
+ // not implemented yet
+ case v128:
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref: {
return false;
}
case none:
diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp
index be6f92ffa..7d5385a83 100644
--- a/src/passes/DeadCodeElimination.cpp
+++ b/src/passes/DeadCodeElimination.cpp
@@ -347,6 +347,12 @@ struct DeadCodeElimination
DELEGATE(Push);
case Expression::Id::PopId:
DELEGATE(Pop);
+ case Expression::Id::RefNullId:
+ DELEGATE(RefNull);
+ case Expression::Id::RefIsNullId:
+ DELEGATE(RefIsNull);
+ case Expression::Id::RefFuncId:
+ DELEGATE(RefFunc);
case Expression::Id::TryId:
DELEGATE(Try);
case Expression::Id::ThrowId:
diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp
index f5115567b..fda8e3f80 100644
--- a/src/passes/Flatten.cpp
+++ b/src/passes/Flatten.cpp
@@ -21,6 +21,7 @@
#include <ir/branch-utils.h>
#include <ir/effects.h>
#include <ir/flat.h>
+#include <ir/properties.h>
#include <ir/utils.h>
#include <pass.h>
#include <wasm-builder.h>
@@ -61,7 +62,9 @@ struct Flatten
std::vector<Expression*> ourPreludes;
Builder builder(*getModule());
- if (curr->is<Const>() || curr->is<Nop>() || curr->is<Unreachable>()) {
+ // Nothing to do for constants, nop, and unreachable
+ if (Properties::isConstantExpression(curr) || curr->is<Nop>() ||
+ curr->is<Unreachable>()) {
return;
}
@@ -194,8 +197,37 @@ struct Flatten
auto type = br->value->type;
if (type.isConcrete()) {
// we are sending a value. use a local instead
- Index temp = getTempForBreakTarget(br->name, type);
+ Type blockType = findBreakTarget(br->name)->type;
+ Index temp = getTempForBreakTarget(br->name, blockType);
ourPreludes.push_back(builder.makeLocalSet(temp, br->value));
+
+ // br_if leaves a value on the stack if not taken, which later can
+ // be the last element of the enclosing innermost block and flow
+ // out. The local we created using 'getTempForBreakTarget' returns
+ // the return type of the block this branch is targetting, which may
+ // not be the same with the innermost block's return type. For
+ // example,
+ // (block $any (result anyref)
+ // (block (result nullref)
+ // (local.tee $0
+ // (br_if $any
+ // (ref.null)
+ // (i32.const 0)
+ // )
+ // )
+ // )
+ // )
+ // In this case we need two locals to store (ref.null); one with
+ // anyref type that's for the target block ($label0) and one more
+ // with nullref type in case for flowing out. Here we create the
+ // second 'flowing out' local in case two block's types are
+ // different.
+ if (type != blockType) {
+ temp = builder.addVar(getFunction(), type);
+ ourPreludes.push_back(builder.makeLocalSet(
+ temp, ExpressionManipulator::copy(br->value, *getModule())));
+ }
+
if (br->condition) {
// the value must also flow out
ourPreludes.push_back(br);
@@ -239,6 +271,7 @@ struct Flatten
}
}
}
+ // TODO Handle br_on_exn
// continue for general handling of everything, control flow or otherwise
curr = getCurrent(); // we may have replaced it
diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp
index 729a4a6c3..9d5109a83 100644
--- a/src/passes/FuncCastEmulation.cpp
+++ b/src/passes/FuncCastEmulation.cpp
@@ -65,11 +65,11 @@ static Expression* toABI(Expression* value, Module* module) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: {
- WASM_UNREACHABLE("anyref cannot be converted to i64");
- }
+ case funcref:
+ case anyref:
+ case nullref:
case exnref: {
- WASM_UNREACHABLE("exnref cannot be converted to i64");
+ WASM_UNREACHABLE("reference types cannot be converted to i64");
}
case none: {
// the value is none, but we need a value here
@@ -108,11 +108,11 @@ static Expression* fromABI(Expression* value, Type type, Module* module) {
case v128: {
WASM_UNREACHABLE("v128 not implemented yet");
}
- case anyref: {
- WASM_UNREACHABLE("anyref cannot be converted from i64");
- }
+ case funcref:
+ case anyref:
+ case nullref:
case exnref: {
- WASM_UNREACHABLE("exnref cannot be converted from i64");
+ WASM_UNREACHABLE("reference types cannot be converted from i64");
}
case none: {
value = builder.makeDrop(value);
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index db1db5971..c43d41e7f 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -46,13 +46,13 @@ namespace wasm {
// Useful into on a function, helping us decide if we can inline it
struct FunctionInfo {
- std::atomic<Index> calls;
+ std::atomic<Index> refs;
Index size;
std::atomic<bool> lightweight;
bool usedGlobally; // in a table or export
FunctionInfo() {
- calls = 0;
+ refs = 0;
size = 0;
lightweight = true;
usedGlobally = false;
@@ -75,7 +75,7 @@ struct FunctionInfo {
// FIXME: move this check to be first in this function, since we should
// return true if oneCallerInlineMaxSize is bigger than
// flexibleInlineMaxSize (which it typically should be).
- if (calls == 1 && !usedGlobally &&
+ if (refs == 1 && !usedGlobally &&
size <= options.inlining.oneCallerInlineMaxSize) {
return true;
}
@@ -108,11 +108,16 @@ struct FunctionInfoScanner
void visitCall(Call* curr) {
// can't add a new element in parallel
assert(infos->count(curr->target) > 0);
- (*infos)[curr->target].calls++;
+ (*infos)[curr->target].refs++;
// having a call is not lightweight
(*infos)[getFunction()->name].lightweight = false;
}
+ void visitRefFunc(RefFunc* curr) {
+ assert(infos->count(curr->func) > 0);
+ (*infos)[curr->func].refs++;
+ }
+
void visitFunction(Function* curr) {
(*infos)[curr->name].size = Measurer::measure(curr->body);
}
@@ -374,7 +379,7 @@ struct Inlining : public Pass {
doInlining(module, func.get(), action);
inlinedUses[inlinedName]++;
inlinedInto.insert(func.get());
- assert(inlinedUses[inlinedName] <= infos[inlinedName].calls);
+ assert(inlinedUses[inlinedName] <= infos[inlinedName].refs);
}
}
// anything we inlined into may now have non-unique label names, fix it up
@@ -388,7 +393,7 @@ struct Inlining : public Pass {
module->removeFunctions([&](Function* func) {
auto name = func->name;
auto& info = infos[name];
- return inlinedUses.count(name) && inlinedUses[name] == info.calls &&
+ return inlinedUses.count(name) && inlinedUses[name] == info.refs &&
!info.usedGlobally;
});
// return whether we did any work
diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp
index 407903219..ae35ec2d1 100644
--- a/src/passes/InstrumentLocals.cpp
+++ b/src/passes/InstrumentLocals.cpp
@@ -56,14 +56,18 @@ Name get_i32("get_i32");
Name get_i64("get_i64");
Name get_f32("get_f32");
Name get_f64("get_f64");
+Name get_funcref("get_funcref");
Name get_anyref("get_anyref");
+Name get_nullref("get_nullref");
Name get_exnref("get_exnref");
Name set_i32("set_i32");
Name set_i64("set_i64");
Name set_f32("set_f32");
Name set_f64("set_f64");
+Name set_funcref("set_funcref");
Name set_anyref("set_anyref");
+Name set_nullref("set_nullref");
Name set_exnref("set_exnref");
struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
@@ -84,9 +88,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
break;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
+ import = get_funcref;
+ break;
case anyref:
import = get_anyref;
break;
+ case nullref:
+ import = get_nullref;
+ break;
case exnref:
import = get_exnref;
break;
@@ -126,9 +136,15 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
break;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
+ import = set_funcref;
+ break;
case anyref:
import = set_anyref;
break;
+ case nullref:
+ import = set_nullref;
+ break;
case exnref:
import = set_exnref;
break;
@@ -156,10 +172,26 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> {
addImport(curr, set_f64, {Type::i32, Type::i32, Type::f64}, Type::f64);
if (curr->features.hasReferenceTypes()) {
+ addImport(curr,
+ get_funcref,
+ {Type::i32, Type::i32, Type::funcref},
+ Type::funcref);
+ addImport(curr,
+ set_funcref,
+ {Type::i32, Type::i32, Type::funcref},
+ Type::funcref);
addImport(
curr, get_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref);
addImport(
curr, set_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref);
+ addImport(curr,
+ get_nullref,
+ {Type::i32, Type::i32, Type::nullref},
+ Type::nullref);
+ addImport(curr,
+ set_nullref,
+ {Type::i32, Type::i32, Type::nullref},
+ Type::nullref);
}
if (curr->features.hasExceptionHandling()) {
addImport(
diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp
index 8c7bc4414..df6651b0d 100644
--- a/src/passes/LegalizeJSInterface.cpp
+++ b/src/passes/LegalizeJSInterface.cpp
@@ -107,14 +107,43 @@ struct LegalizeJSInterface : public Pass {
}
}
}
+
if (!illegalImportsToLegal.empty()) {
+ // Gather functions used in 'ref.func'. They should not be removed.
+ std::unordered_map<Name, std::atomic<bool>> usedInRefFunc;
+
+ struct RefFuncScanner : public WalkerPass<PostWalker<RefFuncScanner>> {
+ Module& wasm;
+ std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc;
+
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override {
+ return new RefFuncScanner(wasm, usedInRefFunc);
+ }
+
+ RefFuncScanner(
+ Module& wasm,
+ std::unordered_map<Name, std::atomic<bool>>& usedInRefFunc)
+ : wasm(wasm), usedInRefFunc(usedInRefFunc) {
+ // Fill in unordered_map, as we operate on it in parallel
+ for (auto& func : wasm.functions) {
+ usedInRefFunc[func->name];
+ }
+ }
+
+ void visitRefFunc(RefFunc* curr) { usedInRefFunc[curr->func] = true; }
+ };
+
+ RefFuncScanner(*module, usedInRefFunc).run(runner, module);
for (auto& pair : illegalImportsToLegal) {
- module->removeFunction(pair.first);
+ if (!usedInRefFunc[pair.first]) {
+ module->removeFunction(pair.first);
+ }
}
// fix up imports: call_import of an illegal must be turned to a call of a
// legal
-
struct FixImports : public WalkerPass<PostWalker<FixImports>> {
bool isFunctionParallel() override { return true; }
diff --git a/src/passes/LocalCSE.cpp b/src/passes/LocalCSE.cpp
index 0816bf6ea..b49c92310 100644
--- a/src/passes/LocalCSE.cpp
+++ b/src/passes/LocalCSE.cpp
@@ -172,9 +172,12 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> {
void handle(Expression* curr) {
if (auto* set = curr->dynCast<LocalSet>()) {
// Calculate equivalences
+ auto* func = getFunction();
equivalences.reset(set->index);
if (auto* get = set->value->dynCast<LocalGet>()) {
- equivalences.add(set->index, get->index);
+ if (func->getLocalType(set->index) == func->getLocalType(get->index)) {
+ equivalences.add(set->index, get->index);
+ }
}
// consider the value
auto* value = set->value;
@@ -184,7 +187,7 @@ struct LocalCSE : public WalkerPass<LinearExecutionWalker<LocalCSE>> {
if (iter != usables.end()) {
// already exists in the table, this is good to reuse
auto& info = iter->second;
- Type localType = getFunction()->getLocalType(info.index);
+ Type localType = func->getLocalType(info.index);
set->value =
Builder(*getModule()).makeLocalGet(info.index, localType);
anotherPass = true;
diff --git a/src/passes/MergeLocals.cpp b/src/passes/MergeLocals.cpp
index 0116753f1..2223594b6 100644
--- a/src/passes/MergeLocals.cpp
+++ b/src/passes/MergeLocals.cpp
@@ -100,7 +100,8 @@ struct MergeLocals
return;
}
// compute all dependencies
- LocalGraph preGraph(getFunction());
+ auto* func = getFunction();
+ LocalGraph preGraph(func);
preGraph.computeInfluences();
// optimize each copy
std::unordered_map<LocalSet*, LocalSet*> optimizedToCopy,
@@ -119,6 +120,11 @@ struct MergeLocals
if (preGraph.getSetses[influencedGet].size() == 1) {
// this is ok
assert(*preGraph.getSetses[influencedGet].begin() == trivial);
+ // If local types are different (when one is a subtype of the
+ // other), don't optimize
+ if (func->getLocalType(copy->index) != influencedGet->type) {
+ canOptimizeToCopy = false;
+ }
} else {
canOptimizeToCopy = false;
break;
@@ -152,6 +158,11 @@ struct MergeLocals
if (preGraph.getSetses[influencedGet].size() == 1) {
// this is ok
assert(*preGraph.getSetses[influencedGet].begin() == copy);
+ // If local types are different (when one is a subtype of the
+ // other), don't optimize
+ if (func->getLocalType(trivial->index) != influencedGet->type) {
+ canOptimizeToTrivial = false;
+ }
} else {
canOptimizeToTrivial = false;
break;
@@ -176,7 +187,7 @@ struct MergeLocals
// if one does not work, we need to undo all its siblings (don't extend
// the live range unless we are definitely removing a conflict, same
// logic as before).
- LocalGraph postGraph(getFunction());
+ LocalGraph postGraph(func);
postGraph.computeInfluences();
for (auto& pair : optimizedToCopy) {
auto* copy = pair.first;
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 6de1d3d00..edd6ba2b6 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -751,12 +751,12 @@ struct OptimizeInstructions
// condition, do that
auto needCondition =
EffectAnalyzer(getPassOptions(), iff->condition).hasSideEffects();
- auto typeIsIdentical = iff->ifTrue->type == iff->type;
- if (typeIsIdentical && !needCondition) {
+ auto isSubType = Type::isSubType(iff->ifTrue->type, iff->type);
+ if (isSubType && !needCondition) {
return iff->ifTrue;
} else {
Builder builder(*getModule());
- if (typeIsIdentical) {
+ if (isSubType) {
return builder.makeSequence(builder.makeDrop(iff->condition),
iff->ifTrue);
} else {
diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp
index 57a3ab27f..85eb026f9 100644
--- a/src/passes/Precompute.cpp
+++ b/src/passes/Precompute.cpp
@@ -177,7 +177,7 @@ struct Precompute
void visitExpression(Expression* curr) {
// TODO: if local.get, only replace with a constant if we don't care about
// size...?
- if (curr->is<Const>() || curr->is<Nop>()) {
+ if (Properties::isConstantExpression(curr) || curr->is<Nop>()) {
return;
}
// Until engines implement v128.const and we have SIMD-aware optimizations
@@ -208,14 +208,16 @@ struct Precompute
return;
}
}
- ret->value = Builder(*getModule()).makeConst(flow.value);
+ ret->value = Builder(*getModule()).makeConstExpression(flow.value);
} else {
ret->value = nullptr;
}
} else {
Builder builder(*getModule());
- replaceCurrent(builder.makeReturn(
- flow.value.type != none ? builder.makeConst(flow.value) : nullptr));
+ replaceCurrent(
+ builder.makeReturn(flow.value.type != Type::none
+ ? builder.makeConstExpression(flow.value)
+ : nullptr));
}
return;
}
@@ -234,7 +236,7 @@ struct Precompute
return;
}
}
- br->value = Builder(*getModule()).makeConst(flow.value);
+ br->value = Builder(*getModule()).makeConstExpression(flow.value);
} else {
br->value = nullptr;
}
@@ -243,13 +245,14 @@ struct Precompute
Builder builder(*getModule());
replaceCurrent(builder.makeBreak(
flow.breakTo,
- flow.value.type != none ? builder.makeConst(flow.value) : nullptr));
+ flow.value.type != none ? builder.makeConstExpression(flow.value)
+ : nullptr));
}
return;
}
// this was precomputed
if (flow.value.type.isConcrete()) {
- replaceCurrent(Builder(*getModule()).makeConst(flow.value));
+ replaceCurrent(Builder(*getModule()).makeConstExpression(flow.value));
worked = true;
} else {
ExpressionManipulator::nop(curr);
@@ -350,7 +353,7 @@ private:
} else {
curr = setValues[set];
}
- if (curr.isNull()) {
+ if (curr.isNone()) {
// not a constant, give up
value = Literal();
break;
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index 5efd1fd28..51e78c8a7 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -1333,7 +1333,12 @@ struct PrintExpressionContents
}
restoreNormalColor(o);
}
- void visitSelect(Select* curr) { prepareColor(o) << "select"; }
+ void visitSelect(Select* curr) {
+ prepareColor(o) << "select";
+ if (curr->type.isRef()) {
+ o << " (result " << curr->type << ')';
+ }
+ }
void visitDrop(Drop* curr) { printMedium(o, "drop"); }
void visitReturn(Return* curr) { printMedium(o, "return"); }
void visitHost(Host* curr) {
@@ -1346,6 +1351,12 @@ struct PrintExpressionContents
break;
}
}
+ void visitRefNull(RefNull* curr) { printMedium(o, "ref.null"); }
+ void visitRefIsNull(RefIsNull* curr) { printMedium(o, "ref.is_null"); }
+ void visitRefFunc(RefFunc* curr) {
+ printMedium(o, "ref.func ");
+ printName(curr->func, o);
+ }
void visitTry(Try* curr) {
printMedium(o, "try");
if (curr->type.isConcrete()) {
@@ -1852,6 +1863,23 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
}
}
}
+ void visitRefNull(RefNull* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ o << ')';
+ }
+ void visitRefIsNull(RefIsNull* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ incIndent();
+ printFullLine(curr->value);
+ decIndent();
+ }
+ void visitRefFunc(RefFunc* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ o << ')';
+ }
// try-catch-end is written in the folded wat format as
// (try
// ...
@@ -2434,13 +2462,15 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) {
}
case StackInst::BlockBegin:
case StackInst::IfBegin:
- case StackInst::LoopBegin: {
+ case StackInst::LoopBegin:
+ case StackInst::TryBegin: {
o << getExpressionName(inst->origin);
break;
}
case StackInst::BlockEnd:
case StackInst::IfEnd:
- case StackInst::LoopEnd: {
+ case StackInst::LoopEnd:
+ case StackInst::TryEnd: {
o << "end (" << inst->type << ')';
break;
}
@@ -2448,6 +2478,10 @@ WasmPrinter::printStackInst(StackInst* inst, std::ostream& o, Function* func) {
o << "else";
break;
}
+ case StackInst::Catch: {
+ o << "catch";
+ break;
+ }
default:
WASM_UNREACHABLE("unexpeted op");
}
diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp
index f5000e3a4..21cbc5e5b 100644
--- a/src/passes/RemoveUnusedModuleElements.cpp
+++ b/src/passes/RemoveUnusedModuleElements.cpp
@@ -116,6 +116,12 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> {
usesMemory = true;
}
}
+ void visitRefFunc(RefFunc* curr) {
+ if (reachable.count(
+ ModuleElement(ModuleElementKind::Function, curr->func)) == 0) {
+ queue.emplace_back(ModuleElementKind::Function, curr->func);
+ }
+ }
void visitThrow(Throw* curr) {
if (reachable.count(ModuleElement(ModuleElementKind::Event, curr->event)) ==
0) {
diff --git a/src/passes/SimplifyGlobals.cpp b/src/passes/SimplifyGlobals.cpp
index 88f27f8be..b18f726ed 100644
--- a/src/passes/SimplifyGlobals.cpp
+++ b/src/passes/SimplifyGlobals.cpp
@@ -37,6 +37,7 @@
#include <atomic>
#include "ir/effects.h"
+#include "ir/properties.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
@@ -106,8 +107,9 @@ struct ConstantGlobalApplier
void visitExpression(Expression* curr) {
if (auto* set = curr->dynCast<GlobalSet>()) {
- if (auto* c = set->value->dynCast<Const>()) {
- currConstantGlobals[set->name] = c->value;
+ if (Properties::isConstantExpression(set->value)) {
+ currConstantGlobals[set->name] =
+ getLiteralFromConstExpression(set->value);
} else {
currConstantGlobals.erase(set->name);
}
@@ -116,7 +118,7 @@ struct ConstantGlobalApplier
// Check if the global is known to be constant all the time.
if (constantGlobals->count(get->name)) {
auto* global = getModule()->getGlobal(get->name);
- assert(global->init->is<Const>());
+ assert(Properties::isConstantExpression(global->init));
replaceCurrent(ExpressionManipulator::copy(global->init, *getModule()));
replaced = true;
return;
@@ -125,7 +127,7 @@ struct ConstantGlobalApplier
auto iter = currConstantGlobals.find(get->name);
if (iter != currConstantGlobals.end()) {
Builder builder(*getModule());
- replaceCurrent(builder.makeConst(iter->second));
+ replaceCurrent(builder.makeConstExpression(iter->second));
replaced = true;
}
return;
@@ -249,13 +251,14 @@ struct SimplifyGlobals : public Pass {
std::map<Name, Literal> constantGlobals;
for (auto& global : module->globals) {
if (!global->imported()) {
- if (auto* c = global->init->dynCast<Const>()) {
- constantGlobals[global->name] = c->value;
+ if (Properties::isConstantExpression(global->init)) {
+ constantGlobals[global->name] =
+ getLiteralFromConstExpression(global->init);
} else if (auto* get = global->init->dynCast<GlobalGet>()) {
auto iter = constantGlobals.find(get->name);
if (iter != constantGlobals.end()) {
Builder builder(*module);
- global->init = builder.makeConst(iter->second);
+ global->init = builder.makeConstExpression(iter->second);
}
}
}
@@ -268,7 +271,7 @@ struct SimplifyGlobals : public Pass {
NameSet constantGlobals;
for (auto& global : module->globals) {
if (!global->mutable_ && !global->imported() &&
- global->init->is<Const>()) {
+ Properties::isConstantExpression(global->init)) {
constantGlobals.insert(global->name);
}
}
diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp
index a3fa4a34d..a952f8a38 100644
--- a/src/passes/SimplifyLocals.cpp
+++ b/src/passes/SimplifyLocals.cpp
@@ -546,7 +546,6 @@ struct SimplifyLocals
auto* blockLocalSetPointer = sinkables.at(sharedIndex).item;
auto* value = (*blockLocalSetPointer)->template cast<LocalSet>()->value;
block->list[block->list.size() - 1] = value;
- block->type = value->type;
ExpressionManipulator::nop(*blockLocalSetPointer);
for (size_t j = 0; j < breaks.size(); j++) {
// move break local.set's value to the break
@@ -577,6 +576,7 @@ struct SimplifyLocals
this->replaceCurrent(newLocalSet);
sinkables.clear();
anotherCycle = true;
+ block->finalize();
}
// optimize local.sets from both sides of an if into a return value
@@ -915,6 +915,7 @@ struct SimplifyLocals
void visitLocalSet(LocalSet* curr) {
// Remove trivial copies, even through a tee
auto* value = curr->value;
+ Function* func = this->getFunction();
while (auto* subSet = value->dynCast<LocalSet>()) {
value = subSet->value;
}
@@ -929,7 +930,8 @@ struct SimplifyLocals
}
anotherCycle = true;
}
- } else {
+ } else if (func->getLocalType(curr->index) ==
+ func->getLocalType(get->index)) {
// There is a new equivalence now.
equivalences.reset(curr->index);
equivalences.add(curr->index, get->index);
diff --git a/src/passes/opt-utils.h b/src/passes/opt-utils.h
index 93fac137f..7912a7d92 100644
--- a/src/passes/opt-utils.h
+++ b/src/passes/opt-utils.h
@@ -54,19 +54,22 @@ inline void optimizeAfterInlining(std::unordered_set<Function*>& funcs,
module->updateMaps();
}
-struct CallTargetReplacer : public WalkerPass<PostWalker<CallTargetReplacer>> {
+struct FunctionRefReplacer
+ : public WalkerPass<PostWalker<FunctionRefReplacer>> {
bool isFunctionParallel() override { return true; }
using MaybeReplace = std::function<void(Name&)>;
- CallTargetReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {}
+ FunctionRefReplacer(MaybeReplace maybeReplace) : maybeReplace(maybeReplace) {}
- CallTargetReplacer* create() override {
- return new CallTargetReplacer(maybeReplace);
+ FunctionRefReplacer* create() override {
+ return new FunctionRefReplacer(maybeReplace);
}
void visitCall(Call* curr) { maybeReplace(curr->target); }
+ void visitRefFunc(RefFunc* curr) { maybeReplace(curr->func); }
+
private:
MaybeReplace maybeReplace;
};
@@ -81,7 +84,7 @@ inline void replaceFunctions(PassRunner* runner,
}
};
// replace direct calls
- CallTargetReplacer(maybeReplace).run(runner, &module);
+ FunctionRefReplacer(maybeReplace).run(runner, &module);
// replace in table
for (auto& segment : module.table.segments) {
for (auto& name : segment.data) {
diff --git a/src/shell-interface.h b/src/shell-interface.h
index 52533f37c..75f8e81b8 100644
--- a/src/shell-interface.h
+++ b/src/shell-interface.h
@@ -114,10 +114,12 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
break;
case v128:
assert(false && "v128 not implemented yet");
+ case funcref:
case anyref:
- assert(false && "anyref not implemented yet");
+ case nullref:
case exnref:
- assert(false && "exnref not implemented yet");
+ globals[import->name] = Literal::makeNullref();
+ break;
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -163,7 +165,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface {
trap("callIndirect: bad # of arguments");
}
for (size_t i = 0; i < params.size(); i++) {
- if (params[i] != arguments[i].type) {
+ if (!Type::isSubType(arguments[i].type, params[i])) {
trap("callIndirect: bad argument type");
}
}
diff --git a/src/support/name.h b/src/support/name.h
index 2bc50abf0..615740e09 100644
--- a/src/support/name.h
+++ b/src/support/name.h
@@ -17,7 +17,7 @@
#ifndef wasm_support_name_h
#define wasm_support_name_h
-#include <cstring>
+#include <string>
#include "emscripten-optimizer/istring.h"
diff --git a/src/support/small_vector.h b/src/support/small_vector.h
index 7f00bd4a6..d4ad961a7 100644
--- a/src/support/small_vector.h
+++ b/src/support/small_vector.h
@@ -38,17 +38,15 @@ template<typename T, size_t N> class SmallVector {
std::vector<T> flexible;
public:
+ using value_type = T;
+
SmallVector() {}
T& operator[](size_t i) {
- if (i < N) {
- return fixed[i];
- } else {
- return flexible[i - N];
- }
+ return const_cast<T&>(static_cast<const SmallVector<T, N>&>(*this)[i]);
}
- T operator[](size_t i) const {
+ const T& operator[](size_t i) const {
if (i < N) {
return fixed[i];
} else {
diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h
index c0c7428cc..7787dba25 100644
--- a/src/tools/execution-results.h
+++ b/src/tools/execution-results.h
@@ -69,11 +69,17 @@ struct ExecutionResults {
auto* func = wasm.getFunction(exp->value);
if (func->sig.results != Type::none) {
// this has a result
- results[exp->name] = run(func, wasm, instance);
- // ignore the result if we hit an unreachable and returned no value
- if (results[exp->name].type.isConcrete()) {
- std::cout << "[fuzz-exec] note result: " << exp->name << " => "
- << results[exp->name] << '\n';
+ Literal ret = run(func, wasm, instance);
+ // We cannot compare funcrefs by name because function names can
+ // change (after duplicate function elimination or roundtripping)
+ // while the function contents are still the same
+ if (ret.type != Type::funcref) {
+ results[exp->name] = ret;
+ // ignore the result if we hit an unreachable and returned no value
+ if (results[exp->name].type.isConcrete()) {
+ std::cout << "[fuzz-exec] note result: " << exp->name << " => "
+ << results[exp->name] << '\n';
+ }
}
} else {
// no result, run it anyhow (it might modify memory etc.)
@@ -100,17 +106,17 @@ struct ExecutionResults {
auto name = iter.first;
if (results.find(name) == results.end()) {
std::cout << "[fuzz-exec] missing " << name << '\n';
- abort();
+ return false;
}
std::cout << "[fuzz-exec] comparing " << name << '\n';
if (results[name] != other.results[name]) {
std::cout << "not identical!\n";
- abort();
+ return false;
}
}
if (loggings != other.loggings) {
std::cout << "logging not identical!\n";
- abort();
+ return false;
}
return true;
}
@@ -138,7 +144,7 @@ struct ExecutionResults {
// call the method
for (Type param : func->sig.params.expand()) {
// zeros in arguments TODO: more?
- arguments.push_back(Literal(param));
+ arguments.push_back(Literal::makeZero(param));
}
return instance.callFunction(func->name, arguments);
} catch (const TrapException&) {
diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h
index ce302fac6..ff0888f1d 100644
--- a/src/tools/fuzzing.h
+++ b/src/tools/fuzzing.h
@@ -25,8 +25,7 @@ high chance for set at start of loop
high chance of a tee in that case => loop var
*/
-// TODO Complete exnref type support. Its support is partialy implemented
-// and the type is currently not generated in fuzzed programs yet.
+// TODO Generate exception handling instructions
#include "ir/memory-utils.h"
#include <ir/find_all.h>
@@ -310,6 +309,24 @@ private:
double getDouble() { return Literal(get64()).reinterpretf64(); }
+ SmallVector<Type, 2> getSubTypes(Type type) {
+ SmallVector<Type, 2> ret;
+ ret.push_back(type); // includes itself
+ switch (type) {
+ case Type::anyref:
+ ret.push_back(Type::funcref);
+ ret.push_back(Type::exnref);
+ // falls through
+ case Type::funcref:
+ case Type::exnref:
+ ret.push_back(Type::nullref);
+ break;
+ default:
+ break;
+ }
+ return ret;
+ }
+
void setupMemory() {
// Add memory itself
MemoryUtils::ensureExists(wasm.memory);
@@ -404,10 +421,12 @@ private:
Index num = upTo(3);
for (size_t i = 0; i < num; i++) {
// Events should have void return type and at least one param type
+ Type type = getConcreteType();
std::vector<Type> params;
+ params.push_back(type);
Index numValues = upToSquared(MAX_PARAMS - 1);
for (Index i = 0; i < numValues + 1; i++) {
- params.push_back(pick(i32, i64, f32, f64));
+ params.push_back(getConcreteType());
}
auto* event = builder.makeEvent(std::string("event$") + std::to_string(i),
WASM_EVENT_ATTRIBUTE_EXCEPTION,
@@ -447,7 +466,7 @@ private:
}
void addImportLoggingSupport() {
- for (auto type : getConcreteTypes()) {
+ for (auto type : getLoggableTypes()) {
auto* func = new Function;
Name name = std::string("log-") + type.toString();
func->name = name;
@@ -501,7 +520,7 @@ private:
// function generation state
- Function* func;
+ Function* func = nullptr;
std::vector<Expression*> breakableStack; // things we can break to
Index labelIndex;
@@ -585,10 +604,12 @@ private:
// loop limit
FindAll<Loop> loops(func->body);
for (auto* loop : loops.list) {
- loop->body = builder.makeSequence(makeHangLimitCheck(), loop->body);
+ loop->body =
+ builder.makeSequence(makeHangLimitCheck(), loop->body, loop->type);
}
// recursion limit
- func->body = builder.makeSequence(makeHangLimitCheck(), func->body);
+ func->body =
+ builder.makeSequence(makeHangLimitCheck(), func->body, func->sig.results);
}
void recombine(Function* func) {
@@ -841,7 +862,9 @@ private:
case f32:
case f64:
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
ret = _makeConcrete(type);
break;
@@ -852,7 +875,8 @@ private:
ret = _makeunreachable();
break;
}
- assert(ret->type == type); // we should create the right type of thing
+ // we should create the right type of thing
+ assert(Type::isSubType(ret->type, type));
nesting--;
return ret;
}
@@ -898,9 +922,12 @@ private:
&Self::makeSelect,
&Self::makeGlobalGet)
.add(FeatureSet::SIMD, &Self::makeSIMD);
- if (type == i32 || type == i64) {
+ if (type == Type::i32 || type == Type::i64) {
options.add(FeatureSet::Atomics, &Self::makeAtomic);
}
+ if (type == Type::i32) {
+ options.add(FeatureSet::ReferenceTypes, &Self::makeRefIsNull);
+ }
return (this->*pick(options))(type);
}
@@ -1064,11 +1091,11 @@ private:
// possible branch back
list.push_back(builder.makeBreak(ret->name, nullptr, makeCondition()));
list.push_back(make(type)); // final element, so we have the right type
- ret->body = builder.makeBlock(list);
+ ret->body = builder.makeBlock(list, type);
}
breakableStack.pop_back();
hangStack.pop_back();
- ret->finalize();
+ ret->finalize(type);
return ret;
}
@@ -1093,15 +1120,15 @@ private:
}
}
- Expression* buildIf(const struct ThreeArgs& args) {
- return builder.makeIf(args.a, args.b, args.c);
+ Expression* buildIf(const struct ThreeArgs& args, Type type) {
+ return builder.makeIf(args.a, args.b, args.c, type);
}
Expression* makeIf(Type type) {
auto* condition = makeCondition();
hangStack.push_back(nullptr);
auto* ret =
- buildIf({condition, makeMaybeBlock(type), makeMaybeBlock(type)});
+ buildIf({condition, makeMaybeBlock(type), makeMaybeBlock(type)}, type);
hangStack.pop_back();
return ret;
}
@@ -1360,8 +1387,10 @@ private:
return builder.makeLoad(
16, false, offset, pick(1, 2, 4, 8, 16), ptr, type);
}
- case anyref: // anyref cannot be loaded from memory
- case exnref: // exnref cannot be loaded from memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("invalid type");
@@ -1370,8 +1399,8 @@ private:
}
Expression* makeLoad(Type type) {
- // exnref type cannot be stored in memory
- if (!allowMemory || type == exnref) {
+ // reference types cannot be stored in memory
+ if (!allowMemory || type.isRef()) {
return makeTrivial(type);
}
auto* ret = makeNonAtomicLoad(type);
@@ -1393,7 +1422,7 @@ private:
Expression* makeNonAtomicStore(Type type) {
if (type == unreachable) {
// make a normal store, then make it unreachable
- auto* ret = makeNonAtomicStore(getConcreteType());
+ auto* ret = makeNonAtomicStore(getStorableType());
auto* store = ret->dynCast<Store>();
if (!store) {
return ret;
@@ -1416,7 +1445,7 @@ private:
// the type is none or unreachable. we also need to pick the value
// type.
if (type == none) {
- type = getConcreteType();
+ type = getStorableType();
}
auto offset = logify(get());
auto ptr = makePointer();
@@ -1462,8 +1491,10 @@ private:
return builder.makeStore(
16, offset, pick(1, 2, 4, 8, 16), ptr, value, type);
}
- case anyref: // anyref cannot be stored in memory
- case exnref: // exnref cannot be stored in memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("invalid type");
@@ -1472,7 +1503,6 @@ private:
}
Expression* makeStore(Type type) {
- // exnref type cannot be stored in memory
if (!allowMemory || type.isRef()) {
return makeTrivial(type);
}
@@ -1558,8 +1588,10 @@ private:
case f64:
return Literal(getDouble());
case v128:
- case anyref: // anyref cannot have literals
- case exnref: // exnref cannot have literals
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("invalid type");
@@ -1601,8 +1633,10 @@ private:
case f64:
return Literal(double(small));
case v128:
- case anyref: // anyref cannot have literals
- case exnref: // exnref cannot have literals
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1667,8 +1701,10 @@ private:
std::numeric_limits<uint64_t>::max()));
break;
case v128:
- case anyref: // anyref cannot have literals
- case exnref: // exnref cannot have literals
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1699,8 +1735,10 @@ private:
value = Literal(double(int64_t(1) << upTo(64)));
break;
case v128:
- case anyref: // anyref cannot have literals
- case exnref: // exnref cannot have literals
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1724,21 +1762,23 @@ private:
}
Expression* makeConst(Type type) {
- switch (type) {
- case anyref:
- // There's no anyref.const.
- // TODO We should return a nullref once we implement instructions for
- // reference types proposal.
- assert(false && "anyref const is not implemented yet");
- case exnref:
- // There's no exnref.const.
- // TODO We should return a nullref once we implement instructions for
- // reference types proposal.
- assert(false && "exnref const is not implemented yet");
- default:
- break;
+ if (type.isRef()) {
+ assert(wasm.features.hasReferenceTypes());
+ // Check if we can use ref.func.
+ // 'func' is the pointer to the last created function and can be null when
+ // we set up globals (before we create any functions), in which case we
+ // can't use ref.func.
+ if (type == Type::funcref && func && oneIn(2)) {
+ // First set to target to the last created function, and try to select
+ // among other existing function if possible
+ Function* target = func;
+ if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) {
+ target = pick(wasm.functions).get();
+ }
+ return builder.makeRefFunc(target->name);
+ }
+ return builder.makeRefNull();
}
-
auto* ret = wasm.allocator.alloc<Const>();
ret->value = makeLiteral(type);
ret->type = type;
@@ -1757,9 +1797,9 @@ private:
// give up
return makeTrivial(type);
}
- // There's no binary ops for exnref
- if (type == exnref) {
- makeTrivial(type);
+ // There's no unary ops for reference types
+ if (type.isRef()) {
+ return makeTrivial(type);
}
switch (type) {
@@ -1807,8 +1847,11 @@ private:
AllTrueVecI64x2),
make(v128)});
}
- case anyref: // there's no unary ops for anyref
- case exnref: // there's no unary ops for exnref
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
+ return makeTrivial(type);
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1947,8 +1990,10 @@ private:
}
WASM_UNREACHABLE("invalid value");
}
- case anyref: // there's no unary ops for anyref
- case exnref: // there's no unary ops for exnref
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1969,9 +2014,9 @@ private:
// give up
return makeTrivial(type);
}
- // There's no binary ops for exnref
+ // There's no binary ops for reference types
if (type.isRef()) {
- makeTrivial(type);
+ return makeTrivial(type);
}
switch (type) {
@@ -2180,8 +2225,10 @@ private:
make(v128),
make(v128)});
}
- case anyref: // there's no binary ops for anyref
- case exnref: // there's no binary ops for exnref
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -2189,12 +2236,15 @@ private:
WASM_UNREACHABLE("invalid type");
}
- Expression* buildSelect(const ThreeArgs& args) {
- return builder.makeSelect(args.a, args.b, args.c);
+ Expression* buildSelect(const ThreeArgs& args, Type type) {
+ return builder.makeSelect(args.a, args.b, args.c, type);
}
Expression* makeSelect(Type type) {
- return makeDeNanOp(buildSelect({make(i32), make(type), make(type)}));
+ Type subType1 = pick(getSubTypes(type));
+ Type subType2 = pick(getSubTypes(type));
+ return makeDeNanOp(
+ buildSelect({make(i32), make(subType1), make(subType2)}, type));
}
Expression* makeSwitch(Type type) {
@@ -2338,6 +2388,9 @@ private:
Expression* makeSIMD(Type type) {
assert(wasm.features.hasSIMD());
+ if (type.isRef()) {
+ return makeTrivial(type);
+ }
if (type != v128) {
return makeSIMDExtract(type);
}
@@ -2380,7 +2433,9 @@ private:
op = ExtractLaneVecF64x2;
break;
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
case none:
case unreachable:
@@ -2549,6 +2604,18 @@ private:
WASM_UNREACHABLE("invalid value");
}
+ Expression* makeRefIsNull(Type type) {
+ assert(type == Type::i32);
+ assert(wasm.features.hasReferenceTypes());
+ Type refType;
+ if (wasm.features.hasExceptionHandling()) {
+ refType = pick(Type::funcref, Type::anyref, Type::nullref, Type::exnref);
+ } else {
+ refType = pick(Type::funcref, Type::anyref, Type::nullref);
+ }
+ return builder.makeRefIsNull(make(refType));
+ }
+
Expression* makeMemoryInit() {
if (!allowMemory) {
return makeTrivial(none);
@@ -2593,7 +2660,7 @@ private:
// special makers
Expression* makeLogging() {
- auto type = getConcreteType();
+ auto type = getLoggableType();
return builder.makeCall(
std::string("log-") + type.toString(), {make(type)}, none);
}
@@ -2605,20 +2672,64 @@ private:
// special getters
- Type getReachableType() {
- return pick(FeatureOptions<Type>()
- .add(FeatureSet::MVP, i32, i64, f32, f64, none)
- .add(FeatureSet::SIMD, v128));
- }
+ std::vector<Type> getReachableTypes() {
+ return items(FeatureOptions<Type>()
+ .add(FeatureSet::MVP,
+ Type::i32,
+ Type::i64,
+ Type::f32,
+ Type::f64,
+ Type::none)
+ .add(FeatureSet::SIMD, Type::v128)
+ .add(FeatureSet::ReferenceTypes,
+ Type::funcref,
+ Type::anyref,
+ Type::nullref)
+ .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes |
+ FeatureSet::ExceptionHandling),
+ Type::exnref));
+ }
+ Type getReachableType() { return pick(getReachableTypes()); }
std::vector<Type> getConcreteTypes() {
- return items(FeatureOptions<Type>()
- .add(FeatureSet::MVP, i32, i64, f32, f64)
- .add(FeatureSet::SIMD, v128));
+ return items(
+ FeatureOptions<Type>()
+ .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64)
+ .add(FeatureSet::SIMD, Type::v128)
+ .add(FeatureSet::ReferenceTypes,
+ Type::funcref,
+ Type::anyref,
+ Type::nullref)
+ .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes |
+ FeatureSet::ExceptionHandling),
+ Type::exnref));
}
-
Type getConcreteType() { return pick(getConcreteTypes()); }
+ // Get types that can be stored in memory
+ std::vector<Type> getStorableTypes() {
+ return items(
+ FeatureOptions<Type>()
+ .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64)
+ .add(FeatureSet::SIMD, Type::v128));
+ }
+ Type getStorableType() { return pick(getStorableTypes()); }
+
+ // - funcref cannot be logged because referenced functions can be inlined or
+ // removed during optimization
+ // - there's no point in logging anyref because it is opaque
+ std::vector<Type> getLoggableTypes() {
+ return items(
+ FeatureOptions<Type>()
+ .add(FeatureSet::MVP, Type::i32, Type::i64, Type::f32, Type::f64)
+ .add(FeatureSet::SIMD, Type::v128)
+ .add(FeatureSet::ReferenceTypes, Type::nullref)
+ .add((FeatureSet::Feature)(FeatureSet::ReferenceTypes |
+ FeatureSet::ExceptionHandling),
+ Type::exnref));
+ }
+ Type getLoggableType() { return pick(getLoggableTypes()); }
+
// statistical distributions
// 0 to the limit, logarithmic scale
@@ -2659,8 +2770,8 @@ private:
// low values
Index upToSquared(Index x) { return upTo(upTo(x)); }
- // pick from a vector
- template<typename T> const T& pick(const std::vector<T>& vec) {
+ // pick from a vector-like container
+ template<typename T> const typename T::value_type& pick(const T& vec) {
assert(!vec.empty());
auto index = upTo(vec.size());
return vec[index];
diff --git a/src/tools/spec-wrapper.h b/src/tools/spec-wrapper.h
index beada1b4b..f59291e55 100644
--- a/src/tools/spec-wrapper.h
+++ b/src/tools/spec-wrapper.h
@@ -48,8 +48,12 @@ static std::string generateSpecWrapper(Module& wasm) {
case v128:
ret += "(v128.const i32x4 0 0 0 0)";
break;
- case anyref: // there's no anyref.const
- case exnref: // there's no exnref.const
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
+ ret += "(ref.null)";
+ break;
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp
index 274b6de29..6adb1e174 100644
--- a/src/tools/wasm-reduce.cpp
+++ b/src/tools/wasm-reduce.cpp
@@ -592,7 +592,9 @@ struct Reducer
fixed = builder->makeUnary(TruncSFloat64ToInt32, child);
break;
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
continue; // not implemented yet
case none:
@@ -615,7 +617,9 @@ struct Reducer
fixed = builder->makeUnary(TruncSFloat64ToInt64, child);
break;
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
continue; // not implemented yet
case none:
@@ -638,7 +642,9 @@ struct Reducer
fixed = builder->makeUnary(DemoteFloat64, child);
break;
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
continue; // not implemented yet
case none:
@@ -661,7 +667,9 @@ struct Reducer
case f64:
WASM_UNREACHABLE("unexpected type");
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
continue; // not implemented yet
case none:
@@ -671,7 +679,9 @@ struct Reducer
break;
}
case v128:
+ case funcref:
case anyref:
+ case nullref:
case exnref:
continue; // not implemented yet
case none:
@@ -999,6 +1009,10 @@ struct Reducer
return false;
}
// try to replace with a trivial value
+ if (curr->type.isRef()) {
+ RefNull* n = builder->makeRefNull();
+ return tryToReplaceCurrent(n);
+ }
Const* c = builder->makeConst(Literal(int32_t(0)));
if (tryToReplaceCurrent(c)) {
return true;
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp
index d5ee60d8a..6c9d3f36a 100644
--- a/src/tools/wasm-shell.cpp
+++ b/src/tools/wasm-shell.cpp
@@ -74,7 +74,7 @@ struct Operation {
name = element[i++]->str();
for (size_t j = i; j < element.size(); j++) {
Expression* argument = builder.parseExpression(*element[j]);
- arguments.push_back(argument->dynCast<Const>()->value);
+ arguments.push_back(getLiteralFromConstExpression(argument));
}
}
@@ -214,7 +214,7 @@ static void run_asserts(Name moduleName,
assert(!trapped);
if (curr.size() >= 3) {
Literal expected =
- builder->parseExpression(*curr[2])->dynCast<Const>()->value;
+ getLiteralFromConstExpression(builder->parseExpression(*curr[2]));
std::cerr << "seen " << result << ", expected " << expected << '\n';
if (expected != result) {
std::cout << "unexpected, should be identical\n";
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 4206defdf..f019d0792 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -343,10 +343,12 @@ enum EncodedType {
f32 = -0x3, // 0x7d
f64 = -0x4, // 0x7c
v128 = -0x5, // 0x7b
- // elem_type
- AnyFunc = -0x10, // 0x70
+ // function reference type
+ funcref = -0x10, // 0x70
// opaque reference type
anyref = -0x11, // 0x6f
+ // null reference type
+ nullref = -0x12, // 0x6e
// exception reference type
exnref = -0x18, // 0x68
// func_type form
@@ -402,6 +404,7 @@ enum ASTNodes {
Drop = 0x1a,
Select = 0x1b,
+ SelectWithType = 0x1c, // added in reference types proposal
LocalGet = 0x20,
LocalSet = 0x21,
@@ -867,6 +870,12 @@ enum ASTNodes {
MemoryCopy = 0x0a,
MemoryFill = 0x0b,
+ // reference types opcodes
+
+ RefNull = 0xd0,
+ RefIsNull = 0xd1,
+ RefFunc = 0xd2,
+
// exception handling opcodes
Try = 0x06,
@@ -914,9 +923,15 @@ inline S32LEB binaryType(Type type) {
case v128:
ret = BinaryConsts::EncodedType::v128;
break;
+ case funcref:
+ ret = BinaryConsts::EncodedType::funcref;
+ break;
case anyref:
ret = BinaryConsts::EncodedType::anyref;
break;
+ case nullref:
+ ret = BinaryConsts::EncodedType::nullref;
+ break;
case exnref:
ret = BinaryConsts::EncodedType::exnref;
break;
@@ -1143,8 +1158,8 @@ public:
// we store function imports here before wasm.addFunctionImport after we know
// their names
std::vector<Function*> functionImports;
- // at index i we have all calls to the function i
- std::map<Index, std::vector<Call*>> functionCalls;
+ // at index i we have all refs to the function i
+ std::map<Index, std::vector<Expression*>> functionRefs;
Function* currFunction = nullptr;
// before we see a function (like global init expressions), there is no end of
// function to check
@@ -1279,12 +1294,15 @@ public:
bool maybeVisitDataDrop(Expression*& out, uint32_t code);
bool maybeVisitMemoryCopy(Expression*& out, uint32_t code);
bool maybeVisitMemoryFill(Expression*& out, uint32_t code);
- void visitSelect(Select* curr);
+ void visitSelect(Select* curr, uint8_t code);
void visitReturn(Return* curr);
bool maybeVisitHost(Expression*& out, uint8_t code);
void visitNop(Nop* curr);
void visitUnreachable(Unreachable* curr);
void visitDrop(Drop* curr);
+ void visitRefNull(RefNull* curr);
+ void visitRefIsNull(RefIsNull* curr);
+ void visitRefFunc(RefFunc* curr);
void visitTry(Try* curr);
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 918e6a4ab..38009cb8a 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -110,6 +110,12 @@ public:
ret->finalize();
return ret;
}
+ Block* makeBlock(const std::vector<Expression*>& items, Type type) {
+ auto* ret = allocator.alloc<Block>();
+ ret->list.set(items);
+ ret->finalize(type);
+ return ret;
+ }
Block* makeBlock(const ExpressionList& items) {
auto* ret = allocator.alloc<Block>();
ret->list.set(items);
@@ -164,6 +170,13 @@ public:
ret->finalize();
return ret;
}
+ Loop* makeLoop(Name name, Expression* body, Type type) {
+ auto* ret = allocator.alloc<Loop>();
+ ret->name = name;
+ ret->body = body;
+ ret->finalize(type);
+ return ret;
+ }
Break* makeBreak(Name name,
Expression* value = nullptr,
Expression* condition = nullptr) {
@@ -459,6 +472,7 @@ public:
return ret;
}
Const* makeConst(Literal value) {
+ assert(value.type.isNumber());
auto* ret = allocator.alloc<Const>();
ret->value = value;
ret->type = value.type;
@@ -488,6 +502,17 @@ public:
ret->finalize();
return ret;
}
+ Select* makeSelect(Expression* condition,
+ Expression* ifTrue,
+ Expression* ifFalse,
+ Type type) {
+ auto* ret = allocator.alloc<Select>();
+ ret->condition = condition;
+ ret->ifTrue = ifTrue;
+ ret->ifFalse = ifFalse;
+ ret->finalize(type);
+ return ret;
+ }
Return* makeReturn(Expression* value = nullptr) {
auto* ret = allocator.alloc<Return>();
ret->value = value;
@@ -502,6 +527,23 @@ public:
ret->finalize();
return ret;
}
+ RefNull* makeRefNull() {
+ auto* ret = allocator.alloc<RefNull>();
+ ret->finalize();
+ return ret;
+ }
+ RefIsNull* makeRefIsNull(Expression* value) {
+ auto* ret = allocator.alloc<RefIsNull>();
+ ret->value = value;
+ ret->finalize();
+ return ret;
+ }
+ RefFunc* makeRefFunc(Name func) {
+ auto* ret = allocator.alloc<RefFunc>();
+ ret->func = func;
+ ret->finalize();
+ return ret;
+ }
Try* makeTry(Expression* body, Expression* catchBody) {
auto* ret = allocator.alloc<Try>();
ret->body = body;
@@ -569,6 +611,21 @@ public:
return ret;
}
+ Expression* makeConstExpression(Literal value) {
+ switch (value.type) {
+ case Type::nullref:
+ return makeRefNull();
+ case Type::funcref:
+ if (value.getFunc()[0] != 0) {
+ return makeRefFunc(value.getFunc());
+ }
+ return makeRefNull();
+ default:
+ assert(value.type.isNumber());
+ return makeConst(value);
+ }
+ }
+
// Additional utility functions for building on top of nodes
// Convenient to have these on Builder, as it has allocation built in
@@ -663,6 +720,13 @@ public:
return block;
}
+ Block* makeSequence(Expression* left, Expression* right, Type type) {
+ auto* block = makeBlock(left);
+ block->list.push_back(right);
+ block->finalize(type);
+ return block;
+ }
+
// Grab a slice out of a block, replacing it with nops, and returning
// either another block with the contents (if more than 1) or a single
// expression
@@ -728,16 +792,15 @@ public:
value = Literal(bytes.data());
break;
}
+ case funcref:
case anyref:
- // TODO Implement and return nullref
- assert(false && "anyref not implemented yet");
+ case nullref:
case exnref:
- // TODO Implement and return nullref
- assert(false && "exnref not implemented yet");
+ return ExpressionManipulator::refNull(curr);
case none:
return ExpressionManipulator::nop(curr);
case unreachable:
- return ExpressionManipulator::convert<T, Unreachable>(curr);
+ return ExpressionManipulator::unreachable(curr);
}
return makeConst(value);
}
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index 571f0d1a5..f37a6edd6 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -143,13 +143,13 @@ public:
if (!ret.breaking() &&
(curr->type.isConcrete() || ret.value.type.isConcrete())) {
#if 1 // def WASM_INTERPRETER_DEBUG
- if (ret.value.type != curr->type) {
+ if (!Type::isSubType(ret.value.type, curr->type)) {
std::cerr << "expected " << curr->type << ", seeing " << ret.value.type
<< " from\n"
<< curr << '\n';
}
#endif
- assert(ret.value.type == curr->type);
+ assert(Type::isSubType(ret.value.type, curr->type));
}
depth--;
return ret;
@@ -1095,7 +1095,7 @@ public:
return Literal(uint64_t(val));
}
}
- Flow visitAtomicFence(AtomicFence*) {
+ Flow visitAtomicFence(AtomicFence* curr) {
// Wasm currently supports only sequentially consistent atomics, in which
// case atomic_fence can be lowered to nothing.
NOTE_ENTER("AtomicFence");
@@ -1123,6 +1123,26 @@ public:
Flow visitSIMDLoadExtend(SIMDLoad*) { WASM_UNREACHABLE("unimp"); }
Flow visitPush(Push*) { WASM_UNREACHABLE("unimp"); }
Flow visitPop(Pop*) { WASM_UNREACHABLE("unimp"); }
+ Flow visitRefNull(RefNull* curr) {
+ NOTE_ENTER("RefNull");
+ return Literal::makeNullref();
+ }
+ Flow visitRefIsNull(RefIsNull* curr) {
+ NOTE_ENTER("RefIsNull");
+ Flow flow = visit(curr->value);
+ if (flow.breaking()) {
+ return flow;
+ }
+ Literal value = flow.value;
+ NOTE_EVAL1(value);
+ return Literal(value.type == nullref);
+ }
+ Flow visitRefFunc(RefFunc* curr) {
+ NOTE_ENTER("RefFunc");
+ NOTE_NAME(curr->func);
+ return Literal::makeFuncref(curr->func);
+ }
+ // TODO Implement EH instructions
Flow visitTry(Try*) { WASM_UNREACHABLE("unimp"); }
Flow visitThrow(Throw*) { WASM_UNREACHABLE("unimp"); }
Flow visitRethrow(Rethrow*) { WASM_UNREACHABLE("unimp"); }
@@ -1217,8 +1237,10 @@ public:
return Literal(load64u(addr)).castToF64();
case v128:
return Literal(load128(addr).data());
- case anyref: // anyref cannot be loaded from memory
- case exnref: // exnref cannot be loaded from memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1272,8 +1294,10 @@ public:
case v128:
store128(addr, value.getv128());
break;
- case anyref: // anyref cannot be stored from memory
- case exnref: // exnref cannot be stored in memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1464,7 +1488,7 @@ private:
for (size_t i = 0; i < function->getNumLocals(); i++) {
if (i < arguments.size()) {
assert(i < params.size());
- if (params[i] != arguments[i].type) {
+ if (!Type::isSubType(arguments[i].type, params[i])) {
std::cerr << "Function `" << function->name << "` expects type "
<< params[i] << " for parameter " << i << ", got "
<< arguments[i].type << "." << std::endl;
@@ -1473,7 +1497,7 @@ private:
locals[i] = arguments[i];
} else {
assert(function->isVar(i));
- locals[i].type = function->getLocalType(i);
+ locals[i] = Literal::makeZero(function->getLocalType(i));
}
}
}
@@ -1580,7 +1604,8 @@ private:
}
NOTE_EVAL1(index);
NOTE_EVAL1(flow.value);
- assert(curr->isTee() ? flow.value.type == curr->type : true);
+ assert(curr->isTee() ? Type::isSubType(flow.value.type, curr->type)
+ : true);
scope.locals[index] = flow.value;
return curr->isTee() ? flow : Flow();
}
@@ -2067,7 +2092,7 @@ public:
// cannot still be breaking, it means we missed our stop
assert(!flow.breaking() || flow.breakTo == RETURN_FLOW);
Literal ret = flow.value;
- if (function->sig.results != ret.type) {
+ if (!Type::isSubType(ret.type, function->sig.results)) {
std::cerr << "calling " << function->name << " resulted in " << ret
<< " but the function type is " << function->sig.results
<< '\n';
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index d7324d756..8cdcb88f4 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -225,6 +225,9 @@ private:
Expression* makeBreak(Element& s);
Expression* makeBreakTable(Element& s);
Expression* makeReturn(Element& s);
+ Expression* makeRefNull(Element& s);
+ Expression* makeRefIsNull(Element& s);
+ Expression* makeRefFunc(Element& s);
Expression* makeTry(Element& s);
Expression* makeCatch(Element& s, Type type);
Expression* makeThrow(Element& s);
diff --git a/src/wasm-stack.h b/src/wasm-stack.h
index fbd28b0d5..91c0c5383 100644
--- a/src/wasm-stack.h
+++ b/src/wasm-stack.h
@@ -128,6 +128,9 @@ public:
void visitSelect(Select* curr);
void visitReturn(Return* curr);
void visitHost(Host* curr);
+ void visitRefNull(RefNull* curr);
+ void visitRefIsNull(RefIsNull* curr);
+ void visitRefFunc(RefFunc* curr);
void visitTry(Try* curr);
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
@@ -207,6 +210,9 @@ public:
void visitSelect(Select* curr);
void visitReturn(Return* curr);
void visitHost(Host* curr);
+ void visitRefNull(RefNull* curr);
+ void visitRefIsNull(RefIsNull* curr);
+ void visitRefFunc(RefFunc* curr);
void visitTry(Try* curr);
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
@@ -698,6 +704,30 @@ void BinaryenIRWriter<SubType>::visitHost(Host* curr) {
emit(curr);
}
+template<typename SubType>
+void BinaryenIRWriter<SubType>::visitRefNull(RefNull* curr) {
+ emit(curr);
+}
+
+template<typename SubType>
+void BinaryenIRWriter<SubType>::visitRefIsNull(RefIsNull* curr) {
+ visit(curr->value);
+ if (curr->type == Type::unreachable) {
+ emitUnreachable();
+ return;
+ }
+ emit(curr);
+}
+
+template<typename SubType>
+void BinaryenIRWriter<SubType>::visitRefFunc(RefFunc* curr) {
+ if (curr->type == Type::unreachable) {
+ emitUnreachable();
+ return;
+ }
+ emit(curr);
+}
+
template<typename SubType> void BinaryenIRWriter<SubType>::visitTry(Try* curr) {
emit(curr);
visitPossibleBlockContents(curr->body);
diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h
index 9c6e78360..c9290cbab 100644
--- a/src/wasm-traversal.h
+++ b/src/wasm-traversal.h
@@ -72,6 +72,9 @@ template<typename SubType, typename ReturnType = void> struct Visitor {
ReturnType visitDrop(Drop* curr) { return ReturnType(); }
ReturnType visitReturn(Return* curr) { return ReturnType(); }
ReturnType visitHost(Host* curr) { return ReturnType(); }
+ ReturnType visitRefNull(RefNull* curr) { return ReturnType(); }
+ ReturnType visitRefIsNull(RefIsNull* curr) { return ReturnType(); }
+ ReturnType visitRefFunc(RefFunc* curr) { return ReturnType(); }
ReturnType visitTry(Try* curr) { return ReturnType(); }
ReturnType visitThrow(Throw* curr) { return ReturnType(); }
ReturnType visitRethrow(Rethrow* curr) { return ReturnType(); }
@@ -167,6 +170,12 @@ template<typename SubType, typename ReturnType = void> struct Visitor {
DELEGATE(Return);
case Expression::Id::HostId:
DELEGATE(Host);
+ case Expression::Id::RefNullId:
+ DELEGATE(RefNull);
+ case Expression::Id::RefIsNullId:
+ DELEGATE(RefIsNull);
+ case Expression::Id::RefFuncId:
+ DELEGATE(RefFunc);
case Expression::Id::TryId:
DELEGATE(Try);
case Expression::Id::ThrowId:
@@ -241,6 +250,9 @@ struct OverriddenVisitor {
UNIMPLEMENTED(Drop);
UNIMPLEMENTED(Return);
UNIMPLEMENTED(Host);
+ UNIMPLEMENTED(RefNull);
+ UNIMPLEMENTED(RefIsNull);
+ UNIMPLEMENTED(RefFunc);
UNIMPLEMENTED(Try);
UNIMPLEMENTED(Throw);
UNIMPLEMENTED(Rethrow);
@@ -337,6 +349,12 @@ struct OverriddenVisitor {
DELEGATE(Return);
case Expression::Id::HostId:
DELEGATE(Host);
+ case Expression::Id::RefNullId:
+ DELEGATE(RefNull);
+ case Expression::Id::RefIsNullId:
+ DELEGATE(RefIsNull);
+ case Expression::Id::RefFuncId:
+ DELEGATE(RefFunc);
case Expression::Id::TryId:
DELEGATE(Try);
case Expression::Id::ThrowId:
@@ -476,6 +494,15 @@ struct UnifiedExpressionVisitor : public Visitor<SubType, ReturnType> {
ReturnType visitHost(Host* curr) {
return static_cast<SubType*>(this)->visitExpression(curr);
}
+ ReturnType visitRefNull(RefNull* curr) {
+ return static_cast<SubType*>(this)->visitExpression(curr);
+ }
+ ReturnType visitRefIsNull(RefIsNull* curr) {
+ return static_cast<SubType*>(this)->visitExpression(curr);
+ }
+ ReturnType visitRefFunc(RefFunc* curr) {
+ return static_cast<SubType*>(this)->visitExpression(curr);
+ }
ReturnType visitTry(Try* curr) {
return static_cast<SubType*>(this)->visitExpression(curr);
}
@@ -778,6 +805,15 @@ struct Walker : public VisitorType {
static void doVisitHost(SubType* self, Expression** currp) {
self->visitHost((*currp)->cast<Host>());
}
+ static void doVisitRefNull(SubType* self, Expression** currp) {
+ self->visitRefNull((*currp)->cast<RefNull>());
+ }
+ static void doVisitRefIsNull(SubType* self, Expression** currp) {
+ self->visitRefIsNull((*currp)->cast<RefIsNull>());
+ }
+ static void doVisitRefFunc(SubType* self, Expression** currp) {
+ self->visitRefFunc((*currp)->cast<RefFunc>());
+ }
static void doVisitTry(SubType* self, Expression** currp) {
self->visitTry((*currp)->cast<Try>());
}
@@ -1036,6 +1072,19 @@ struct PostWalker : public Walker<SubType, VisitorType> {
}
break;
}
+ case Expression::Id::RefNullId: {
+ self->pushTask(SubType::doVisitRefNull, currp);
+ break;
+ }
+ case Expression::Id::RefIsNullId: {
+ self->pushTask(SubType::doVisitRefIsNull, currp);
+ self->pushTask(SubType::scan, &curr->cast<RefIsNull>()->value);
+ break;
+ }
+ case Expression::Id::RefFuncId: {
+ self->pushTask(SubType::doVisitRefFunc, currp);
+ break;
+ }
case Expression::Id::TryId: {
self->pushTask(SubType::doVisitTry, currp);
self->pushTask(SubType::scan, &curr->cast<Try>()->catchBody);
@@ -1099,7 +1148,7 @@ struct ControlFlowWalker : public PostWalker<SubType, VisitorType> {
Expression* findBreakTarget(Name name) {
assert(!controlFlowStack.empty());
Index i = controlFlowStack.size() - 1;
- while (1) {
+ while (true) {
auto* curr = controlFlowStack[i];
if (Block* block = curr->template dynCast<Block>()) {
if (name == block->name) {
@@ -1111,7 +1160,7 @@ struct ControlFlowWalker : public PostWalker<SubType, VisitorType> {
}
} else {
// an if, ignorable
- assert(curr->template is<If>());
+ assert(curr->template is<If>() || curr->template is<Try>());
}
if (i == 0) {
return nullptr;
@@ -1169,7 +1218,7 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> {
Expression* findBreakTarget(Name name) {
assert(!expressionStack.empty());
Index i = expressionStack.size() - 1;
- while (1) {
+ while (true) {
auto* curr = expressionStack[i];
if (Block* block = curr->template dynCast<Block>()) {
if (name == block->name) {
@@ -1179,8 +1228,6 @@ struct ExpressionStackWalker : public PostWalker<SubType, VisitorType> {
if (name == loop->name) {
return curr;
}
- } else {
- WASM_UNREACHABLE("unexpected expression type");
}
if (i == 0) {
return nullptr;
diff --git a/src/wasm-type.h b/src/wasm-type.h
index 53ef39ef8..668ac3e4d 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -36,7 +36,9 @@ public:
f32,
f64,
v128,
+ funcref,
anyref,
+ nullref,
exnref,
_last_value_type,
};
@@ -64,7 +66,8 @@ public:
bool isInteger() const { return id == i32 || id == i64; }
bool isFloat() const { return id == f32 || id == f64; }
bool isVector() const { return id == v128; };
- bool isRef() const { return id == anyref || id == exnref; }
+ bool isNumber() const { return id >= i32 && id <= v128; }
+ bool isRef() const { return id >= funcref && id <= exnref; }
// (In)equality must be defined for both Type and ValueType because it is
// otherwise ambiguous whether to convert both this and other to int or
@@ -94,6 +97,23 @@ public:
// type.
static Type get(unsigned byteSize, bool float_);
+ // Returns true if left is a subtype of right. Subtype includes itself.
+ static bool isSubType(Type left, Type right);
+
+ // Computes the least upper bound from the type lattice.
+ // If one of the type is unreachable, the other type becomes the result. If
+ // the common supertype does not exist, returns none, a poison value.
+ static Type getLeastUpperBound(Type a, Type b);
+
+ // Computes the least upper bound for all types in the given list.
+ template<typename T> static Type mergeTypes(const T& types) {
+ Type type = Type::unreachable;
+ for (auto other : types) {
+ type = Type::getLeastUpperBound(type, other);
+ }
+ return type;
+ }
+
std::string toString() const;
};
@@ -134,7 +154,9 @@ constexpr Type i64 = Type::i64;
constexpr Type f32 = Type::f32;
constexpr Type f64 = Type::f64;
constexpr Type v128 = Type::v128;
+constexpr Type funcref = Type::funcref;
constexpr Type anyref = Type::anyref;
+constexpr Type nullref = Type::nullref;
constexpr Type exnref = Type::exnref;
constexpr Type unreachable = Type::unreachable;
diff --git a/src/wasm.h b/src/wasm.h
index 48adf103b..c4dbd2f3f 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -531,6 +531,9 @@ public:
MemoryFillId,
PushId,
PopId,
+ RefNullId,
+ RefIsNullId,
+ RefFuncId,
TryId,
ThrowId,
RethrowId,
@@ -569,6 +572,8 @@ public:
const char* getExpressionName(Expression* curr);
+Literal getLiteralFromConstExpression(Expression* curr);
+
typedef ArenaVector<Expression*> ExpressionList;
template<Expression::Id SID> class SpecificExpression : public Expression {
@@ -1008,6 +1013,7 @@ public:
Expression* condition;
void finalize();
+ void finalize(Type type_);
};
class Drop : public SpecificExpression<Expression::DropId> {
@@ -1070,6 +1076,32 @@ public:
Pop(MixedArena& allocator) {}
};
+class RefNull : public SpecificExpression<Expression::RefNullId> {
+public:
+ RefNull() = default;
+ RefNull(MixedArena& allocator) {}
+
+ void finalize();
+};
+
+class RefIsNull : public SpecificExpression<Expression::RefIsNullId> {
+public:
+ RefIsNull(MixedArena& allocator) {}
+
+ Expression* value;
+
+ void finalize();
+};
+
+class RefFunc : public SpecificExpression<Expression::RefFuncId> {
+public:
+ RefFunc(MixedArena& allocator) {}
+
+ Name func;
+
+ void finalize();
+};
+
class Try : public SpecificExpression<Expression::TryId> {
public:
Try(MixedArena& allocator) {}
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp
index 82a150257..4f66b36e3 100644
--- a/src/wasm/literal.cpp
+++ b/src/wasm/literal.cpp
@@ -137,8 +137,11 @@ void Literal::getBits(uint8_t (&buf)[16]) const {
case Type::v128:
memcpy(buf, &v128, sizeof(v128));
break;
- case Type::anyref: // anyref type is opaque
- case Type::exnref: // exnref type is opaque
+ case Type::funcref:
+ case Type::nullref:
+ break;
+ case Type::anyref:
+ case Type::exnref:
case Type::none:
case Type::unreachable:
WASM_UNREACHABLE("invalid type");
@@ -146,10 +149,20 @@ void Literal::getBits(uint8_t (&buf)[16]) const {
}
bool Literal::operator==(const Literal& other) const {
+ if (type.isRef() && other.type.isRef()) {
+ if (type == Type::nullref && other.type == Type::nullref) {
+ return true;
+ }
+ if (type == Type::funcref && other.type == Type::funcref &&
+ func == other.func) {
+ return true;
+ }
+ return false;
+ }
if (type != other.type) {
return false;
}
- if (type == none) {
+ if (type == Type::none) {
return true;
}
uint8_t bits[16], other_bits[16];
@@ -273,8 +286,14 @@ std::ostream& operator<<(std::ostream& o, Literal literal) {
o << "i32x4 ";
literal.printVec128(o, literal.getv128());
break;
- case Type::anyref: // anyref type is opaque
- case Type::exnref: // exnref type is opaque
+ case Type::funcref:
+ o << "funcref(" << literal.getFunc() << ")";
+ break;
+ case Type::nullref:
+ o << "nullref";
+ break;
+ case Type::anyref:
+ case Type::exnref:
case Type::unreachable:
WASM_UNREACHABLE("invalid type");
}
@@ -477,7 +496,9 @@ Literal Literal::eqz() const {
case Type::f64:
return eq(Literal(double(0)));
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -497,7 +518,9 @@ Literal Literal::neg() const {
case Type::f64:
return Literal(int64_t(i64 ^ 0x8000000000000000ULL)).castToF64();
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -517,7 +540,9 @@ Literal Literal::abs() const {
case Type::f64:
return Literal(int64_t(i64 & 0x7fffffffffffffffULL)).castToF64();
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -620,7 +645,9 @@ Literal Literal::add(const Literal& other) const {
case Type::f64:
return Literal(getf64() + other.getf64());
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -640,7 +667,9 @@ Literal Literal::sub(const Literal& other) const {
case Type::f64:
return Literal(getf64() - other.getf64());
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -731,7 +760,9 @@ Literal Literal::mul(const Literal& other) const {
case Type::f64:
return Literal(getf64() * other.getf64());
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -967,7 +998,9 @@ Literal Literal::eq(const Literal& other) const {
case Type::f64:
return Literal(getf64() == other.getf64());
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -987,7 +1020,9 @@ Literal Literal::ne(const Literal& other) const {
case Type::f64:
return Literal(getf64() != other.getf64());
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 82eb51d7e..ba5a8d3dd 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -262,7 +262,7 @@ void WasmBinaryWriter::writeImports() {
BYN_TRACE("write one table\n");
writeImportHeader(&wasm->table);
o << U32LEB(int32_t(ExternalKind::Table));
- o << S32LEB(BinaryConsts::EncodedType::AnyFunc);
+ o << S32LEB(BinaryConsts::EncodedType::funcref);
writeResizableLimits(wasm->table.initial,
wasm->table.max,
wasm->table.hasMax(),
@@ -463,7 +463,7 @@ void WasmBinaryWriter::writeFunctionTableDeclaration() {
BYN_TRACE("== writeFunctionTableDeclaration\n");
auto start = startSection(BinaryConsts::Section::Table);
o << U32LEB(1); // Declare 1 table.
- o << S32LEB(BinaryConsts::EncodedType::AnyFunc);
+ o << S32LEB(BinaryConsts::EncodedType::funcref);
writeResizableLimits(wasm->table.initial,
wasm->table.max,
wasm->table.hasMax(),
@@ -1059,8 +1059,12 @@ Type WasmBinaryBuilder::getType() {
return f64;
case BinaryConsts::EncodedType::v128:
return v128;
+ case BinaryConsts::EncodedType::funcref:
+ return funcref;
case BinaryConsts::EncodedType::anyref:
return anyref;
+ case BinaryConsts::EncodedType::nullref:
+ return nullref;
case BinaryConsts::EncodedType::exnref:
return exnref;
default:
@@ -1258,8 +1262,8 @@ void WasmBinaryBuilder::readImports() {
wasm.table.name = Name(std::string("timport$") + std::to_string(i));
auto elementType = getS32LEB();
WASM_UNUSED(elementType);
- if (elementType != BinaryConsts::EncodedType::AnyFunc) {
- throwError("Imported table type is not AnyFunc");
+ if (elementType != BinaryConsts::EncodedType::funcref) {
+ throwError("Imported table type is not funcref");
}
wasm.table.exists = true;
bool is_shared;
@@ -1802,11 +1806,17 @@ void WasmBinaryBuilder::processFunctions() {
wasm.addExport(curr);
}
- for (auto& iter : functionCalls) {
+ for (auto& iter : functionRefs) {
size_t index = iter.first;
- auto& calls = iter.second;
- for (auto* call : calls) {
- call->target = getFunctionName(index);
+ auto& refs = iter.second;
+ for (auto* ref : refs) {
+ if (auto* call = ref->dynCast<Call>()) {
+ call->target = getFunctionName(index);
+ } else if (auto* refFunc = ref->dynCast<RefFunc>()) {
+ refFunc->func = getFunctionName(index);
+ } else {
+ WASM_UNREACHABLE("Invalid type in function references");
+ }
}
}
@@ -1869,8 +1879,8 @@ void WasmBinaryBuilder::readFunctionTableDeclaration() {
}
wasm.table.exists = true;
auto elemType = getS32LEB();
- if (elemType != BinaryConsts::EncodedType::AnyFunc) {
- throwError("ElementType must be AnyFunc in MVP");
+ if (elemType != BinaryConsts::EncodedType::funcref) {
+ throwError("ElementType must be funcref in MVP");
}
bool is_shared;
getResizableLimits(
@@ -2117,7 +2127,8 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
visitGlobalSet((curr = allocator.alloc<GlobalSet>())->cast<GlobalSet>());
break;
case BinaryConsts::Select:
- visitSelect((curr = allocator.alloc<Select>())->cast<Select>());
+ case BinaryConsts::SelectWithType:
+ visitSelect((curr = allocator.alloc<Select>())->cast<Select>(), code);
break;
case BinaryConsts::Return:
visitReturn((curr = allocator.alloc<Return>())->cast<Return>());
@@ -2137,6 +2148,15 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) {
case BinaryConsts::Catch:
curr = nullptr;
break;
+ case BinaryConsts::RefNull:
+ visitRefNull((curr = allocator.alloc<RefNull>())->cast<RefNull>());
+ break;
+ case BinaryConsts::RefIsNull:
+ visitRefIsNull((curr = allocator.alloc<RefIsNull>())->cast<RefIsNull>());
+ break;
+ case BinaryConsts::RefFunc:
+ visitRefFunc((curr = allocator.alloc<RefFunc>())->cast<RefFunc>());
+ break;
case BinaryConsts::Try:
visitTry((curr = allocator.alloc<Try>())->cast<Try>());
break;
@@ -2510,7 +2530,7 @@ void WasmBinaryBuilder::visitCall(Call* curr) {
curr->operands[num - i - 1] = popNonVoidExpression();
}
curr->type = sig.results;
- functionCalls[index].push_back(curr); // we don't know function names yet
+ functionRefs[index].push_back(curr); // we don't know function names yet
curr->finalize();
}
@@ -4326,12 +4346,24 @@ bool WasmBinaryBuilder::maybeVisitSIMDLoad(Expression*& out, uint32_t code) {
return true;
}
-void WasmBinaryBuilder::visitSelect(Select* curr) {
- BYN_TRACE("zz node: Select\n");
+void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) {
+ BYN_TRACE("zz node: Select, code " << int32_t(code) << std::endl);
+ if (code == BinaryConsts::SelectWithType) {
+ size_t numTypes = getU32LEB();
+ std::vector<Type> types;
+ for (size_t i = 0; i < numTypes; i++) {
+ types.push_back(getType());
+ }
+ curr->type = Type(types);
+ }
curr->condition = popNonVoidExpression();
curr->ifFalse = popNonVoidExpression();
curr->ifTrue = popNonVoidExpression();
- curr->finalize();
+ if (code == BinaryConsts::SelectWithType) {
+ curr->finalize(curr->type);
+ } else {
+ curr->finalize();
+ }
}
void WasmBinaryBuilder::visitReturn(Return* curr) {
@@ -4383,6 +4415,27 @@ void WasmBinaryBuilder::visitDrop(Drop* curr) {
curr->finalize();
}
+void WasmBinaryBuilder::visitRefNull(RefNull* curr) {
+ BYN_TRACE("zz node: RefNull\n");
+ curr->finalize();
+}
+
+void WasmBinaryBuilder::visitRefIsNull(RefIsNull* curr) {
+ BYN_TRACE("zz node: RefIsNull\n");
+ curr->value = popNonVoidExpression();
+ curr->finalize();
+}
+
+void WasmBinaryBuilder::visitRefFunc(RefFunc* curr) {
+ BYN_TRACE("zz node: RefFunc\n");
+ Index index = getU32LEB();
+ if (index >= functionImports.size() + functionSignatures.size()) {
+ throwError("ref.func: invalid call index");
+ }
+ functionRefs[index].push_back(curr); // we don't know function names yet
+ curr->finalize();
+}
+
void WasmBinaryBuilder::visitTry(Try* curr) {
BYN_TRACE("zz node: Try\n");
// For simplicity of implementation, like if scopes, we create a hidden block
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 20aff2091..3b12c4346 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -850,16 +850,22 @@ Type SExpressionWasmBuilder::stringToType(const char* str,
return v128;
}
}
+ if (strncmp(str, "funcref", 7) == 0 && (prefix || str[7] == 0)) {
+ return funcref;
+ }
if (strncmp(str, "anyref", 6) == 0 && (prefix || str[6] == 0)) {
return anyref;
}
+ if (strncmp(str, "nullref", 7) == 0 && (prefix || str[7] == 0)) {
+ return nullref;
+ }
if (strncmp(str, "exnref", 6) == 0 && (prefix || str[6] == 0)) {
return exnref;
}
if (allowError) {
return none;
}
- throw ParseException("invalid wasm type");
+ throw ParseException(std::string("invalid wasm type: ") + str);
}
Type SExpressionWasmBuilder::stringToLaneType(const char* str) {
@@ -936,10 +942,16 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) {
Expression* SExpressionWasmBuilder::makeSelect(Element& s) {
auto ret = allocator.alloc<Select>();
- ret->ifTrue = parseExpression(s[1]);
- ret->ifFalse = parseExpression(s[2]);
- ret->condition = parseExpression(s[3]);
- ret->finalize();
+ Index i = 1;
+ Type type = parseOptionalResultType(s, i);
+ ret->ifTrue = parseExpression(s[i++]);
+ ret->ifFalse = parseExpression(s[i++]);
+ ret->condition = parseExpression(s[i]);
+ if (type.isConcrete()) {
+ ret->finalize(type);
+ } else {
+ ret->finalize();
+ }
return ret;
}
@@ -1718,6 +1730,27 @@ Expression* SExpressionWasmBuilder::makeReturn(Element& s) {
return ret;
}
+Expression* SExpressionWasmBuilder::makeRefNull(Element& s) {
+ auto ret = allocator.alloc<RefNull>();
+ ret->finalize();
+ return ret;
+}
+
+Expression* SExpressionWasmBuilder::makeRefIsNull(Element& s) {
+ auto ret = allocator.alloc<RefIsNull>();
+ ret->value = parseExpression(s[1]);
+ ret->finalize();
+ return ret;
+}
+
+Expression* SExpressionWasmBuilder::makeRefFunc(Element& s) {
+ auto func = getFunctionName(*s[1]);
+ auto ret = allocator.alloc<RefFunc>();
+ ret->func = func;
+ ret->finalize();
+ return ret;
+}
+
// try-catch-end is written in the folded wast format as
// (try
// ...
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index 593214838..22d6a0036 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -147,8 +147,10 @@ void BinaryInstWriter::visitLoad(Load* curr) {
// the pointer is unreachable, so we are never reached; just don't emit
// a load
return;
- case anyref: // anyref cannot be loaded from memory
- case exnref: // exnref cannot be loaded from memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
WASM_UNREACHABLE("unexpected type");
}
@@ -247,8 +249,10 @@ void BinaryInstWriter::visitStore(Store* curr) {
o << int8_t(BinaryConsts::SIMDPrefix)
<< U32LEB(BinaryConsts::V128Store);
break;
- case anyref: // anyref cannot be stored from memory
- case exnref: // exnref cannot be stored in memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -642,8 +646,10 @@ void BinaryInstWriter::visitConst(Const* curr) {
}
break;
}
- case anyref: // there's no anyref.const
- case exnref: // there's no exnref.const
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
case unreachable:
WASM_UNREACHABLE("unexpected type");
@@ -1541,7 +1547,15 @@ void BinaryInstWriter::visitBinary(Binary* curr) {
}
void BinaryInstWriter::visitSelect(Select* curr) {
- o << int8_t(BinaryConsts::Select);
+ if (curr->type.isRef()) {
+ o << int8_t(BinaryConsts::SelectWithType) << U32LEB(curr->type.size());
+ for (size_t i = 0; i < curr->type.size(); i++) {
+ o << binaryType(curr->type != Type::unreachable ? curr->type
+ : Type::none);
+ }
+ } else {
+ o << int8_t(BinaryConsts::Select);
+ }
}
void BinaryInstWriter::visitReturn(Return* curr) {
@@ -1562,6 +1576,19 @@ void BinaryInstWriter::visitHost(Host* curr) {
o << U32LEB(0); // Reserved flags field
}
+void BinaryInstWriter::visitRefNull(RefNull* curr) {
+ o << int8_t(BinaryConsts::RefNull);
+}
+
+void BinaryInstWriter::visitRefIsNull(RefIsNull* curr) {
+ o << int8_t(BinaryConsts::RefIsNull);
+}
+
+void BinaryInstWriter::visitRefFunc(RefFunc* curr) {
+ o << int8_t(BinaryConsts::RefFunc)
+ << U32LEB(parent.getFunctionIndex(curr->func));
+}
+
void BinaryInstWriter::visitTry(Try* curr) {
breakStack.emplace_back(IMPOSSIBLE_CONTINUE);
o << int8_t(BinaryConsts::Try);
@@ -1659,11 +1686,21 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
continue;
}
index += numLocalsByType[v128];
+ if (type == funcref) {
+ mappedLocals[i] = index + currLocalsByType[funcref] - 1;
+ continue;
+ }
+ index += numLocalsByType[funcref];
if (type == anyref) {
mappedLocals[i] = index + currLocalsByType[anyref] - 1;
continue;
}
index += numLocalsByType[anyref];
+ if (type == nullref) {
+ mappedLocals[i] = index + currLocalsByType[nullref] - 1;
+ continue;
+ }
+ index += numLocalsByType[nullref];
if (type == exnref) {
mappedLocals[i] = index + currLocalsByType[exnref] - 1;
continue;
@@ -1671,11 +1708,12 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
WASM_UNREACHABLE("unexpected type");
}
// Emit them.
- o << U32LEB((numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) +
- (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) +
- (numLocalsByType[v128] ? 1 : 0) +
- (numLocalsByType[anyref] ? 1 : 0) +
- (numLocalsByType[exnref] ? 1 : 0));
+ o << U32LEB(
+ (numLocalsByType[i32] ? 1 : 0) + (numLocalsByType[i64] ? 1 : 0) +
+ (numLocalsByType[f32] ? 1 : 0) + (numLocalsByType[f64] ? 1 : 0) +
+ (numLocalsByType[v128] ? 1 : 0) + (numLocalsByType[funcref] ? 1 : 0) +
+ (numLocalsByType[anyref] ? 1 : 0) + (numLocalsByType[nullref] ? 1 : 0) +
+ (numLocalsByType[exnref] ? 1 : 0));
if (numLocalsByType[i32]) {
o << U32LEB(numLocalsByType[i32]) << binaryType(i32);
}
@@ -1691,9 +1729,15 @@ void BinaryInstWriter::mapLocalsAndEmitHeader() {
if (numLocalsByType[v128]) {
o << U32LEB(numLocalsByType[v128]) << binaryType(v128);
}
+ if (numLocalsByType[funcref]) {
+ o << U32LEB(numLocalsByType[funcref]) << binaryType(funcref);
+ }
if (numLocalsByType[anyref]) {
o << U32LEB(numLocalsByType[anyref]) << binaryType(anyref);
}
+ if (numLocalsByType[nullref]) {
+ o << U32LEB(numLocalsByType[nullref]) << binaryType(nullref);
+ }
if (numLocalsByType[exnref]) {
o << U32LEB(numLocalsByType[exnref]) << binaryType(exnref);
}
@@ -1760,7 +1804,7 @@ StackInst* StackIRGenerator::makeStackInst(StackInst::Op op,
// type.
stackType = none;
} else if (op != StackInst::BlockEnd && op != StackInst::IfEnd &&
- op != StackInst::LoopEnd) {
+ op != StackInst::LoopEnd && op != StackInst::TryEnd) {
// If a concrete type is returned, we mark the end of the construct has
// having that type (as it is pushed to the value stack at that point),
// other parts are marked as none).
@@ -1781,13 +1825,15 @@ void StackIRToBinaryWriter::write() {
case StackInst::Basic:
case StackInst::BlockBegin:
case StackInst::IfBegin:
- case StackInst::LoopBegin: {
+ case StackInst::LoopBegin:
+ case StackInst::TryBegin: {
writer.visit(inst->origin);
break;
}
case StackInst::BlockEnd:
case StackInst::IfEnd:
- case StackInst::LoopEnd: {
+ case StackInst::LoopEnd:
+ case StackInst::TryEnd: {
writer.emitScopeEnd();
break;
}
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index 939ee8c93..62c30d1e0 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -62,7 +62,9 @@ std::vector<std::unique_ptr<std::vector<Type>>> typeLists = [] {
add({Type::f32});
add({Type::f64});
add({Type::v128});
+ add({Type::funcref});
add({Type::anyref});
+ add({Type::nullref});
add({Type::exnref});
return lists;
}();
@@ -75,7 +77,9 @@ std::unordered_map<std::vector<Type>, uint32_t> indices = {
{{Type::f32}, Type::f32},
{{Type::f64}, Type::f64},
{{Type::v128}, Type::v128},
+ {{Type::funcref}, Type::funcref},
{{Type::anyref}, Type::anyref},
+ {{Type::nullref}, Type::nullref},
{{Type::exnref}, Type::exnref},
};
@@ -154,8 +158,10 @@ unsigned Type::getByteSize() const {
return 8;
case Type::v128:
return 16;
- case Type::anyref: // anyref type is opaque
- case Type::exnref: // exnref type is opaque
+ case Type::funcref:
+ case Type::anyref:
+ case Type::nullref:
+ case Type::exnref:
case Type::none:
case Type::unreachable:
WASM_UNREACHABLE("invalid type");
@@ -164,7 +170,7 @@ unsigned Type::getByteSize() const {
}
Type Type::reinterpret() const {
- assert(isSingle() && "reinterpret only works with single types");
+ assert(isSingle() && "reinterpretType only works with single types");
Type singleType = *expand().begin();
switch (singleType) {
case Type::i32:
@@ -176,7 +182,9 @@ Type Type::reinterpret() const {
case Type::f64:
return i64;
case Type::v128:
+ case Type::funcref:
case Type::anyref:
+ case Type::nullref:
case Type::exnref:
case Type::none:
case Type::unreachable:
@@ -221,6 +229,39 @@ Type Type::get(unsigned byteSize, bool float_) {
WASM_UNREACHABLE("invalid size");
}
+bool Type::Type::isSubType(Type left, Type right) {
+ if (left == right) {
+ return true;
+ }
+ if (left.isRef() && right.isRef() &&
+ (right == Type::anyref || left == Type::nullref)) {
+ return true;
+ }
+ return false;
+}
+
+Type Type::Type::getLeastUpperBound(Type a, Type b) {
+ if (a == b) {
+ return a;
+ }
+ if (a == Type::unreachable) {
+ return b;
+ }
+ if (b == Type::unreachable) {
+ return a;
+ }
+ if (!a.isRef() || !b.isRef()) {
+ return none; // a poison value that must not be consumed
+ }
+ if (a == Type::nullref) {
+ return b;
+ }
+ if (b == Type::nullref) {
+ return a;
+ }
+ return Type::anyref;
+}
+
namespace {
std::ostream&
@@ -280,9 +321,15 @@ std::ostream& operator<<(std::ostream& os, Type type) {
case Type::v128:
os << "v128";
break;
+ case Type::funcref:
+ os << "funcref";
+ break;
case Type::anyref:
os << "anyref";
break;
+ case Type::nullref:
+ os << "nullref";
+ break;
case Type::exnref:
os << "exnref";
break;
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index 55e115d95..7bf51c5f7 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -21,6 +21,7 @@
#include "ir/branch-utils.h"
#include "ir/features.h"
+#include "ir/global-utils.h"
#include "ir/module-utils.h"
#include "ir/utils.h"
#include "support/colors.h"
@@ -181,6 +182,31 @@ struct ValidationInfo {
fail(text, curr, func);
}
}
+
+ // Type 'left' should be a subtype of 'right'.
+ bool shouldBeSubType(Type left,
+ Type right,
+ Expression* curr,
+ const char* text,
+ Function* func = nullptr) {
+ if (Type::isSubType(left, right)) {
+ return true;
+ }
+ fail(text, curr, func);
+ return false;
+ }
+
+ // Type 'left' should be a subtype of 'right', or unreachable.
+ bool shouldBeSubTypeOrUnreachable(Type left,
+ Type right,
+ Expression* curr,
+ const char* text,
+ Function* func = nullptr) {
+ if (left == Type::unreachable) {
+ return true;
+ }
+ return shouldBeSubType(left, right, curr, text, func);
+ }
};
struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
@@ -210,7 +236,7 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
std::unordered_map<Name, BreakInfo> breakInfos;
- Type returnType = unreachable; // type used in returns
+ std::set<Type> returnTypes; // types used in returns
// Binaryen IR requires that label names must be unique - IR generators must
// ensure that
@@ -287,6 +313,8 @@ public:
void visitDrop(Drop* curr);
void visitReturn(Return* curr);
void visitHost(Host* curr);
+ void visitRefIsNull(RefIsNull* curr);
+ void visitRefFunc(RefFunc* curr);
void visitTry(Try* curr);
void visitThrow(Throw* curr);
void visitRethrow(Rethrow* curr);
@@ -327,6 +355,19 @@ private:
return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction());
}
+ bool
+ shouldBeSubType(Type left, Type right, Expression* curr, const char* text) {
+ return info.shouldBeSubType(left, right, curr, text, getFunction());
+ }
+
+ bool shouldBeSubTypeOrUnreachable(Type left,
+ Type right,
+ Expression* curr,
+ const char* text) {
+ return info.shouldBeSubTypeOrUnreachable(
+ left, right, curr, text, getFunction());
+ }
+
void validateAlignment(
size_t align, Type type, Index bytes, bool isAtomic, Expression* curr);
void validateMemBytes(uint8_t bytes, Type type, Expression* curr);
@@ -364,29 +405,23 @@ void FunctionValidator::visitBlock(Block* curr) {
// none or unreachable means a poison value that we should ignore - if
// consumed, it will error
if (info.type.isConcrete() && curr->type.isConcrete()) {
- shouldBeEqual(
- curr->type,
+ shouldBeSubType(
info.type,
+ curr->type,
curr,
"block+breaks must have right type if breaks return a value");
}
if (curr->type.isConcrete() && info.arity && info.type != unreachable) {
- shouldBeEqual(curr->type,
- info.type,
- curr,
- "block+breaks must have right type if breaks have arity");
+ shouldBeSubType(
+ info.type,
+ curr->type,
+ curr,
+ "block+breaks must have right type if breaks have arity");
}
shouldBeTrue(
info.arity != BreakInfo::PoisonArity, curr, "break arities must match");
if (curr->list.size() > 0) {
auto last = curr->list.back()->type;
- if (last.isConcrete() && info.type != unreachable) {
- shouldBeEqual(last,
- info.type,
- curr,
- "block+breaks must have right type if block ends with "
- "a reachable value");
- }
if (last == none) {
shouldBeTrue(info.arity == Index(0),
curr,
@@ -420,9 +455,9 @@ void FunctionValidator::visitBlock(Block* curr) {
"not flow out a value");
} else {
if (backType.isConcrete()) {
- shouldBeEqual(
- curr->type,
+ shouldBeSubType(
backType,
+ curr->type,
curr,
"block with value and last element with value must match types");
} else {
@@ -457,6 +492,23 @@ void FunctionValidator::visitLoop(Loop* curr) {
curr,
"bad body for a loop that has no value");
}
+
+ // When there are multiple instructions within a loop, they are wrapped in a
+ // Block internally, so visitBlock can take care of verification. Here we
+ // check cases when there is only one instruction in a Loop.
+ if (!curr->body->is<Block>()) {
+ if (!curr->type.isConcrete()) {
+ shouldBeFalse(curr->body->type.isConcrete(),
+ curr,
+ "if loop is not returning a value, final element should "
+ "not flow out a value");
+ } else {
+ shouldBeSubTypeOrUnreachable(curr->body->type,
+ curr->type,
+ curr,
+ "loop with value and body must match types");
+ }
+ }
}
void FunctionValidator::visitIf(If* curr) {
@@ -476,12 +528,12 @@ void FunctionValidator::visitIf(If* curr) {
}
} else {
if (curr->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(
+ shouldBeSubTypeOrUnreachable(
curr->ifTrue->type,
curr->type,
curr,
"returning if-else's true must have right type");
- shouldBeEqualOrFirstIsUnreachable(
+ shouldBeSubTypeOrUnreachable(
curr->ifFalse->type,
curr->type,
curr,
@@ -499,25 +551,16 @@ void FunctionValidator::visitIf(If* curr) {
}
}
if (curr->ifTrue->type.isConcrete()) {
- shouldBeEqual(curr->type,
- curr->ifTrue->type,
- curr,
- "if type must match concrete ifTrue");
- shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type,
- curr->ifTrue->type,
- curr,
- "other arm must match concrete ifTrue");
+ shouldBeSubType(curr->ifTrue->type,
+ curr->type,
+ curr,
+ "if type must match concrete ifTrue");
}
if (curr->ifFalse->type.isConcrete()) {
- shouldBeEqual(curr->type,
- curr->ifFalse->type,
- curr,
- "if type must match concrete ifFalse");
- shouldBeEqualOrFirstIsUnreachable(
- curr->ifTrue->type,
- curr->ifFalse->type,
- curr,
- "other arm must match concrete ifFalse");
+ shouldBeSubType(curr->ifFalse->type,
+ curr->type,
+ curr,
+ "if type must match concrete ifFalse");
}
}
}
@@ -545,13 +588,7 @@ void FunctionValidator::noteBreak(Name name, Type valueType, Expression* curr) {
if (!info.hasBeenSet()) {
info = BreakInfo(valueType, arity);
} else {
- if (info.type == unreachable) {
- info.type = valueType;
- } else if (valueType != unreachable) {
- if (valueType != info.type) {
- info.type = none; // a poison value that must not be consumed
- }
- }
+ info.type = Type::getLeastUpperBound(info.type, valueType);
if (arity != info.arity) {
info.arity = BreakInfo::PoisonArity;
}
@@ -600,10 +637,10 @@ void FunctionValidator::visitCall(Call* curr) {
return;
}
for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type,
- params[i],
- curr,
- "call param types must match") &&
+ if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type,
+ params[i],
+ curr,
+ "call param types must match") &&
!info.quiet) {
getStream() << "(on argument " << i << ")\n";
}
@@ -653,10 +690,10 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) {
return;
}
for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type,
- params[i],
- curr,
- "call param types must match") &&
+ if (!shouldBeSubTypeOrUnreachable(curr->operands[i]->type,
+ params[i],
+ curr,
+ "call param types must match") &&
!info.quiet) {
getStream() << "(on argument " << i << ")\n";
}
@@ -723,10 +760,10 @@ void FunctionValidator::visitLocalSet(LocalSet* curr) {
curr,
"local.set type must be correct");
}
- shouldBeEqual(curr->value->type,
- getFunction()->getLocalType(curr->index),
- curr,
- "local.set's value type must be correct");
+ shouldBeSubType(curr->value->type,
+ getFunction()->getLocalType(curr->index),
+ curr,
+ "local.set's value type must be correct");
}
}
}
@@ -750,10 +787,10 @@ void FunctionValidator::visitGlobalSet(GlobalSet* curr) {
"global.set name must be valid (and not an import; imports "
"can't be modified)")) {
shouldBeTrue(global->mutable_, curr, "global.set global must be mutable");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type,
- global->type,
- curr,
- "global.set value must have right type");
+ shouldBeSubTypeOrUnreachable(curr->value->type,
+ global->type,
+ curr,
+ "global.set value must have right type");
}
}
@@ -1182,12 +1219,14 @@ void FunctionValidator::validateMemBytes(uint8_t bytes,
shouldBeEqual(
bytes, uint8_t(16), curr, "expected v128 operation to touch 16 bytes");
break;
- case anyref: // anyref cannot be stored in memory
- case exnref: // exnref cannot be stored in memory
- case none:
- WASM_UNREACHABLE("unexpected type");
case unreachable:
break;
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
+ case none:
+ WASM_UNREACHABLE("unexpected type");
}
}
@@ -1616,15 +1655,18 @@ void FunctionValidator::visitSelect(Select* curr) {
shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
shouldBeUnequal(
curr->ifFalse->type, none, curr, "select right must be valid");
+ shouldBeUnequal(curr->type, none, curr, "select type must be valid");
shouldBeTrue(curr->condition->type == unreachable ||
curr->condition->type == i32,
curr,
"select condition must be valid");
- if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) {
- shouldBeEqual(curr->ifTrue->type,
- curr->ifFalse->type,
- curr,
- "select sides must be equal");
+ if (curr->type != unreachable) {
+ shouldBeTrue(Type::isSubType(curr->ifTrue->type, curr->type),
+ curr,
+ "select's left expression must be subtype of select's type");
+ shouldBeTrue(Type::isSubType(curr->ifFalse->type, curr->type),
+ curr,
+ "select's right expression must be subtype of select's type");
}
}
@@ -1636,16 +1678,7 @@ void FunctionValidator::visitDrop(Drop* curr) {
}
void FunctionValidator::visitReturn(Return* curr) {
- if (curr->value) {
- if (returnType == unreachable) {
- returnType = curr->value->type;
- } else if (curr->value->type != unreachable) {
- shouldBeEqual(
- curr->value->type, returnType, curr, "function results must match");
- }
- } else {
- returnType = none;
- }
+ returnTypes.insert(curr->value ? curr->value->type : Type::none);
}
void FunctionValidator::visitHost(Host* curr) {
@@ -1668,32 +1701,37 @@ void FunctionValidator::visitHost(Host* curr) {
}
}
+void FunctionValidator::visitRefIsNull(RefIsNull* curr) {
+ shouldBeTrue(curr->value->type == Type::unreachable ||
+ curr->value->type.isRef(),
+ curr->value,
+ "ref.is_null's argument should be a reference type");
+}
+
+void FunctionValidator::visitRefFunc(RefFunc* curr) {
+ auto* func = getModule()->getFunctionOrNull(curr->func);
+ shouldBeTrue(!!func, curr, "function argument of ref.func must exist");
+}
+
void FunctionValidator::visitTry(Try* curr) {
if (curr->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(
- curr->body->type,
- curr->type,
- curr->body,
- "try's type does not match try body's type");
- shouldBeEqualOrFirstIsUnreachable(
- curr->catchBody->type,
- curr->type,
- curr->catchBody,
- "try's type does not match catch's body type");
- }
- if (curr->body->type.isConcrete()) {
- shouldBeEqualOrFirstIsUnreachable(
- curr->catchBody->type,
- curr->body->type,
- curr->catchBody,
- "try's body type must match catch's body type");
- }
- if (curr->catchBody->type.isConcrete()) {
- shouldBeEqualOrFirstIsUnreachable(
- curr->body->type,
- curr->catchBody->type,
- curr->body,
- "try's body type must match catch's body type");
+ shouldBeSubTypeOrUnreachable(curr->body->type,
+ curr->type,
+ curr->body,
+ "try's type does not match try body's type");
+ shouldBeSubTypeOrUnreachable(curr->catchBody->type,
+ curr->type,
+ curr->catchBody,
+ "try's type does not match catch's body type");
+ } else {
+ shouldBeEqual(curr->body->type,
+ unreachable,
+ curr,
+ "unreachable try-catch must have unreachable try body");
+ shouldBeEqual(curr->catchBody->type,
+ unreachable,
+ curr,
+ "unreachable try-catch must have unreachable catch body");
}
}
@@ -1727,10 +1765,10 @@ void FunctionValidator::visitThrow(Throw* curr) {
void FunctionValidator::visitRethrow(Rethrow* curr) {
shouldBeEqual(
curr->type, unreachable, curr, "rethrow's type must be unreachable");
- shouldBeEqual(curr->exnref->type,
- exnref,
- curr->exnref,
- "rethrow's argument must be exnref type");
+ shouldBeSubType(curr->exnref->type,
+ Type::exnref,
+ curr->exnref,
+ "rethrow's argument must be exnref type or its subtype");
}
void FunctionValidator::visitBrOnExn(BrOnExn* curr) {
@@ -1740,10 +1778,11 @@ void FunctionValidator::visitBrOnExn(BrOnExn* curr) {
curr,
"br_on_exn's event params and event's params are different");
noteBreak(curr->name, curr->sent, curr);
- shouldBeTrue(curr->exnref->type == unreachable ||
- curr->exnref->type == exnref,
- curr,
- "br_on_exn's argument must be unreachable or exnref type");
+ shouldBeSubTypeOrUnreachable(
+ curr->exnref->type,
+ Type::exnref,
+ curr,
+ "br_on_exn's argument must be unreachable or exnref type or its subtype");
if (curr->exnref->type == unreachable) {
shouldBeTrue(curr->type == unreachable,
curr,
@@ -1779,21 +1818,22 @@ void FunctionValidator::visitFunction(Function* curr) {
"all used types should be allowed");
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->sig.results,
- curr->body->type,
- curr->body,
- "function body type must match, if function returns");
- }
- if (returnType != unreachable) {
- shouldBeEqual(curr->sig.results,
- returnType,
- curr->body,
- "function result must match, if function has returns");
+ shouldBeSubTypeOrUnreachable(
+ curr->body->type,
+ curr->sig.results,
+ curr->body,
+ "function body type must match, if function returns");
+ for (Type returnType : returnTypes) {
+ shouldBeSubTypeOrUnreachable(
+ returnType,
+ curr->sig.results,
+ curr->body,
+ "function result must match, if function has returns");
}
+
shouldBeTrue(
breakInfos.empty(), curr->body, "all named break targets must exist");
- returnType = unreachable;
+ returnTypes.clear();
labelNames.clear();
// validate optional local names
std::set<Name> seen;
@@ -1858,8 +1898,10 @@ void FunctionValidator::validateAlignment(
case v128:
case unreachable:
break;
- case anyref: // anyref cannot be stored in memory
- case exnref: // exnref cannot be stored in memory
+ case funcref:
+ case anyref:
+ case nullref:
+ case exnref:
case none:
WASM_UNREACHABLE("invalid type");
}
@@ -1890,7 +1932,8 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) {
//
// The block has an added type, not derived from the ast itself, so it
// is ok for it to be either i32 or unreachable.
- if (!(oldType.isConcrete() && newType == unreachable)) {
+ if (!Type::isSubType(newType, oldType) &&
+ !(oldType.isConcrete() && newType == Type::unreachable)) {
std::ostringstream ss;
ss << "stale type found in " << scope << " on " << curr
<< "\n(marked as " << oldType << ", should be " << newType
@@ -2011,13 +2054,14 @@ static void validateGlobals(Module& module, ValidationInfo& info) {
info.shouldBeTrue(
curr->init != nullptr, curr->name, "global init must be non-null");
assert(curr->init);
- info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GlobalGet>(),
+ info.shouldBeTrue(GlobalUtils::canInitializeGlobal(curr->init),
curr->name,
"global init must be valid");
- if (!info.shouldBeEqual(curr->type,
- curr->init->type,
- curr->init,
- "global init must have correct type") &&
+
+ if (!info.shouldBeSubType(curr->init->type,
+ curr->type,
+ curr->init,
+ "global init must have correct type") &&
!info.quiet) {
info.getStream(nullptr) << "(on global " << curr->name << ")\n";
}
@@ -2118,9 +2162,9 @@ static void validateEvents(Module& module, ValidationInfo& info) {
curr->name,
"Event type's result type should be none");
for (auto type : curr->sig.params.expand()) {
- info.shouldBeTrue(type.isInteger() || type.isFloat(),
+ info.shouldBeTrue(type.isConcrete(),
curr->name,
- "Values in an event should have integer or float type");
+ "Values in an event should have concrete types");
}
}
}
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index ff1295bad..11d203835 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -173,13 +173,19 @@ const char* getExpressionName(Expression* curr) {
return "push";
case Expression::Id::PopId:
return "pop";
- case Expression::TryId:
+ case Expression::Id::RefNullId:
+ return "ref.null";
+ case Expression::Id::RefIsNullId:
+ return "ref.is_null";
+ case Expression::Id::RefFuncId:
+ return "ref.func";
+ case Expression::Id::TryId:
return "try";
- case Expression::ThrowId:
+ case Expression::Id::ThrowId:
return "throw";
- case Expression::RethrowId:
+ case Expression::Id::RethrowId:
return "rethrow";
- case Expression::BrOnExnId:
+ case Expression::Id::BrOnExnId:
return "br_on_exn";
case Expression::Id::NumExpressionIds:
WASM_UNREACHABLE("invalid expr id");
@@ -187,6 +193,18 @@ const char* getExpressionName(Expression* curr) {
WASM_UNREACHABLE("invalid expr id");
}
+Literal getLiteralFromConstExpression(Expression* curr) {
+ if (auto* c = curr->dynCast<Const>()) {
+ return c->value;
+ } else if (curr->is<RefNull>()) {
+ return Literal::makeNullref();
+ } else if (auto* r = curr->dynCast<RefFunc>()) {
+ return Literal::makeFuncref(r->func);
+ } else {
+ WASM_UNREACHABLE("Not a constant expression");
+ }
+}
+
// core AST type checking
struct TypeSeeker : public PostWalker<TypeSeeker> {
@@ -248,27 +266,6 @@ struct TypeSeeker : public PostWalker<TypeSeeker> {
}
};
-static Type mergeTypes(std::vector<Type>& types) {
- Type type = unreachable;
- for (auto other : types) {
- // once none, stop. it then indicates a poison value, that must not be
- // consumed and ignore unreachable
- if (type != none) {
- if (other == none) {
- type = none;
- } else if (other != unreachable) {
- if (type == unreachable) {
- type = other;
- } else if (type != other) {
- // poison value, we saw multiple types; this should not be consumed
- type = none;
- }
- }
- }
- }
- return type;
-}
-
// a block is unreachable if one of its elements is unreachable,
// and there are no branches to it
static void handleUnreachable(Block* block,
@@ -336,7 +333,7 @@ void Block::finalize() {
}
TypeSeeker seeker(this, this->name);
- type = mergeTypes(seeker.types);
+ type = Type::mergeTypes(seeker.types);
handleUnreachable(this);
}
@@ -364,19 +361,8 @@ void If::finalize(Type type_) {
}
void If::finalize() {
- if (ifFalse) {
- if (ifTrue->type == ifFalse->type) {
- type = ifTrue->type;
- } else if (ifTrue->type.isConcrete() && ifFalse->type == unreachable) {
- type = ifTrue->type;
- } else if (ifFalse->type.isConcrete() && ifTrue->type == unreachable) {
- type = ifFalse->type;
- } else {
- type = none;
- }
- } else {
- type = none; // if without else
- }
+ type = ifFalse ? Type::getLeastUpperBound(ifTrue->type, ifFalse->type)
+ : Type::none;
// if the arms return a value, leave it even if the condition
// is unreachable, we still mark ourselves as having that type, e.g.
// (if (result i32)
@@ -828,13 +814,15 @@ void Binary::finalize() {
}
}
+void Select::finalize(Type type_) { type = type_; }
+
void Select::finalize() {
assert(ifTrue && ifFalse);
if (ifTrue->type == unreachable || ifFalse->type == unreachable ||
condition->type == unreachable) {
type = unreachable;
} else {
- type = ifTrue->type;
+ type = Type::getLeastUpperBound(ifTrue->type, ifFalse->type);
}
}
@@ -864,16 +852,20 @@ void Host::finalize() {
}
}
-void Try::finalize() {
- if (body->type == catchBody->type) {
- type = body->type;
- } else if (body->type.isConcrete() && catchBody->type == unreachable) {
- type = body->type;
- } else if (catchBody->type.isConcrete() && body->type == unreachable) {
- type = catchBody->type;
- } else {
- type = none;
+void RefNull::finalize() { type = Type::nullref; }
+
+void RefIsNull::finalize() {
+ if (value->type == Type::unreachable) {
+ type = Type::unreachable;
+ return;
}
+ type = Type::i32;
+}
+
+void RefFunc::finalize() { type = Type::funcref; }
+
+void Try::finalize() {
+ type = Type::getLeastUpperBound(body->type, catchBody->type);
}
void Try::finalize(Type type_) {
diff --git a/src/wasm2js.h b/src/wasm2js.h
index 2cacbe8e5..1adde23b8 100644
--- a/src/wasm2js.h
+++ b/src/wasm2js.h
@@ -1848,6 +1848,18 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m,
unimplemented(curr);
WASM_UNREACHABLE("unimp");
}
+ Ref visitRefNull(RefNull* curr) {
+ unimplemented(curr);
+ WASM_UNREACHABLE("unimp");
+ }
+ Ref visitRefIsNull(RefIsNull* curr) {
+ unimplemented(curr);
+ WASM_UNREACHABLE("unimp");
+ }
+ Ref visitRefFunc(RefFunc* curr) {
+ unimplemented(curr);
+ WASM_UNREACHABLE("unimp");
+ }
Ref visitTry(Try* curr) {
unimplemented(curr);
WASM_UNREACHABLE("unimp");