summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <alonzakai@gmail.com>2017-10-02 13:52:16 -0700
committerGitHub <noreply@github.com>2017-10-02 13:52:16 -0700
commit1f8d8a53e8fcee0791c11345fd7f328255cfa22c (patch)
tree2effc81da752bb71c486106bb5fe71e9e174389e /src
parenta9f91b9774d117a13c231ef0f40861372456878f (diff)
downloadbinaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.tar.gz
binaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.tar.bz2
binaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.zip
Fast validation (#1204)
This makes wasm validation parallel (the function part). This makes loading+validating tanks (a 12MB wasm file) 2.3x faster on a 4-core machine (from 3.5 to 1.5 seconds). It's a big speedup because most of loading+validating was actually validating. It's also noticeable during compilation, since we validate by default at the end. 8% faster on -O2 and 23% on -O0. So actually fairly significant on -O0 builds. As a bonus, this PR also moves the code from being 99% in the header to be 1% in the header.
Diffstat (limited to 'src')
-rw-r--r--src/tools/wasm-reduce.cpp4
-rw-r--r--src/wasm-validator.h186
-rw-r--r--src/wasm/wasm-validator.cpp668
3 files changed, 481 insertions, 377 deletions
diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp
index 355b70278..29484cc7a 100644
--- a/src/tools/wasm-reduce.cpp
+++ b/src/tools/wasm-reduce.cpp
@@ -414,9 +414,7 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor<
}
for (auto& func : functions) {
curr->removeFunction(func.name);
- WasmValidator validator;
- validator.quiet = true;
- if (validator.validate(*curr) && writeAndTestReduction()) {
+ if (WasmValidator().validate(*curr, false, true, true /* override quiet */) && writeAndTestReduction()) {
std::cerr << "| removed function " << func.name << '\n';
noteReduction();
} else {
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index ceaaee890..250779b91 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -33,6 +33,8 @@
// about function B not existing yet, but we would care
// if e.g. inside function A an i32.add receives an i64).
//
+// * quiet: Whether to log errors verbosely.
+//
#ifndef wasm_wasm_validator_h
#define wasm_wasm_validator_h
@@ -46,188 +48,8 @@
namespace wasm {
-// Print anything that can be streamed to an ostream
-template <typename T,
- typename std::enable_if<
- !std::is_base_of<Expression, typename std::remove_pointer<T>::type>::value
- >::type* = nullptr>
-inline std::ostream& printModuleComponent(T curr, std::ostream& stream) {
- stream << curr << std::endl;
- return stream;
-}
-// Extra overload for Expressions, to print type info too
-inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) {
- WasmPrinter::printExpression(curr, stream, false, true) << std::endl;
- return stream;
-}
-
-struct WasmValidator : public PostWalker<WasmValidator> {
- bool valid = true;
-
- // what to validate, see comment up top
- bool validateWeb = false;
- bool validateGlobally = true;
-
- bool quiet = false; // whether to log errors verbosely
-
- struct BreakInfo {
- WasmType type;
- Index arity;
- BreakInfo() {}
- BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
- };
-
- std::map<Name, Expression*> breakTargets;
- std::map<Expression*, BreakInfo> breakInfos;
-
- WasmType returnType = unreachable; // type used in returns
-
- std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that
-
- std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice
-
- void noteLabelName(Name name);
-
-public:
- // TODO: If we want the validator to be part of libwasm rather than libpasses, then
- // Using PassRunner::getPassDebug causes a circular dependence. We should fix that,
- // perhaps by moving some of the pass infrastructure into libsupport.
- bool validate(Module& module, bool validateWeb_ = false, bool validateGlobally_ = true) {
- validateWeb = validateWeb_;
- validateGlobally = validateGlobally_;
- // wasm logic validation
- walkModule(&module);
- // validate additional internal IR details when in pass-debug mode
- if (PassRunner::getPassDebug()) {
- validateBinaryenIR(module);
- }
- // print if an error occurred
- if (!valid && !quiet) {
- WasmPrinter::printModule(&module, std::cerr);
- }
- return valid;
- }
-
- // visitors
-
- static void visitPreBlock(WasmValidator* self, Expression** currp) {
- auto* curr = (*currp)->cast<Block>();
- if (curr->name.is()) self->breakTargets[curr->name] = curr;
- }
-
- void visitBlock(Block *curr);
-
- static void visitPreLoop(WasmValidator* self, Expression** currp) {
- auto* curr = (*currp)->cast<Loop>();
- if (curr->name.is()) self->breakTargets[curr->name] = curr;
- }
-
- void visitLoop(Loop *curr);
- void visitIf(If *curr);
-
- // override scan to add a pre and a post check task to all nodes
- static void scan(WasmValidator* self, Expression** currp) {
- PostWalker<WasmValidator>::scan(self, currp);
-
- auto* curr = *currp;
- if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
- if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
- }
-
- void noteBreak(Name name, Expression* value, Expression* curr);
- void visitBreak(Break *curr);
- void visitSwitch(Switch *curr);
- void visitCall(Call *curr);
- void visitCallImport(CallImport *curr);
- void visitCallIndirect(CallIndirect *curr);
- void visitGetLocal(GetLocal* curr);
- void visitSetLocal(SetLocal *curr);
- void visitLoad(Load *curr);
- void visitStore(Store *curr);
- void visitAtomicRMW(AtomicRMW *curr);
- void visitAtomicCmpxchg(AtomicCmpxchg *curr);
- void visitAtomicWait(AtomicWait *curr);
- void visitAtomicWake(AtomicWake *curr);
- void visitBinary(Binary *curr);
- void visitUnary(Unary *curr);
- void visitSelect(Select* curr);
- void visitDrop(Drop* curr);
- void visitReturn(Return* curr);
- void visitHost(Host* curr);
- void visitImport(Import* curr);
- void visitExport(Export* curr);
- void visitGlobal(Global* curr);
- void visitFunction(Function *curr);
-
- void visitMemory(Memory *curr);
- void visitTable(Table* curr);
- void visitModule(Module *curr);
-
- void doWalkFunction(Function* func) {
- PostWalker<WasmValidator>::doWalkFunction(func);
- }
-
- // helpers
- private:
- template <typename T, typename S>
- std::ostream& fail(S text, T curr);
- std::ostream& printFailureHeader();
-
- template<typename T>
- bool shouldBeTrue(bool result, T curr, const char* text) {
- if (!result) {
- fail("unexpected false: " + std::string(text), curr);
- return false;
- }
- return result;
- }
- template<typename T>
- bool shouldBeFalse(bool result, T curr, const char* text) {
- if (result) {
- fail("unexpected true: " + std::string(text), curr);
- return false;
- }
- return result;
- }
-
- template<typename T, typename S>
- bool shouldBeEqual(S left, S right, T curr, const char* text) {
- if (left != right) {
- std::ostringstream ss;
- ss << left << " != " << right << ": " << text;
- fail(ss.str(), curr);
- return false;
- }
- return true;
- }
-
- template<typename T, typename S>
- bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
- if (left != unreachable && left != right) {
- std::ostringstream ss;
- ss << left << " != " << right << ": " << text;
- fail(ss.str(), curr);
- return false;
- }
- return true;
- }
-
- template<typename T, typename S>
- bool shouldBeUnequal(S left, S right, T curr, const char* text) {
- if (left == right) {
- std::ostringstream ss;
- ss << left << " == " << right << ": " << text;
- fail(ss.str(), curr);
- return false;
- }
- return true;
- }
-
- void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text);
- void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic,
- Expression* curr);
- void validateMemBytes(uint8_t bytes, WasmType type, Expression* curr);
- void validateBinaryenIR(Module& wasm);
+struct WasmValidator {
+ bool validate(Module& module, bool validateWeb = false, bool validateGlobally = true, bool quiet = false);
};
} // namespace wasm
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index cfe84ac3e..ebefa8b65 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -14,21 +14,279 @@
* limitations under the License.
*/
-#include "wasm-validator.h"
+#include <mutex>
+#include <set>
+#include <sstream>
+#include <unordered_set>
+#include "wasm.h"
+#include "wasm-printing.h"
+#include "wasm-validator.h"
#include "ast_utils.h"
#include "ast/branch-utils.h"
#include "support/colors.h"
namespace wasm {
-void WasmValidator::noteLabelName(Name name) {
+
+// Print anything that can be streamed to an ostream
+template <typename T,
+ typename std::enable_if<
+ !std::is_base_of<Expression, typename std::remove_pointer<T>::type>::value
+ >::type* = nullptr>
+inline std::ostream& printModuleComponent(T curr, std::ostream& stream) {
+ stream << curr << std::endl;
+ return stream;
+}
+
+// Extra overload for Expressions, to print type info too
+inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) {
+ WasmPrinter::printExpression(curr, stream, false, true) << std::endl;
+ return stream;
+}
+
+// For parallel validation, we have a helper struct for coordination
+struct ValidationInfo {
+ bool validateWeb;
+ bool validateGlobally;
+ bool quiet;
+
+ std::atomic<bool> valid;
+
+ // a stream of error test for each function. we print in the right order at
+ // the end, for deterministic output
+ // note errors are rare/unexpected, so it's ok to use a slow mutex here
+ std::mutex mutex;
+ std::unordered_map<Function*, std::unique_ptr<std::ostringstream>> outputs;
+
+ ValidationInfo() {
+ valid.store(true);
+ }
+
+ std::ostringstream& getStream(Function* func) {
+ std::unique_lock<std::mutex> lock(mutex);
+ auto iter = outputs.find(func);
+ if (iter != outputs.end()) return *(iter->second.get());
+ auto& ret = outputs[func] = make_unique<std::ostringstream>();
+ return *ret.get();
+ }
+
+ // printing and error handling support
+
+ template <typename T, typename S>
+ std::ostream& fail(S text, T curr, Function* func) {
+ valid.store(false);
+ auto& stream = getStream(func);
+ if (quiet) return stream;
+ auto& ret = printFailureHeader(func);
+ ret << text << ", on \n";
+ return printModuleComponent(curr, ret);
+ }
+
+ std::ostream& printFailureHeader(Function* func) {
+ auto& stream = getStream(func);
+ if (quiet) return stream;
+ Colors::red(stream);
+ if (func) {
+ stream << "[wasm-validator error in function ";
+ Colors::green(stream);
+ stream << func->name;
+ Colors::red(stream);
+ stream << "] ";
+ } else {
+ stream << "[wasm-validator error in module] ";
+ }
+ Colors::normal(stream);
+ return stream;
+ }
+
+ // checking utilities
+
+ template<typename T>
+ bool shouldBeTrue(bool result, T curr, const char* text, Function* func = nullptr) {
+ if (!result) {
+ fail("unexpected false: " + std::string(text), curr, func);
+ return false;
+ }
+ return result;
+ }
+ template<typename T>
+ bool shouldBeFalse(bool result, T curr, const char* text, Function* func = nullptr) {
+ if (result) {
+ fail("unexpected true: " + std::string(text), curr, func);
+ return false;
+ }
+ return result;
+ }
+
+ template<typename T, typename S>
+ bool shouldBeEqual(S left, S right, T curr, const char* text, Function* func = nullptr) {
+ if (left != right) {
+ std::ostringstream ss;
+ ss << left << " != " << right << ": " << text;
+ fail(ss.str(), curr, func);
+ return false;
+ }
+ return true;
+ }
+
+ template<typename T, typename S>
+ bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text, Function* func = nullptr) {
+ if (left != unreachable && left != right) {
+ std::ostringstream ss;
+ ss << left << " != " << right << ": " << text;
+ fail(ss.str(), curr, func);
+ return false;
+ }
+ return true;
+ }
+
+ template<typename T, typename S>
+ bool shouldBeUnequal(S left, S right, T curr, const char* text, Function* func = nullptr) {
+ if (left == right) {
+ std::ostringstream ss;
+ ss << left << " == " << right << ": " << text;
+ fail(ss.str(), curr, func);
+ return false;
+ }
+ return true;
+ }
+
+ void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text, Function* func = nullptr) {
+ switch (ty) {
+ case i32:
+ case i64:
+ case unreachable: {
+ break;
+ }
+ default: fail(text, curr, func);
+ }
+ }
+
+};
+
+struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
+ bool isFunctionParallel() override { return true; }
+
+ Pass* create() override { return new FunctionValidator(&info); }
+
+ ValidationInfo& info;
+
+ FunctionValidator(ValidationInfo* info) : info(*info) {}
+
+ struct BreakInfo {
+ WasmType type;
+ Index arity;
+ BreakInfo() {}
+ BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {}
+ };
+
+ std::map<Name, Expression*> breakTargets;
+ std::map<Expression*, BreakInfo> breakInfos;
+
+ WasmType returnType = unreachable; // type used in returns
+
+ std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that
+
+ std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice
+
+ void noteLabelName(Name name);
+
+public:
+ // visitors
+
+ static void visitPreBlock(FunctionValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Block>();
+ if (curr->name.is()) self->breakTargets[curr->name] = curr;
+ }
+
+ void visitBlock(Block *curr);
+
+ static void visitPreLoop(FunctionValidator* self, Expression** currp) {
+ auto* curr = (*currp)->cast<Loop>();
+ if (curr->name.is()) self->breakTargets[curr->name] = curr;
+ }
+
+ void visitLoop(Loop *curr);
+ void visitIf(If *curr);
+
+ // override scan to add a pre and a post check task to all nodes
+ static void scan(FunctionValidator* self, Expression** currp) {
+ PostWalker<FunctionValidator>::scan(self, currp);
+
+ auto* curr = *currp;
+ if (curr->is<Block>()) self->pushTask(visitPreBlock, currp);
+ if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp);
+ }
+
+ void noteBreak(Name name, Expression* value, Expression* curr);
+ void visitBreak(Break *curr);
+ void visitSwitch(Switch *curr);
+ void visitCall(Call *curr);
+ void visitCallImport(CallImport *curr);
+ void visitCallIndirect(CallIndirect *curr);
+ void visitGetLocal(GetLocal* curr);
+ void visitSetLocal(SetLocal *curr);
+ void visitLoad(Load *curr);
+ void visitStore(Store *curr);
+ void visitAtomicRMW(AtomicRMW *curr);
+ void visitAtomicCmpxchg(AtomicCmpxchg *curr);
+ void visitAtomicWait(AtomicWait *curr);
+ void visitAtomicWake(AtomicWake *curr);
+ void visitBinary(Binary *curr);
+ void visitUnary(Unary *curr);
+ void visitSelect(Select* curr);
+ void visitDrop(Drop* curr);
+ void visitReturn(Return* curr);
+ void visitHost(Host* curr);
+ void visitFunction(Function *curr);
+
+ // helpers
+private:
+ std::ostream& getStream() {
+ return info.getStream(getFunction());
+ }
+
+ template<typename T>
+ bool shouldBeTrue(bool result, T curr, const char* text) {
+ return info.shouldBeTrue(result, curr, text, getFunction());
+ }
+ template<typename T>
+ bool shouldBeFalse(bool result, T curr, const char* text) {
+ return info.shouldBeFalse(result, curr, text, getFunction());
+ }
+
+ template<typename T, typename S>
+ bool shouldBeEqual(S left, S right, T curr, const char* text) {
+ return info.shouldBeEqual(left, right, curr, text, getFunction());
+ }
+
+ template<typename T, typename S>
+ bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) {
+ return info.shouldBeEqualOrFirstIsUnreachable(left, right, curr, text, getFunction());
+ }
+
+ template<typename T, typename S>
+ bool shouldBeUnequal(S left, S right, T curr, const char* text) {
+ return info.shouldBeUnequal(left, right, curr, text, getFunction());
+ }
+
+ void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text) {
+ return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction());
+ }
+
+ void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic,
+ Expression* curr);
+ void validateMemBytes(uint8_t bytes, WasmType type, Expression* curr);
+};
+
+void FunctionValidator::noteLabelName(Name name) {
if (!name.is()) return;
shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that");
labelNames.insert(name);
}
-void WasmValidator::visitBlock(Block *curr) {
+void FunctionValidator::visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
noteLabelName(curr->name);
@@ -61,8 +319,8 @@ void WasmValidator::visitBlock(Block *curr) {
}
if (curr->list.size() > 1) {
for (Index i = 0; i < curr->list.size() - 1; i++) {
- if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)") && !quiet) {
- std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
+ if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)") && !info.quiet) {
+ getStream() << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n";
}
}
}
@@ -83,7 +341,7 @@ void WasmValidator::visitBlock(Block *curr) {
}
}
-void WasmValidator::visitLoop(Loop *curr) {
+void FunctionValidator::visitLoop(Loop *curr) {
if (curr->name.is()) {
noteLabelName(curr->name);
breakTargets.erase(curr->name);
@@ -97,7 +355,7 @@ void WasmValidator::visitLoop(Loop *curr) {
}
}
-void WasmValidator::visitIf(If *curr) {
+void FunctionValidator::visitIf(If *curr) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be valid");
if (!curr->ifFalse) {
shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body");
@@ -125,7 +383,7 @@ void WasmValidator::visitIf(If *curr) {
}
}
-void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) {
+void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr) {
WasmType valueType = none;
Index arity = 0;
if (value) {
@@ -151,65 +409,70 @@ void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) {
}
}
}
-void WasmValidator::visitBreak(Break *curr) {
+void FunctionValidator::visitBreak(Break *curr) {
noteBreak(curr->name, curr->value, curr);
if (curr->condition) {
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32");
}
}
-void WasmValidator::visitSwitch(Switch *curr) {
+void FunctionValidator::visitSwitch(Switch *curr) {
for (auto& target : curr->targets) {
noteBreak(target, curr->value, curr);
}
noteBreak(curr->default_, curr->value, curr);
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32");
}
-void WasmValidator::visitCall(Call *curr) {
- if (!validateGlobally) return;
+
+void FunctionValidator::visitCall(Call *curr) {
+ if (!info.validateGlobally) return;
auto* target = getModule()->getFunctionOrNull(curr->target);
if (!shouldBeTrue(!!target, curr, "call target must exist")) {
- if (getModule()->getImportOrNull(curr->target) && !quiet) {
- std::cerr << "(perhaps it should be a CallImport instead of Call?)\n";
+ if (getModule()->getImportOrNull(curr->target) && !info.quiet) {
+ getStream() << "(perhaps it should be a CallImport instead of Call?)\n";
}
return;
}
if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !quiet) {
- std::cerr << "(on argument " << i << ")\n";
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !info.quiet) {
+ getStream() << "(on argument " << i << ")\n";
}
}
}
-void WasmValidator::visitCallImport(CallImport *curr) {
- if (!validateGlobally) return;
+
+void FunctionValidator::visitCallImport(CallImport *curr) {
+ if (!info.validateGlobally) return;
auto* import = getModule()->getImportOrNull(curr->target);
if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return;
if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return;
auto* type = getModule()->getFunctionType(import->functionType);
if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !quiet) {
- std::cerr << "(on argument " << i << ")\n";
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) {
+ getStream() << "(on argument " << i << ")\n";
}
}
}
-void WasmValidator::visitCallIndirect(CallIndirect *curr) {
- if (!validateGlobally) return;
+
+void FunctionValidator::visitCallIndirect(CallIndirect *curr) {
+ if (!info.validateGlobally) return;
auto* type = getModule()->getFunctionTypeOrNull(curr->fullType);
if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return;
shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32");
if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return;
for (size_t i = 0; i < curr->operands.size(); i++) {
- if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !quiet) {
- std::cerr << "(on argument " << i << ")\n";
+ if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) {
+ getStream() << "(on argument " << i << ")\n";
}
}
}
-void WasmValidator::visitGetLocal(GetLocal* curr) {
+
+void FunctionValidator::visitGetLocal(GetLocal* curr) {
shouldBeTrue(isConcreteWasmType(curr->type), curr, "get_local must have a valid type - check what you provided when you constructed the node");
}
-void WasmValidator::visitSetLocal(SetLocal *curr) {
+
+void FunctionValidator::visitSetLocal(SetLocal *curr) {
shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough");
if (curr->value->type != unreachable) {
if (curr->type != none) { // tee is ok anyhow
@@ -218,39 +481,33 @@ void WasmValidator::visitSetLocal(SetLocal *curr) {
shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function");
}
}
-void WasmValidator::visitLoad(Load *curr) {
- if (curr->isAtomic && !getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitLoad(Load *curr) {
+ shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
validateMemBytes(curr->bytes, curr->type, curr);
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32");
}
-void WasmValidator::visitStore(Store *curr) {
- if (curr->isAtomic && !getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitStore(Store *curr) {
+ shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
validateMemBytes(curr->bytes, curr->valueType, curr);
validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32");
shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none");
shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match");
}
-void WasmValidator::shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text) {
- switch (ty) {
- case i32:
- case i64:
- case unreachable: {
- break;
- }
- default: fail(text, curr);
- }
-}
-void WasmValidator::visitAtomicRMW(AtomicRMW* curr) {
- if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitAtomicRMW(AtomicRMW* curr) {
+ shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
validateMemBytes(curr->bytes, curr->type, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicRMW pointer type must be i32");
shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "AtomicRMW result type must match operand");
shouldBeIntOrUnreachable(curr->type, curr, "Atomic operations are only valid on int types");
}
-void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
- if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
+ shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
validateMemBytes(curr->bytes, curr->type, curr);
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "cmpxchg pointer type must be i32");
if (curr->expected->type != unreachable && curr->replacement->type != unreachable) {
@@ -260,21 +517,24 @@ void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) {
shouldBeEqualOrFirstIsUnreachable(curr->replacement->type, curr->type, curr, "Cmpxchg result type must match replacement");
shouldBeIntOrUnreachable(curr->expected->type, curr, "Atomic operations are only valid on int types");
}
-void WasmValidator::visitAtomicWait(AtomicWait* curr) {
- if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitAtomicWait(AtomicWait* curr) {
+ shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWait must have type i32");
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWait pointer type must be i32");
shouldBeIntOrUnreachable(curr->expected->type, curr, "AtomicWait expected type must be int");
shouldBeEqualOrFirstIsUnreachable(curr->expected->type, curr->expectedType, curr, "AtomicWait expected type must match operand");
shouldBeEqualOrFirstIsUnreachable(curr->timeout->type, i64, curr, "AtomicWait timeout type must be i64");
}
-void WasmValidator::visitAtomicWake(AtomicWake* curr) {
- if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr);
+
+void FunctionValidator::visitAtomicWake(AtomicWake* curr) {
+ shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory");
shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWake must have type i32");
shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWake pointer type must be i32");
shouldBeEqualOrFirstIsUnreachable(curr->wakeCount->type, i32, curr, "AtomicWake wakeCount type must be i32");
}
-void WasmValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* curr) {
+
+void FunctionValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* curr) {
switch (bytes) {
case 1:
case 2:
@@ -287,10 +547,11 @@ void WasmValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* c
}
break;
}
- default: fail("Memory operations must be 1,2,4, or 8 bytes", curr);
+ default: info.fail("Memory operations must be 1,2,4, or 8 bytes", curr, getFunction());
}
}
-void WasmValidator::visitBinary(Binary *curr) {
+
+void FunctionValidator::visitBinary(Binary *curr) {
if (curr->left->type != unreachable && curr->right->type != unreachable) {
shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
}
@@ -386,7 +647,8 @@ void WasmValidator::visitBinary(Binary *curr) {
default: WASM_UNREACHABLE();
}
}
-void WasmValidator::visitUnary(Unary *curr) {
+
+void FunctionValidator::visitUnary(Unary *curr) {
shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input");
if (curr->value->type == unreachable) return; // nothing to check
switch (curr->op) {
@@ -467,7 +729,8 @@ void WasmValidator::visitUnary(Unary *curr) {
default: abort();
}
}
-void WasmValidator::visitSelect(Select* curr) {
+
+void FunctionValidator::visitSelect(Select* curr) {
shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid");
shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid");
shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid");
@@ -476,11 +739,11 @@ void WasmValidator::visitSelect(Select* curr) {
}
}
-void WasmValidator::visitDrop(Drop* curr) {
+void FunctionValidator::visitDrop(Drop* curr) {
shouldBeTrue(isConcreteWasmType(curr->value->type) || curr->value->type == unreachable, curr, "can only drop a valid value");
}
-void WasmValidator::visitReturn(Return* curr) {
+void FunctionValidator::visitReturn(Return* curr) {
if (curr->value) {
if (returnType == unreachable) {
returnType = curr->value->type;
@@ -492,7 +755,7 @@ void WasmValidator::visitReturn(Return* curr) {
}
}
-void WasmValidator::visitHost(Host* curr) {
+void FunctionValidator::visitHost(Host* curr) {
switch (curr->op) {
case GrowMemory: {
shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand");
@@ -506,48 +769,7 @@ void WasmValidator::visitHost(Host* curr) {
}
}
-void WasmValidator::visitImport(Import* curr) {
- if (!validateGlobally) return;
- if (curr->kind == ExternalKind::Function) {
- if (validateWeb) {
- auto* functionType = getModule()->getFunctionType(curr->functionType);
- shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type");
- for (WasmType param : functionType->params) {
- shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
- }
- }
- }
- if (curr->kind == ExternalKind::Table) {
- shouldBeTrue(getModule()->table.imported, curr->name, "Table import record exists but table is not marked as imported");
- }
- if (curr->kind == ExternalKind::Memory) {
- shouldBeTrue(getModule()->memory.imported, curr->name, "Memory import record exists but memory is not marked as imported");
- }
-}
-
-void WasmValidator::visitExport(Export* curr) {
- if (!validateGlobally) return;
- if (curr->kind == ExternalKind::Function) {
- if (validateWeb) {
- Function* f = getModule()->getFunction(curr->value);
- shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type");
- for (auto param : f->params) {
- shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters");
- }
- }
- }
-}
-
-void WasmValidator::visitGlobal(Global* curr) {
- if (!validateGlobally) return;
- shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null");
- shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
- if (!shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !quiet) {
- std::cerr << "(on global " << curr->name << ")\n";
- }
-}
-
-void WasmValidator::visitFunction(Function *curr) {
+void FunctionValidator::visitFunction(Function *curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
if (curr->body->type != unreachable) {
@@ -575,7 +797,7 @@ void WasmValidator::visitFunction(Function *curr) {
Walker walker(seenExpressions);
walker.walk(curr->body);
for (auto* bad : walker.dupes) {
- fail("expression seen more than once in the tree", bad);
+ info.fail("expression seen more than once in the tree", bad, getFunction());
}
}
@@ -594,73 +816,7 @@ static bool checkOffset(Expression* curr, Address add, Address max) {
return offset + add <= max;
}
-void WasmValidator::visitMemory(Memory *curr) {
- shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
- shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
- shouldBeTrue(!curr->shared || curr->hasMax(), "memory", "shared memory must have max size");
- Index mustBeGreaterOrEqual = 0;
- for (auto& segment : curr->segments) {
- if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
- shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable");
- Index size = segment.data.size();
- shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
- if (segment.offset->is<Const>()) {
- Index start = segment.offset->cast<Const>()->value.geti32();
- Index end = start + size;
- shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
- shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
- mustBeGreaterOrEqual = end;
- }
- }
-}
-void WasmValidator::visitTable(Table* curr) {
- for (auto& segment : curr->segments) {
- shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
- shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable");
- for (auto name : segment.data) {
- shouldBeTrue(getModule()->getFunctionOrNull(name) || getModule()->getImportOrNull(name), name, "segment name should be valid");
- }
- }
-}
-void WasmValidator::visitModule(Module *curr) {
- if (!validateGlobally) return;
- // exports
- std::set<Name> exportNames;
- for (auto& exp : curr->exports) {
- Name name = exp->value;
- if (exp->kind == ExternalKind::Function) {
- bool found = false;
- for (auto& func : curr->functions) {
- if (func->name == name) {
- found = true;
- break;
- }
- }
- shouldBeTrue(found, name, "module function exports must be found");
- } else if (exp->kind == ExternalKind::Global) {
- shouldBeTrue(curr->getGlobalOrNull(name), name, "module global exports must be found");
- } else if (exp->kind == ExternalKind::Table) {
- shouldBeTrue(name == Name("0") || name == curr->table.name, name, "module table exports must be found");
- } else if (exp->kind == ExternalKind::Memory) {
- shouldBeTrue(name == Name("0") || name == curr->memory.name, name, "module memory exports must be found");
- } else {
- WASM_UNREACHABLE();
- }
- Name exportName = exp->name;
- shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
- exportNames.insert(exportName);
- }
- // start
- if (curr->start.is()) {
- auto func = curr->getFunctionOrNull(curr->start);
- if (shouldBeTrue(func != nullptr, curr->start, "start must be found")) {
- shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params");
- shouldBeTrue(func->result == none, curr, "start must not return a value");
- }
- }
-}
-
-void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
+void FunctionValidator::validateAlignment(size_t align, WasmType type, Index bytes,
bool isAtomic, Expression* curr) {
if (isAtomic) {
shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment");
@@ -672,7 +828,7 @@ void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
case 4:
case 8: break;
default:{
- fail("bad alignment: " + std::to_string(align), curr);
+ info.fail("bad alignment: " + std::to_string(align), curr, getFunction());
break;
}
}
@@ -692,11 +848,11 @@ void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes,
}
}
-void WasmValidator::validateBinaryenIR(Module& wasm) {
+static void validateBinaryenIR(Module& wasm, ValidationInfo& info) {
struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> {
- WasmValidator& parent;
+ ValidationInfo& info;
- BinaryenIRValidator(WasmValidator& parent) : parent(parent) {}
+ BinaryenIRValidator(ValidationInfo& info) : info(info) {}
void visitExpression(Expression* curr) {
// check if a node type is 'stale', i.e., we forgot to finalize() the node.
@@ -712,40 +868,168 @@ void WasmValidator::validateBinaryenIR(Module& wasm) {
// The block has an added type, not derived from the ast itself, so it is
// ok for it to be either i32 or unreachable.
if (!(isConcreteWasmType(oldType) && newType == unreachable)) {
- parent.printFailureHeader() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
- parent.valid = false;
+ std::ostringstream ss;
+ ss << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n";
+ info.fail(ss.str(), curr, getFunction());
}
curr->type = oldType;
}
}
};
- BinaryenIRValidator binaryenIRValidator(*this);
+ BinaryenIRValidator binaryenIRValidator(info);
binaryenIRValidator.walkModule(&wasm);
}
-template <typename T, typename S>
-std::ostream& WasmValidator::fail(S text, T curr) {
- valid = false;
- if (quiet) return std::cerr;
- auto& ret = printFailureHeader() << text << ", on \n";
- return printModuleComponent(curr, ret);
+// Main validator class
+
+static void validateImports(Module& module, ValidationInfo& info) {
+ for (auto& curr : module.imports) {
+ if (curr->kind == ExternalKind::Function) {
+ if (info.validateWeb) {
+ auto* functionType = module.getFunctionType(curr->functionType);
+ info.shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type");
+ for (WasmType param : functionType->params) {
+ info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters");
+ }
+ }
+ }
+ if (curr->kind == ExternalKind::Table) {
+ info.shouldBeTrue(module.table.imported, curr->name, "Table import record exists but table is not marked as imported");
+ }
+ if (curr->kind == ExternalKind::Memory) {
+ info.shouldBeTrue(module.memory.imported, curr->name, "Memory import record exists but memory is not marked as imported");
+ }
+ }
}
-std::ostream& WasmValidator::printFailureHeader() {
- if (quiet) return std::cerr;
- Colors::red(std::cerr);
- if (getFunction()) {
- std::cerr << "[wasm-validator error in function ";
- Colors::green(std::cerr);
- std::cerr << getFunction()->name;
- Colors::red(std::cerr);
- std::cerr << "] ";
- } else {
- std::cerr << "[wasm-validator error in module] ";
+static void validateExports(Module& module, ValidationInfo& info) {
+ for (auto& curr : module.exports) {
+ if (curr->kind == ExternalKind::Function) {
+ if (info.validateWeb) {
+ Function* f = module.getFunction(curr->value);
+ info.shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type");
+ for (auto param : f->params) {
+ info.shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters");
+ }
+ }
+ }
+ }
+ std::set<Name> exportNames;
+ for (auto& exp : module.exports) {
+ Name name = exp->value;
+ if (exp->kind == ExternalKind::Function) {
+ bool found = false;
+ for (auto& func : module.functions) {
+ if (func->name == name) {
+ found = true;
+ break;
+ }
+ }
+ info.shouldBeTrue(found, name, "module function exports must be found");
+ } else if (exp->kind == ExternalKind::Global) {
+ info.shouldBeTrue(module.getGlobalOrNull(name), name, "module global exports must be found");
+ } else if (exp->kind == ExternalKind::Table) {
+ info.shouldBeTrue(name == Name("0") || name == module.table.name, name, "module table exports must be found");
+ } else if (exp->kind == ExternalKind::Memory) {
+ info.shouldBeTrue(name == Name("0") || name == module.memory.name, name, "module memory exports must be found");
+ } else {
+ WASM_UNREACHABLE();
+ }
+ Name exportName = exp->name;
+ info.shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
+ exportNames.insert(exportName);
+ }
+}
+
+static void validateGlobals(Module& module, ValidationInfo& info) {
+ for (auto& curr : module.globals) {
+ info.shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null");
+ info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid");
+ if (!info.shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !info.quiet) {
+ info.getStream(nullptr) << "(on global " << curr->name << ")\n";
+ }
}
- Colors::normal(std::cerr);
- return std::cerr;
}
+static void validateMemory(Module& module, ValidationInfo& info) {
+ auto& curr = module.memory;
+ info.shouldBeFalse(curr.initial > curr.max, "memory", "memory max >= initial");
+ info.shouldBeTrue(curr.max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB");
+ info.shouldBeTrue(!curr.shared || curr.hasMax(), "memory", "shared memory must have max size");
+ Index mustBeGreaterOrEqual = 0;
+ for (auto& segment : curr.segments) {
+ if (!info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue;
+ info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable");
+ Index size = segment.data.size();
+ info.shouldBeTrue(size <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ if (segment.offset->is<Const>()) {
+ Index start = segment.offset->cast<Const>()->value.geti32();
+ Index end = start + size;
+ info.shouldBeTrue(end <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory");
+ info.shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory");
+ mustBeGreaterOrEqual = end;
+ }
+ }
+}
+
+static void validateTable(Module& module, ValidationInfo& info) {
+ auto& curr = module.table;
+ for (auto& segment : curr.segments) {
+ info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32");
+ info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable");
+ for (auto name : segment.data) {
+ info.shouldBeTrue(module.getFunctionOrNull(name) || module.getImportOrNull(name), name, "segment name should be valid");
+ }
+ }
+}
+
+static void validateModule(Module& module, ValidationInfo& info) {
+ // start
+ if (module.start.is()) {
+ auto func = module.getFunctionOrNull(module.start);
+ if (info.shouldBeTrue(func != nullptr, module.start, "start must be found")) {
+ info.shouldBeTrue(func->params.size() == 0, module.start, "start must have 0 params");
+ info.shouldBeTrue(func->result == none, module.start, "start must not return a value");
+ }
+ }
+}
+
+// TODO: If we want the validator to be part of libwasm rather than libpasses, then
+// Using PassRunner::getPassDebug causes a circular dependence. We should fix that,
+// perhaps by moving some of the pass infrastructure into libsupport.
+bool WasmValidator::validate(Module& module, bool validateWeb, bool validateGlobally, bool quiet) {
+ ValidationInfo info;
+ info.validateWeb = validateWeb;
+ info.validateGlobally = validateGlobally;
+ info.quiet = quiet;
+ // parallel wasm logic validation
+ PassRunner runner(&module);
+ runner.add<FunctionValidator>(&info);
+ runner.setIsNested(true);
+ runner.run();
+ // validate globally
+ if (validateGlobally) {
+ validateImports(module, info);
+ validateExports(module, info);
+ validateGlobals(module, info);
+ validateMemory(module, info);
+ validateTable(module, info);
+ validateModule(module, info);
+ }
+ // validate additional internal IR details when in pass-debug mode
+ if (PassRunner::getPassDebug()) {
+ validateBinaryenIR(module, info);
+ }
+ // print all the data
+ if (!info.valid.load() && !info.quiet) {
+ for (auto& func : module.functions) {
+ std::cerr << info.getStream(func.get()).str();
+ }
+ std::cerr << info.getStream(nullptr).str();
+ // also print the module
+ WasmPrinter::printModule(&module, std::cerr);
+ }
+ return info.valid.load();
+}
} // namespace wasm