summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt19
-rw-r--r--src/wasm-validator.h646
-rw-r--r--src/wasm/CMakeLists.txt1
-rw-r--r--src/wasm/wasm-validator.cpp639
4 files changed, 685 insertions, 620 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b28bd370b..287a5589b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -59,8 +59,6 @@ FOREACH(SUFFIX "_DEBUG" "_RELEASE" "_RELWITHDEBINFO" "_MINSIZEREL" "")
SET(CMAKE_ARCHIVE_OUTPUT_DIRECTORY${SUFFIX} "${PROJECT_BINARY_DIR}/lib")
ENDFOREACH()
-SET(all_passes passes)
-
IF(MSVC)
IF(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "19.0") # VS2013 and older explicitly need /arch:sse2 set, VS2015 no longer has that option, but always enabled.
ADD_COMPILE_FLAG("/arch:sse2")
@@ -172,6 +170,9 @@ IF (UNIX AND
ENDIF()
# Static libraries
+# Current (partial) dependency structure is as follows:
+# passes -> wasm -> asmjs -> support
+# TODO: It's odd that wasm should depend on asmjs, maybe we should fix that.
ADD_SUBDIRECTORY(src/ast)
ADD_SUBDIRECTORY(src/asmjs)
ADD_SUBDIRECTORY(src/cfg)
@@ -191,7 +192,7 @@ IF(BUILD_STATIC_LIB)
ELSE()
ADD_LIBRARY(binaryen SHARED ${binaryen_SOURCES})
ENDIF()
-TARGET_LINK_LIBRARIES(binaryen ${all_passes} wasm asmjs ast cfg support)
+TARGET_LINK_LIBRARIES(binaryen passes wasm asmjs ast cfg support)
INSTALL(TARGETS binaryen DESTINATION ${CMAKE_INSTALL_LIBDIR})
INSTALL(FILES src/binaryen-c.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
@@ -204,7 +205,7 @@ SET(wasm-shell_SOURCES
)
ADD_EXECUTABLE(wasm-shell
${wasm-shell_SOURCES})
-TARGET_LINK_LIBRARIES(wasm-shell wasm asmjs emscripten-optimizer ${all_passes} ast cfg support)
+TARGET_LINK_LIBRARIES(wasm-shell wasm asmjs emscripten-optimizer passes ast cfg support)
SET_PROPERTY(TARGET wasm-shell PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET wasm-shell PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS wasm-shell DESTINATION ${CMAKE_INSTALL_BINDIR})
@@ -215,7 +216,7 @@ SET(wasm-opt_SOURCES
)
ADD_EXECUTABLE(wasm-opt
${wasm-opt_SOURCES})
-TARGET_LINK_LIBRARIES(wasm-opt wasm asmjs emscripten-optimizer ${all_passes} ast cfg support)
+TARGET_LINK_LIBRARIES(wasm-opt wasm asmjs emscripten-optimizer passes ast cfg support)
SET_PROPERTY(TARGET wasm-opt PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET wasm-opt PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS wasm-opt DESTINATION ${CMAKE_INSTALL_BINDIR})
@@ -225,7 +226,7 @@ SET(wasm-merge_SOURCES
)
ADD_EXECUTABLE(wasm-merge
${wasm-merge_SOURCES})
-TARGET_LINK_LIBRARIES(wasm-merge wasm asmjs emscripten-optimizer ${all_passes} ast cfg support)
+TARGET_LINK_LIBRARIES(wasm-merge wasm asmjs emscripten-optimizer passes ast cfg support)
SET_PROPERTY(TARGET wasm-merge PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET wasm-merge PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS wasm-merge DESTINATION bin)
@@ -236,7 +237,7 @@ SET(asm2wasm_SOURCES
)
ADD_EXECUTABLE(asm2wasm
${asm2wasm_SOURCES})
-TARGET_LINK_LIBRARIES(asm2wasm emscripten-optimizer ${all_passes} wasm asmjs ast cfg support)
+TARGET_LINK_LIBRARIES(asm2wasm emscripten-optimizer passes wasm asmjs ast cfg support)
SET_PROPERTY(TARGET asm2wasm PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET asm2wasm PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS asm2wasm DESTINATION ${CMAKE_INSTALL_BINDIR})
@@ -258,7 +259,7 @@ SET(wasm_as_SOURCES
)
ADD_EXECUTABLE(wasm-as
${wasm_as_SOURCES})
-TARGET_LINK_LIBRARIES(wasm-as wasm asmjs passes ast cfg support)
+TARGET_LINK_LIBRARIES(wasm-as passes wasm asmjs ast cfg support)
SET_PROPERTY(TARGET wasm-as PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET wasm-as PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS wasm-as DESTINATION ${CMAKE_INSTALL_BINDIR})
@@ -278,7 +279,7 @@ SET(wasm-ctor-eval_SOURCES
)
ADD_EXECUTABLE(wasm-ctor-eval
${wasm-ctor-eval_SOURCES})
-TARGET_LINK_LIBRARIES(wasm-ctor-eval wasm asmjs emscripten-optimizer ${all_passes} ast cfg support)
+TARGET_LINK_LIBRARIES(wasm-ctor-eval emscripten-optimizer passes wasm asmjs ast cfg support)
SET_PROPERTY(TARGET wasm-ctor-eval PROPERTY CXX_STANDARD 11)
SET_PROPERTY(TARGET wasm-ctor-eval PROPERTY CXX_STANDARD_REQUIRED ON)
INSTALL(TARGETS wasm-ctor-eval DESTINATION bin)
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index 9420179cf..403a057fe 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -39,11 +39,8 @@
#include <set>
-#include "support/colors.h"
#include "wasm.h"
#include "wasm-printing.h"
-#include "ast_utils.h"
-#include "ast/branch-utils.h"
namespace wasm {
@@ -68,13 +65,12 @@ struct WasmValidator : public PostWalker<WasmValidator> {
std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that
- void noteLabelName(Name name) {
- if (!name.is()) return;
- shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that");
- labelNames.insert(name);
- }
+ void noteLabelName(Name name);
public:
+ // TODO: If we want the validator to be part of libwasm rather than libpasses, then
+ // Using PassRunner::getPassDebug causes a circular dependence. We should fix that,
+ // perhaps by moving some of the pass infrastructure into libsupport.
bool validate(Module& module, bool validateWeb_ = false, bool validateGlobally_ = true) {
validateWeb = validateWeb_;
validateGlobally = validateGlobally_;
@@ -98,101 +94,15 @@ public:
if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
}
- void visitBlock(Block *curr) {
- // if we are break'ed to, then the value must be right for us
- if (curr->name.is()) {
- noteLabelName(curr->name);
- if (breakInfos.count(curr) > 0) {
- auto& info = breakInfos[curr];
- if (isConcreteWasmType(curr->type)) {
- shouldBeTrue(info.arity != 0, curr, "break arities must be > 0 if block has a value");
- } else {
- shouldBeTrue(info.arity == 0, curr, "break arities must be 0 if block has no value");
- }
- // none or unreachable means a poison value that we should ignore - if consumed, it will error
- if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
- shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
- }
- if (isConcreteWasmType(curr->type) && info.arity && info.type != unreachable) {
- shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks have arity");
- }
- shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
- if (curr->list.size() > 0) {
- auto last = curr->list.back()->type;
- if (isConcreteWasmType(last) && info.type != unreachable) {
- shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
- }
- if (last == none) {
- shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
- }
- }
- }
- breakTargets[curr->name].pop_back();
- }
- if (curr->list.size() > 1) {
- for (Index i = 0; i < curr->list.size() - 1; i++) {
- if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
- std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
- }
- }
- }
- if (curr->list.size() > 0) {
- auto backType = curr->list.back()->type;
- if (!isConcreteWasmType(curr->type)) {
- if (isConcreteWasmType(backType)) {
- shouldBeTrue(curr->type == unreachable, curr, "block with no value and a last element with a value must be unreachable");
- }
- } else {
- if (isConcreteWasmType(backType)) {
- shouldBeEqual(curr->type, backType, curr, "block with value and last element with value must match types");
- } else {
- shouldBeUnequal(backType, none, curr, "block with value must not have last element that is none");
- }
- }
- }
- if (isConcreteWasmType(curr->type)) {
- shouldBeTrue(curr->list.size() > 0, curr, "block with a value must not be empty");
- }
- }
+ void visitBlock(Block *curr);
static void visitPreLoop(WasmValidator* self, Expression** currp) {
auto* curr = (*currp)->cast<Loop>();
if (curr->name.is()) self->breakTargets[curr->name].push_back(curr);
}
- void visitLoop(Loop *curr) {
- if (curr->name.is()) {
- noteLabelName(curr->name);
- breakTargets[curr->name].pop_back();
- if (breakInfos.count(curr) > 0) {
- auto& info = breakInfos[curr];
- shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
- }
- }
- if (curr->type == none) {
- shouldBeFalse(isConcreteWasmType(curr->body->type), curr, "bad body for a loop that has no value");
- }
- }
-
- void visitIf(If *curr) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be valid");
- if (!curr->ifFalse) {
- shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body");
- if (curr->condition->type != unreachable) {
- shouldBeEqual(curr->type, none, curr, "if without else and reachable condition must be none");
- }
- } else {
- if (curr->type != unreachable) {
- shouldBeEqualOrFirstIsUnreachable(curr->ifTrue->type, curr->type, curr, "returning if-else's true must have right type");
- shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, curr->type, curr, "returning if-else's false must have right type");
- } else {
- if (curr->condition->type != unreachable) {
- shouldBeEqual(curr->ifTrue->type, unreachable, curr, "unreachable if-else must have unreachable true");
- shouldBeEqual(curr->ifFalse->type, unreachable, curr, "unreachable if-else must have unreachable false");
- }
- }
- }
- }
+ void visitLoop(Loop *curr);
+ void visitIf(If *curr);
// override scan to add a pre and a post check task to all nodes
static void scan(WasmValidator* self, Expression** currp) {
@@ -203,468 +113,38 @@ public:
if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
}
- void noteBreak(Name name, Expression* value, Expression* curr) {
- WasmType valueType = none;
- Index arity = 0;
- if (value) {
- valueType = value->type;
- shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
- arity = 1;
- }
- if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
- auto* target = breakTargets[name].back();
- if (breakInfos.count(target) == 0) {
- breakInfos[target] = BreakInfo(valueType, arity);
- } else {
- auto& info = breakInfos[target];
- if (info.type == unreachable) {
- info.type = valueType;
- } else if (valueType != unreachable) {
- if (valueType != info.type) {
- info.type = none; // a poison value that must not be consumed
- }
- }
- if (arity != info.arity) {
- info.arity = Index(-1); // a poison value
- }
- }
- }
- void visitBreak(Break *curr) {
- // note breaks (that are actually taken)
- if (BranchUtils::isBranchTaken(curr)) {
- noteBreak(curr->name, curr->value, curr);
- }
- if (curr->condition) {
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
- }
- }
- void visitSwitch(Switch *curr) {
- // note breaks (that are actually taken)
- if (BranchUtils::isBranchTaken(curr)) {
- for (auto& target : curr->targets) {
- noteBreak(target, curr->value, curr);
- }
- noteBreak(curr->default_, curr->value, curr);
- }
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
- }
- void visitCall(Call *curr) {
- if (!validateGlobally) return;
- auto* target = getModule()->getFunctionOrNull(curr->target);
- if (!shouldBeTrue(!!target, curr, "call target must exist")) return;
- if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return;
- for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match")) {
- std::cerr << "(on argument " << i << ")\n";
- }
- }
- }
- void visitCallImport(CallImport *curr) {
- if (!validateGlobally) return;
- auto* import = getModule()->getImportOrNull(curr->target);
- if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
- if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return;
- auto* type = getModule()->getFunctionType(import->functionType);
- if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
- for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
- std::cerr << "(on argument " << i << ")\n";
- }
- }
- }
- void visitCallIndirect(CallIndirect *curr) {
- if (!validateGlobally) return;
- auto* type = getModule()->getFunctionTypeOrNull(curr->fullType);
- if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
- shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
- if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
- for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
- std::cerr << "(on argument " << i << ")\n";
- }
- }
- }
- void visitGetLocal(GetLocal* curr) {
- shouldBeTrue(isConcreteWasmType(curr->type), curr, "get_local must have a valid type - check what you provided when you constructed the node");
- }
- void visitSetLocal(SetLocal *curr) {
- shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
- if (curr->value->type != unreachable) {
- if (curr->type != none) { // tee is ok anyhow
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
- }
- shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
- }
- }
- void visitLoad(Load *curr) {
- validateAlignment(curr->align, curr->type, curr->bytes);
- shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
- }
- void visitStore(Store *curr) {
- validateAlignment(curr->align, curr->type, curr->bytes);
- shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
- shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
- shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
- }
- void visitBinary(Binary *curr) {
- if (curr->left->type != unreachable && curr->right->type != unreachable) {
- shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
- }
- switch (curr->op) {
- case AddInt32:
- case SubInt32:
- case MulInt32:
- case DivSInt32:
- case DivUInt32:
- case RemSInt32:
- case RemUInt32:
- case AndInt32:
- case OrInt32:
- case XorInt32:
- case ShlInt32:
- case ShrUInt32:
- case ShrSInt32:
- case RotLInt32:
- case RotRInt32:
- case EqInt32:
- case NeInt32:
- case LtSInt32:
- case LtUInt32:
- case LeSInt32:
- case LeUInt32:
- case GtSInt32:
- case GtUInt32:
- case GeSInt32:
- case GeUInt32: {
- shouldBeEqualOrFirstIsUnreachable(curr->left->type, i32, curr, "i32 op");
- break;
- }
- case AddInt64:
- case SubInt64:
- case MulInt64:
- case DivSInt64:
- case DivUInt64:
- case RemSInt64:
- case RemUInt64:
- case AndInt64:
- case OrInt64:
- case XorInt64:
- case ShlInt64:
- case ShrUInt64:
- case ShrSInt64:
- case RotLInt64:
- case RotRInt64:
- case EqInt64:
- case NeInt64:
- case LtSInt64:
- case LtUInt64:
- case LeSInt64:
- case LeUInt64:
- case GtSInt64:
- case GtUInt64:
- case GeSInt64:
- case GeUInt64: {
- shouldBeEqualOrFirstIsUnreachable(curr->left->type, i64, curr, "i64 op");
- break;
- }
- case AddFloat32:
- case SubFloat32:
- case MulFloat32:
- case DivFloat32:
- case CopySignFloat32:
- case MinFloat32:
- case MaxFloat32:
- case EqFloat32:
- case NeFloat32:
- case LtFloat32:
- case LeFloat32:
- case GtFloat32:
- case GeFloat32: {
- shouldBeEqualOrFirstIsUnreachable(curr->left->type, f32, curr, "f32 op");
- break;
- }
- case AddFloat64:
- case SubFloat64:
- case MulFloat64:
- case DivFloat64:
- case CopySignFloat64:
- case MinFloat64:
- case MaxFloat64:
- case EqFloat64:
- case NeFloat64:
- case LtFloat64:
- case LeFloat64:
- case GtFloat64:
- case GeFloat64: {
- shouldBeEqualOrFirstIsUnreachable(curr->left->type, f64, curr, "f64 op");
- break;
- }
- default: WASM_UNREACHABLE();
- }
- }
- void visitUnary(Unary *curr) {
- shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input");
- if (curr->value->type == unreachable) return; // nothing to check
- switch (curr->op) {
- case ClzInt32:
- case CtzInt32:
- case PopcntInt32: {
- shouldBeEqual(curr->value->type, i32, curr, "i32 unary value type must be correct");
- break;
- }
- case ClzInt64:
- case CtzInt64:
- case PopcntInt64: {
- shouldBeEqual(curr->value->type, i64, curr, "i64 unary value type must be correct");
- break;
- }
- case NegFloat32:
- case AbsFloat32:
- case CeilFloat32:
- case FloorFloat32:
- case TruncFloat32:
- case NearestFloat32:
- case SqrtFloat32: {
- shouldBeEqual(curr->value->type, f32, curr, "f32 unary value type must be correct");
- break;
- }
- case NegFloat64:
- case AbsFloat64:
- case CeilFloat64:
- case FloorFloat64:
- case TruncFloat64:
- case NearestFloat64:
- case SqrtFloat64: {
- shouldBeEqual(curr->value->type, f64, curr, "f64 unary value type must be correct");
- break;
- }
- case EqZInt32: {
- shouldBeTrue(curr->value->type == i32, curr, "i32.eqz input must be i32");
- break;
- }
- case EqZInt64: {
- shouldBeTrue(curr->value->type == i64, curr, "i64.eqz input must be i64");
- break;
- }
- case ExtendSInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break;
- case ExtendUInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break;
- case WrapInt64: shouldBeEqual(curr->value->type, i64, curr, "wrap type must be correct"); break;
- case TruncSFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
- case TruncSFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
- case TruncUFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
- case TruncUFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
- case TruncSFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
- case TruncSFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
- case TruncUFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
- case TruncUFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
- case ReinterpretFloat32: shouldBeEqual(curr->value->type, f32, curr, "reinterpret/f32 type must be correct"); break;
- case ReinterpretFloat64: shouldBeEqual(curr->value->type, f64, curr, "reinterpret/f64 type must be correct"); break;
- case ConvertUInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
- case ConvertUInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
- case ConvertSInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
- case ConvertSInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
- case ConvertUInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
- case ConvertUInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
- case ConvertSInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
- case ConvertSInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
- case PromoteFloat32: shouldBeEqual(curr->value->type, f32, curr, "promote type must be correct"); break;
- case DemoteFloat64: shouldBeEqual(curr->value->type, f64, curr, "demote type must be correct"); break;
- case ReinterpretInt32: shouldBeEqual(curr->value->type, i32, curr, "reinterpret/i32 type must be correct"); break;
- case ReinterpretInt64: shouldBeEqual(curr->value->type, i64, curr, "reinterpret/i64 type must be correct"); break;
- default: abort();
- }
- }
- void visitSelect(Select* curr) {
- shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
- shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid");
- shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid");
- if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) {
- shouldBeEqual(curr->ifTrue->type, curr->ifFalse->type, curr, "select sides must be equal");
- }
- }
-
- void visitDrop(Drop* curr) {
- shouldBeTrue(isConcreteWasmType(curr->value->type) || curr->value->type == unreachable, curr, "can only drop a valid value");
- }
-
- void visitReturn(Return* curr) {
- if (curr->value) {
- if (returnType == unreachable) {
- returnType = curr->value->type;
- } else if (curr->value->type != unreachable) {
- shouldBeEqual(curr->value->type, returnType, curr, "function results must match");
- }
- } else {
- returnType = none;
- }
- }
-
- void visitHost(Host* curr) {
- switch (curr->op) {
- case GrowMemory: {
- shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand");
- shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand");
- break;
- }
- case PageSize:
- case CurrentMemory:
- case HasFeature: break;
- default: WASM_UNREACHABLE();
- }
- }
-
- void visitImport(Import* curr) {
- if (!validateGlobally) return;
- if (curr->kind == ExternalKind::Function) {
- if (validateWeb) {
- auto* functionType = getModule()->getFunctionType(curr->functionType);
- shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type");
- for (WasmType param : functionType->params) {
- shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
- }
- }
- }
- if (curr->kind == ExternalKind::Table) {
- shouldBeTrue(getModule()->table.imported, curr->name, "Table import record exists but table is not marked as imported");
- }
- if (curr->kind == ExternalKind::Memory) {
- shouldBeTrue(getModule()->memory.imported, curr->name, "Memory import record exists but memory is not marked as imported");
- }
- }
-
- void visitExport(Export* curr) {
- if (!validateGlobally) return;
- if (curr->kind == ExternalKind::Function) {
- if (validateWeb) {
- Function* f = getModule()->getFunction(curr->value);
- shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type");
- for (auto param : f->params) {
- shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters");
- }
- }
- }
- }
-
- void visitGlobal(Global* curr) {
- if (!validateGlobally) return;
- shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null");
- shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
- if (!shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type")) {
- std::cerr << "(on global " << curr->name << '\n';
- }
- }
-
- void visitFunction(Function *curr) {
- // if function has no result, it is ignored
- // if body is unreachable, it might be e.g. a return
- if (curr->body->type != unreachable) {
- shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
- }
- if (returnType != unreachable) {
- shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns");
- }
- returnType = unreachable;
- labelNames.clear();
- }
-
- bool checkOffset(Expression* curr, Address add, Address max) {
- if (curr->is<GetGlobal>()) return true;
- auto* c = curr->dynCast<Const>();
- if (!c) return false;
- uint64_t raw = c->value.getInteger();
- if (raw > std::numeric_limits<Address::address_t>::max()) {
- return false;
- }
- if (raw + uint64_t(add) > std::numeric_limits<Address::address_t>::max()) {
- return false;
- }
- Address offset = raw;
- return offset + add <= max;
- }
-
- void visitMemory(Memory *curr) {
- shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
- shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
- Index mustBeGreaterOrEqual = 0;
- for (auto& segment : curr->segments) {
- if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
- shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable");
- Index size = segment.data.size();
- shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
- if (segment.offset->is<Const>()) {
- Index start = segment.offset->cast<Const>()->value.geti32();
- Index end = start + size;
- shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
- shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
- mustBeGreaterOrEqual = end;
- }
- }
- }
- void visitTable(Table* curr) {
- for (auto& segment : curr->segments) {
- shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
- shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable");
- }
- }
- void visitModule(Module *curr) {
- if (!validateGlobally) return;
- // exports
- std::set<Name> exportNames;
- for (auto& exp : curr->exports) {
- Name name = exp->value;
- if (exp->kind == ExternalKind::Function) {
- bool found = false;
- for (auto& func : curr->functions) {
- if (func->name == name) {
- found = true;
- break;
- }
- }
- shouldBeTrue(found, name, "module function exports must be found");
- } else if (exp->kind == ExternalKind::Global) {
- shouldBeTrue(curr->getGlobalOrNull(name), name, "module global exports must be found");
- } else if (exp->kind == ExternalKind::Table) {
- shouldBeTrue(name == Name("0") || name == curr->table.name, name, "module table exports must be found");
- } else if (exp->kind == ExternalKind::Memory) {
- shouldBeTrue(name == Name("0") || name == curr->memory.name, name, "module memory exports must be found");
- } else {
- WASM_UNREACHABLE();
- }
- Name exportName = exp->name;
- shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
- exportNames.insert(exportName);
- }
- // start
- if (curr->start.is()) {
- auto func = curr->getFunctionOrNull(curr->start);
- if (shouldBeTrue(func != nullptr, curr->start, "start must be found")) {
- shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params");
- shouldBeTrue(func->result == none, curr, "start must not return a value");
- }
- }
- }
+ void noteBreak(Name name, Expression* value, Expression* curr);
+ void visitBreak(Break *curr);
+ void visitSwitch(Switch *curr);
+ void visitCall(Call *curr);
+ void visitCallImport(CallImport *curr);
+ void visitCallIndirect(CallIndirect *curr);
+ void visitGetLocal(GetLocal* curr);
+ void visitSetLocal(SetLocal *curr);
+ void visitLoad(Load *curr);
+ void visitStore(Store *curr);
+ void visitBinary(Binary *curr);
+ void visitUnary(Unary *curr);
+ void visitSelect(Select* curr);
+ void visitDrop(Drop* curr);
+ void visitReturn(Return* curr);
+ void visitHost(Host* curr);
+ void visitImport(Import* curr);
+ void visitExport(Export* curr);
+ void visitGlobal(Global* curr);
+ void visitFunction(Function *curr);
+
+ void visitMemory(Memory *curr);
+ void visitTable(Table* curr);
+ void visitModule(Module *curr);
void doWalkFunction(Function* func) {
PostWalker<WasmValidator>::doWalkFunction(func);
}
// helpers
-
- std::ostream& fail() {
- Colors::red(std::cerr);
- if (getFunction()) {
- std::cerr << "[wasm-validator error in function ";
- Colors::green(std::cerr);
- std::cerr << getFunction()->name;
- Colors::red(std::cerr);
- std::cerr << "] ";
- } else {
- std::cerr << "[wasm-validator error in module] ";
- }
- Colors::normal(std::cerr);
- return std::cerr;
- }
-
+ private:
+ std::ostream& fail();
template<typename T>
bool shouldBeTrue(bool result, T curr, const char* text) {
if (!result) {
@@ -725,64 +205,8 @@ public:
return true;
}
- void validateAlignment(size_t align, WasmType type, Index bytes) {
- switch (align) {
- case 1:
- case 2:
- case 4:
- case 8: break;
- default:{
- fail() << "bad alignment: " << align << std::endl;
- valid = false;
- break;
- }
- }
- shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
- switch (type) {
- case i32:
- case f32: {
- shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
- break;
- }
- case i64:
- case f64: {
- shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
- break;
- }
- default: {}
- }
- }
-
- void validateBinaryenIR(Module& wasm) {
- struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> {
- WasmValidator& parent;
-
- BinaryenIRValidator(WasmValidator& parent) : parent(parent) {}
-
- void visitExpression(Expression* curr) {
- // check if a node type is 'stale', i.e., we forgot to finalize() the node.
- auto oldType = curr->type;
- ReFinalizeNode().visit(curr);
- auto newType = curr->type;
- if (newType != oldType) {
- // We accept concrete => undefined,
- // e.g.
- //
- // (drop (block (result i32) (unreachable)))
- //
- // The block has an added type, not derived from the ast itself, so it is
- // ok for it to be either i32 or unreachable.
- if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
- parent.fail() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
- parent.valid = false;
- }
- curr->type = oldType;
- }
- }
- };
- BinaryenIRValidator binaryenIRValidator(*this);
- binaryenIRValidator.walkModule(&wasm);
- }
+ void validateAlignment(size_t align, WasmType type, Index bytes);
+ void validateBinaryenIR(Module& wasm);
};
} // namespace wasm
diff --git a/src/wasm/CMakeLists.txt b/src/wasm/CMakeLists.txt
index b4b607934..1a8a9b8ba 100644
--- a/src/wasm/CMakeLists.txt
+++ b/src/wasm/CMakeLists.txt
@@ -5,5 +5,6 @@ SET(wasm_SOURCES
wasm-io.cpp
wasm-s-parser.cpp
wasm-type.cpp
+ wasm-validator.cpp
)
ADD_LIBRARY(wasm STATIC ${wasm_SOURCES})
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
new file mode 100644
index 000000000..9421070e7
--- /dev/null
+++ b/src/wasm/wasm-validator.cpp
@@ -0,0 +1,639 @@
+/*
+ * Copyright 2017 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "wasm-validator.h"
+
+#include "ast_utils.h"
+#include "ast/branch-utils.h"
+#include "support/colors.h"
+
+
+namespace wasm {
+void WasmValidator::noteLabelName(Name name) {
+ if (!name.is()) return;
+ shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that");
+ labelNames.insert(name);
+}
+
+void WasmValidator::visitBlock(Block *curr) {
+ // if we are break'ed to, then the value must be right for us
+ if (curr->name.is()) {
+ noteLabelName(curr->name);
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ if (isConcreteWasmType(curr->type)) {
+ shouldBeTrue(info.arity != 0, curr, "break arities must be > 0 if block has a value");
+ } else {
+ shouldBeTrue(info.arity == 0, curr, "break arities must be 0 if block has no value");
+ }
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (isConcreteWasmType(info.type) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks return a value");
+ }
+ if (isConcreteWasmType(curr->type) && info.arity && info.type != unreachable) {
+ shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks have arity");
+ }
+ shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
+ if (curr->list.size() > 0) {
+ auto last = curr->list.back()->type;
+ if (isConcreteWasmType(last) && info.type != unreachable) {
+ shouldBeEqual(last, info.type, curr, "block+breaks must have right type if block ends with a reachable value");
+ }
+ if (last == none) {
+ shouldBeTrue(info.arity == Index(0), curr, "if block ends with a none, breaks cannot send a value of any type");
+ }
+ }
+ }
+ breakTargets[curr->name].pop_back();
+ }
+ if (curr->list.size() > 1) {
+ for (Index i = 0; i < curr->list.size() - 1; i++) {
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)")) {
+ std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ }
+ }
+ }
+ if (curr->list.size() > 0) {
+ auto backType = curr->list.back()->type;
+ if (!isConcreteWasmType(curr->type)) {
+ if (isConcreteWasmType(backType)) {
+ shouldBeTrue(curr->type == unreachable, curr, "block with no value and a last element with a value must be unreachable");
+ }
+ } else {
+ if (isConcreteWasmType(backType)) {
+ shouldBeEqual(curr->type, backType, curr, "block with value and last element with value must match types");
+ } else {
+ shouldBeUnequal(backType, none, curr, "block with value must not have last element that is none");
+ }
+ }
+ }
+ if (isConcreteWasmType(curr->type)) {
+ shouldBeTrue(curr->list.size() > 0, curr, "block with a value must not be empty");
+ }
+}
+
+void WasmValidator::visitLoop(Loop *curr) {
+ if (curr->name.is()) {
+ noteLabelName(curr->name);
+ breakTargets[curr->name].pop_back();
+ if (breakInfos.count(curr) > 0) {
+ auto& info = breakInfos[curr];
+ shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
+ }
+ }
+ if (curr->type == none) {
+ shouldBeFalse(isConcreteWasmType(curr->body->type), curr, "bad body for a loop that has no value");
+ }
+}
+
+void WasmValidator::visitIf(If *curr) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be valid");
+ if (!curr->ifFalse) {
+ shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body");
+ if (curr->condition->type != unreachable) {
+ shouldBeEqual(curr->type, none, curr, "if without else and reachable condition must be none");
+ }
+ } else {
+ if (curr->type != unreachable) {
+ shouldBeEqualOrFirstIsUnreachable(curr->ifTrue->type, curr->type, curr, "returning if-else's true must have right type");
+ shouldBeEqualOrFirstIsUnreachable(curr->ifFalse->type, curr->type, curr, "returning if-else's false must have right type");
+ } else {
+ if (curr->condition->type != unreachable) {
+ shouldBeEqual(curr->ifTrue->type, unreachable, curr, "unreachable if-else must have unreachable true");
+ shouldBeEqual(curr->ifFalse->type, unreachable, curr, "unreachable if-else must have unreachable false");
+ }
+ }
+ }
+}
+
+void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) {
+ WasmType valueType = none;
+ Index arity = 0;
+ if (value) {
+ valueType = value->type;
+ shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
+ arity = 1;
+ }
+ if (!shouldBeTrue(breakTargets[name].size() > 0, curr, "all break targets must be valid")) return;
+ auto* target = breakTargets[name].back();
+ if (breakInfos.count(target) == 0) {
+ breakInfos[target] = BreakInfo(valueType, arity);
+ } else {
+ auto& info = breakInfos[target];
+ if (info.type == unreachable) {
+ info.type = valueType;
+ } else if (valueType != unreachable) {
+ if (valueType != info.type) {
+ info.type = none; // a poison value that must not be consumed
+ }
+ }
+ if (arity != info.arity) {
+ info.arity = Index(-1); // a poison value
+ }
+ }
+}
+void WasmValidator::visitBreak(Break *curr) {
+ // note breaks (that are actually taken)
+ if (BranchUtils::isBranchTaken(curr)) {
+ noteBreak(curr->name, curr->value, curr);
+ }
+ if (curr->condition) {
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
+ }
+}
+
+void WasmValidator::visitSwitch(Switch *curr) {
+ // note breaks (that are actually taken)
+ if (BranchUtils::isBranchTaken(curr)) {
+ for (auto& target : curr->targets) {
+ noteBreak(target, curr->value, curr);
+ }
+ noteBreak(curr->default_, curr->value, curr);
+ }
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
+}
+void WasmValidator::visitCall(Call *curr) {
+ if (!validateGlobally) return;
+ auto* target = getModule()->getFunctionOrNull(curr->target);
+ if (!shouldBeTrue(!!target, curr, "call target must exist")) return;
+ if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return;
+ for (size_t i = 0; i < curr->operands.size(); i++) {
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match")) {
+ std::cerr << "(on argument " << i << ")\n";
+ }
+ }
+}
+void WasmValidator::visitCallImport(CallImport *curr) {
+ if (!validateGlobally) return;
+ auto* import = getModule()->getImportOrNull(curr->target);
+ if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
+ if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return;
+ auto* type = getModule()->getFunctionType(import->functionType);
+ if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
+ for (size_t i = 0; i < curr->operands.size(); i++) {
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
+ std::cerr << "(on argument " << i << ")\n";
+ }
+ }
+}
+void WasmValidator::visitCallIndirect(CallIndirect *curr) {
+ if (!validateGlobally) return;
+ auto* type = getModule()->getFunctionTypeOrNull(curr->fullType);
+ if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
+ shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
+ if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
+ for (size_t i = 0; i < curr->operands.size(); i++) {
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match")) {
+ std::cerr << "(on argument " << i << ")\n";
+ }
+ }
+}
+void WasmValidator::visitGetLocal(GetLocal* curr) {
+ shouldBeTrue(isConcreteWasmType(curr->type), curr, "get_local must have a valid type - check what you provided when you constructed the node");
+}
+void WasmValidator::visitSetLocal(SetLocal *curr) {
+ shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
+ if (curr->value->type != unreachable) {
+ if (curr->type != none) { // tee is ok anyhow
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "set_local type must be correct");
+ }
+ shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
+ }
+}
+void WasmValidator::visitLoad(Load *curr) {
+ validateAlignment(curr->align, curr->type, curr->bytes);
+ shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
+}
+void WasmValidator::visitStore(Store *curr) {
+ validateAlignment(curr->align, curr->type, curr->bytes);
+ shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
+ shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
+ shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
+}
+void WasmValidator::visitBinary(Binary *curr) {
+ if (curr->left->type != unreachable && curr->right->type != unreachable) {
+ shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
+ }
+ switch (curr->op) {
+ case AddInt32:
+ case SubInt32:
+ case MulInt32:
+ case DivSInt32:
+ case DivUInt32:
+ case RemSInt32:
+ case RemUInt32:
+ case AndInt32:
+ case OrInt32:
+ case XorInt32:
+ case ShlInt32:
+ case ShrUInt32:
+ case ShrSInt32:
+ case RotLInt32:
+ case RotRInt32:
+ case EqInt32:
+ case NeInt32:
+ case LtSInt32:
+ case LtUInt32:
+ case LeSInt32:
+ case LeUInt32:
+ case GtSInt32:
+ case GtUInt32:
+ case GeSInt32:
+ case GeUInt32: {
+ shouldBeEqualOrFirstIsUnreachable(curr->left->type, i32, curr, "i32 op");
+ break;
+ }
+ case AddInt64:
+ case SubInt64:
+ case MulInt64:
+ case DivSInt64:
+ case DivUInt64:
+ case RemSInt64:
+ case RemUInt64:
+ case AndInt64:
+ case OrInt64:
+ case XorInt64:
+ case ShlInt64:
+ case ShrUInt64:
+ case ShrSInt64:
+ case RotLInt64:
+ case RotRInt64:
+ case EqInt64:
+ case NeInt64:
+ case LtSInt64:
+ case LtUInt64:
+ case LeSInt64:
+ case LeUInt64:
+ case GtSInt64:
+ case GtUInt64:
+ case GeSInt64:
+ case GeUInt64: {
+ shouldBeEqualOrFirstIsUnreachable(curr->left->type, i64, curr, "i64 op");
+ break;
+ }
+ case AddFloat32:
+ case SubFloat32:
+ case MulFloat32:
+ case DivFloat32:
+ case CopySignFloat32:
+ case MinFloat32:
+ case MaxFloat32:
+ case EqFloat32:
+ case NeFloat32:
+ case LtFloat32:
+ case LeFloat32:
+ case GtFloat32:
+ case GeFloat32: {
+ shouldBeEqualOrFirstIsUnreachable(curr->left->type, f32, curr, "f32 op");
+ break;
+ }
+ case AddFloat64:
+ case SubFloat64:
+ case MulFloat64:
+ case DivFloat64:
+ case CopySignFloat64:
+ case MinFloat64:
+ case MaxFloat64:
+ case EqFloat64:
+ case NeFloat64:
+ case LtFloat64:
+ case LeFloat64:
+ case GtFloat64:
+ case GeFloat64: {
+ shouldBeEqualOrFirstIsUnreachable(curr->left->type, f64, curr, "f64 op");
+ break;
+ }
+ default: WASM_UNREACHABLE();
+ }
+}
+void WasmValidator::visitUnary(Unary *curr) {
+ shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input");
+ if (curr->value->type == unreachable) return; // nothing to check
+ switch (curr->op) {
+ case ClzInt32:
+ case CtzInt32:
+ case PopcntInt32: {
+ shouldBeEqual(curr->value->type, i32, curr, "i32 unary value type must be correct");
+ break;
+ }
+ case ClzInt64:
+ case CtzInt64:
+ case PopcntInt64: {
+ shouldBeEqual(curr->value->type, i64, curr, "i64 unary value type must be correct");
+ break;
+ }
+ case NegFloat32:
+ case AbsFloat32:
+ case CeilFloat32:
+ case FloorFloat32:
+ case TruncFloat32:
+ case NearestFloat32:
+ case SqrtFloat32: {
+ shouldBeEqual(curr->value->type, f32, curr, "f32 unary value type must be correct");
+ break;
+ }
+ case NegFloat64:
+ case AbsFloat64:
+ case CeilFloat64:
+ case FloorFloat64:
+ case TruncFloat64:
+ case NearestFloat64:
+ case SqrtFloat64: {
+ shouldBeEqual(curr->value->type, f64, curr, "f64 unary value type must be correct");
+ break;
+ }
+ case EqZInt32: {
+ shouldBeTrue(curr->value->type == i32, curr, "i32.eqz input must be i32");
+ break;
+ }
+ case EqZInt64: {
+ shouldBeTrue(curr->value->type == i64, curr, "i64.eqz input must be i64");
+ break;
+ }
+ case ExtendSInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break;
+ case ExtendUInt32: shouldBeEqual(curr->value->type, i32, curr, "extend type must be correct"); break;
+ case WrapInt64: shouldBeEqual(curr->value->type, i64, curr, "wrap type must be correct"); break;
+ case TruncSFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
+ case TruncSFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
+ case TruncUFloat32ToInt32: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
+ case TruncUFloat32ToInt64: shouldBeEqual(curr->value->type, f32, curr, "trunc type must be correct"); break;
+ case TruncSFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
+ case TruncSFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
+ case TruncUFloat64ToInt32: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
+ case TruncUFloat64ToInt64: shouldBeEqual(curr->value->type, f64, curr, "trunc type must be correct"); break;
+ case ReinterpretFloat32: shouldBeEqual(curr->value->type, f32, curr, "reinterpret/f32 type must be correct"); break;
+ case ReinterpretFloat64: shouldBeEqual(curr->value->type, f64, curr, "reinterpret/f64 type must be correct"); break;
+ case ConvertUInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
+ case ConvertUInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
+ case ConvertSInt32ToFloat32: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
+ case ConvertSInt32ToFloat64: shouldBeEqual(curr->value->type, i32, curr, "convert type must be correct"); break;
+ case ConvertUInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
+ case ConvertUInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
+ case ConvertSInt64ToFloat32: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
+ case ConvertSInt64ToFloat64: shouldBeEqual(curr->value->type, i64, curr, "convert type must be correct"); break;
+ case PromoteFloat32: shouldBeEqual(curr->value->type, f32, curr, "promote type must be correct"); break;
+ case DemoteFloat64: shouldBeEqual(curr->value->type, f64, curr, "demote type must be correct"); break;
+ case ReinterpretInt32: shouldBeEqual(curr->value->type, i32, curr, "reinterpret/i32 type must be correct"); break;
+ case ReinterpretInt64: shouldBeEqual(curr->value->type, i64, curr, "reinterpret/i64 type must be correct"); break;
+ default: abort();
+ }
+}
+void WasmValidator::visitSelect(Select* curr) {
+ shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
+ shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid");
+ shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid");
+ if (curr->ifTrue->type != unreachable && curr->ifFalse->type != unreachable) {
+ shouldBeEqual(curr->ifTrue->type, curr->ifFalse->type, curr, "select sides must be equal");
+ }
+}
+
+void WasmValidator::visitDrop(Drop* curr) {
+ shouldBeTrue(isConcreteWasmType(curr->value->type) || curr->value->type == unreachable, curr, "can only drop a valid value");
+}
+
+void WasmValidator::visitReturn(Return* curr) {
+ if (curr->value) {
+ if (returnType == unreachable) {
+ returnType = curr->value->type;
+ } else if (curr->value->type != unreachable) {
+ shouldBeEqual(curr->value->type, returnType, curr, "function results must match");
+ }
+ } else {
+ returnType = none;
+ }
+}
+
+void WasmValidator::visitHost(Host* curr) {
+ switch (curr->op) {
+ case GrowMemory: {
+ shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand");
+ shouldBeEqualOrFirstIsUnreachable(curr->operands[0]->type, i32, curr, "grow_memory must have i32 operand");
+ break;
+ }
+ case PageSize:
+ case CurrentMemory:
+ case HasFeature: break;
+ default: WASM_UNREACHABLE();
+ }
+}
+
+void WasmValidator::visitImport(Import* curr) {
+ if (!validateGlobally) return;
+ if (curr->kind == ExternalKind::Function) {
+ if (validateWeb) {
+ auto* functionType = getModule()->getFunctionType(curr->functionType);
+ shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type");
+ for (WasmType param : functionType->params) {
+ shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ }
+ }
+ }
+ if (curr->kind == ExternalKind::Table) {
+ shouldBeTrue(getModule()->table.imported, curr->name, "Table import record exists but table is not marked as imported");
+ }
+ if (curr->kind == ExternalKind::Memory) {
+ shouldBeTrue(getModule()->memory.imported, curr->name, "Memory import record exists but memory is not marked as imported");
+ }
+}
+
+void WasmValidator::visitExport(Export* curr) {
+ if (!validateGlobally) return;
+ if (curr->kind == ExternalKind::Function) {
+ if (validateWeb) {
+ Function* f = getModule()->getFunction(curr->value);
+ shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type");
+ for (auto param : f->params) {
+ shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters");
+ }
+ }
+ }
+}
+
+void WasmValidator::visitGlobal(Global* curr) {
+ if (!validateGlobally) return;
+ shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null");
+ shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
+ if (!shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type")) {
+ std::cerr << "(on global " << curr->name << '\n';
+ }
+}
+
+void WasmValidator::visitFunction(Function *curr) {
+ // if function has no result, it is ignored
+ // if body is unreachable, it might be e.g. a return
+ if (curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function body type must match, if function returns");
+ }
+ if (returnType != unreachable) {
+ shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns");
+ }
+ returnType = unreachable;
+ labelNames.clear();
+}
+
+static bool checkOffset(Expression* curr, Address add, Address max) {
+ if (curr->is<GetGlobal>()) return true;
+ auto* c = curr->dynCast<Const>();
+ if (!c) return false;
+ uint64_t raw = c->value.getInteger();
+ if (raw > std::numeric_limits<Address::address_t>::max()) {
+ return false;
+ }
+ if (raw + uint64_t(add) > std::numeric_limits<Address::address_t>::max()) {
+ return false;
+ }
+ Address offset = raw;
+ return offset + add <= max;
+}
+
+void WasmValidator::visitMemory(Memory *curr) {
+ shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
+ shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
+ Index mustBeGreaterOrEqual = 0;
+ for (auto& segment : curr->segments) {
+ if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
+ shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable");
+ Index size = segment.data.size();
+ shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ if (segment.offset->is<Const>()) {
+ Index start = segment.offset->cast<Const>()->value.geti32();
+ Index end = start + size;
+ shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
+ mustBeGreaterOrEqual = end;
+ }
+ }
+}
+void WasmValidator::visitTable(Table* curr) {
+ for (auto& segment : curr->segments) {
+ shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
+ shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable");
+ }
+}
+void WasmValidator::visitModule(Module *curr) {
+ if (!validateGlobally) return;
+ // exports
+ std::set<Name> exportNames;
+ for (auto& exp : curr->exports) {
+ Name name = exp->value;
+ if (exp->kind == ExternalKind::Function) {
+ bool found = false;
+ for (auto& func : curr->functions) {
+ if (func->name == name) {
+ found = true;
+ break;
+ }
+ }
+ shouldBeTrue(found, name, "module function exports must be found");
+ } else if (exp->kind == ExternalKind::Global) {
+ shouldBeTrue(curr->getGlobalOrNull(name), name, "module global exports must be found");
+ } else if (exp->kind == ExternalKind::Table) {
+ shouldBeTrue(name == Name("0") || name == curr->table.name, name, "module table exports must be found");
+ } else if (exp->kind == ExternalKind::Memory) {
+ shouldBeTrue(name == Name("0") || name == curr->memory.name, name, "module memory exports must be found");
+ } else {
+ WASM_UNREACHABLE();
+ }
+ Name exportName = exp->name;
+ shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
+ exportNames.insert(exportName);
+ }
+ // start
+ if (curr->start.is()) {
+ auto func = curr->getFunctionOrNull(curr->start);
+ if (shouldBeTrue(func != nullptr, curr->start, "start must be found")) {
+ shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params");
+ shouldBeTrue(func->result == none, curr, "start must not return a value");
+ }
+ }
+}
+
+void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes) {
+ switch (align) {
+ case 1:
+ case 2:
+ case 4:
+ case 8: break;
+ default:{
+ fail() << "bad alignment: " << align << std::endl;
+ valid = false;
+ break;
+ }
+ }
+ shouldBeTrue(align <= bytes, align, "alignment must not exceed natural");
+ switch (type) {
+ case i32:
+ case f32: {
+ shouldBeTrue(align <= 4, align, "alignment must not exceed natural");
+ break;
+ }
+ case i64:
+ case f64: {
+ shouldBeTrue(align <= 8, align, "alignment must not exceed natural");
+ break;
+ }
+ default: {}
+ }
+}
+
+void WasmValidator::validateBinaryenIR(Module& wasm) {
+ struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> {
+ WasmValidator& parent;
+
+ BinaryenIRValidator(WasmValidator& parent) : parent(parent) {}
+
+ void visitExpression(Expression* curr) {
+ // check if a node type is 'stale', i.e., we forgot to finalize() the node.
+ auto oldType = curr->type;
+ ReFinalizeNode().visit(curr);
+ auto newType = curr->type;
+ if (newType != oldType) {
+ // We accept concrete => undefined,
+ // e.g.
+ //
+ // (drop (block (result i32) (unreachable)))
+ //
+ // The block has an added type, not derived from the ast itself, so it is
+ // ok for it to be either i32 or unreachable.
+ if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
+ parent.fail() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
+ parent.valid = false;
+ }
+ curr->type = oldType;
+ }
+ }
+ };
+ BinaryenIRValidator binaryenIRValidator(*this);
+ binaryenIRValidator.walkModule(&wasm);
+}
+
+std::ostream& WasmValidator::fail() {
+ Colors::red(std::cerr);
+ if (getFunction()) {
+ std::cerr << "[wasm-validator error in function ";
+ Colors::green(std::cerr);
+ std::cerr << getFunction()->name;
+ Colors::red(std::cerr);
+ std::cerr << "] ";
+ } else {
+ std::cerr << "[wasm-validator error in module] ";
+ }
+ Colors::normal(std::cerr);
+ return std::cerr;
+}
+
+
+} // namespace wasm