diff options
author | Alon Zakai <alonzakai@gmail.com> | 2017-10-02 13:52:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-10-02 13:52:16 -0700 |
commit | 1f8d8a53e8fcee0791c11345fd7f328255cfa22c (patch) | |
tree | 2effc81da752bb71c486106bb5fe71e9e174389e /src | |
parent | a9f91b9774d117a13c231ef0f40861372456878f (diff) | |
download | binaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.tar.gz binaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.tar.bz2 binaryen-1f8d8a53e8fcee0791c11345fd7f328255cfa22c.zip |
Fast validation (#1204)
This makes wasm validation parallel (the function part). This makes loading+validating tanks (a 12MB wasm file) 2.3x faster on a 4-core machine (from 3.5 to 1.5 seconds). It's a big speedup because most of loading+validating was actually validating.
It's also noticeable during compilation, since we validate by default at the end. 8% faster on -O2 and 23% on -O0. So actually fairly significant on -O0 builds.
As a bonus, this PR also moves the code from being 99% in the header to be 1% in the header.
Diffstat (limited to 'src')
-rw-r--r-- | src/tools/wasm-reduce.cpp | 4 | ||||
-rw-r--r-- | src/wasm-validator.h | 186 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 668 |
3 files changed, 481 insertions, 377 deletions
diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 355b70278..29484cc7a 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -414,9 +414,7 @@ struct Reducer : public WalkerPass<PostWalker<Reducer, UnifiedExpressionVisitor< } for (auto& func : functions) { curr->removeFunction(func.name); - WasmValidator validator; - validator.quiet = true; - if (validator.validate(*curr) && writeAndTestReduction()) { + if (WasmValidator().validate(*curr, false, true, true /* override quiet */) && writeAndTestReduction()) { std::cerr << "| removed function " << func.name << '\n'; noteReduction(); } else { diff --git a/src/wasm-validator.h b/src/wasm-validator.h index ceaaee890..250779b91 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -33,6 +33,8 @@ // about function B not existing yet, but we would care // if e.g. inside function A an i32.add receives an i64). // +// * quiet: Whether to log errors verbosely. +// #ifndef wasm_wasm_validator_h #define wasm_wasm_validator_h @@ -46,188 +48,8 @@ namespace wasm { -// Print anything that can be streamed to an ostream -template <typename T, - typename std::enable_if< - !std::is_base_of<Expression, typename std::remove_pointer<T>::type>::value - >::type* = nullptr> -inline std::ostream& printModuleComponent(T curr, std::ostream& stream) { - stream << curr << std::endl; - return stream; -} -// Extra overload for Expressions, to print type info too -inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) { - WasmPrinter::printExpression(curr, stream, false, true) << std::endl; - return stream; -} - -struct WasmValidator : public PostWalker<WasmValidator> { - bool valid = true; - - // what to validate, see comment up top - bool validateWeb = false; - bool validateGlobally = true; - - bool quiet = false; // whether to log errors verbosely - - struct BreakInfo { - WasmType type; - Index arity; - BreakInfo() {} - BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} - }; - - std::map<Name, Expression*> breakTargets; - std::map<Expression*, BreakInfo> breakInfos; - - WasmType returnType = unreachable; // type used in returns - - std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that - - std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice - - void noteLabelName(Name name); - -public: - // TODO: If we want the validator to be part of libwasm rather than libpasses, then - // Using PassRunner::getPassDebug causes a circular dependence. We should fix that, - // perhaps by moving some of the pass infrastructure into libsupport. - bool validate(Module& module, bool validateWeb_ = false, bool validateGlobally_ = true) { - validateWeb = validateWeb_; - validateGlobally = validateGlobally_; - // wasm logic validation - walkModule(&module); - // validate additional internal IR details when in pass-debug mode - if (PassRunner::getPassDebug()) { - validateBinaryenIR(module); - } - // print if an error occurred - if (!valid && !quiet) { - WasmPrinter::printModule(&module, std::cerr); - } - return valid; - } - - // visitors - - static void visitPreBlock(WasmValidator* self, Expression** currp) { - auto* curr = (*currp)->cast<Block>(); - if (curr->name.is()) self->breakTargets[curr->name] = curr; - } - - void visitBlock(Block *curr); - - static void visitPreLoop(WasmValidator* self, Expression** currp) { - auto* curr = (*currp)->cast<Loop>(); - if (curr->name.is()) self->breakTargets[curr->name] = curr; - } - - void visitLoop(Loop *curr); - void visitIf(If *curr); - - // override scan to add a pre and a post check task to all nodes - static void scan(WasmValidator* self, Expression** currp) { - PostWalker<WasmValidator>::scan(self, currp); - - auto* curr = *currp; - if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); - if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); - } - - void noteBreak(Name name, Expression* value, Expression* curr); - void visitBreak(Break *curr); - void visitSwitch(Switch *curr); - void visitCall(Call *curr); - void visitCallImport(CallImport *curr); - void visitCallIndirect(CallIndirect *curr); - void visitGetLocal(GetLocal* curr); - void visitSetLocal(SetLocal *curr); - void visitLoad(Load *curr); - void visitStore(Store *curr); - void visitAtomicRMW(AtomicRMW *curr); - void visitAtomicCmpxchg(AtomicCmpxchg *curr); - void visitAtomicWait(AtomicWait *curr); - void visitAtomicWake(AtomicWake *curr); - void visitBinary(Binary *curr); - void visitUnary(Unary *curr); - void visitSelect(Select* curr); - void visitDrop(Drop* curr); - void visitReturn(Return* curr); - void visitHost(Host* curr); - void visitImport(Import* curr); - void visitExport(Export* curr); - void visitGlobal(Global* curr); - void visitFunction(Function *curr); - - void visitMemory(Memory *curr); - void visitTable(Table* curr); - void visitModule(Module *curr); - - void doWalkFunction(Function* func) { - PostWalker<WasmValidator>::doWalkFunction(func); - } - - // helpers - private: - template <typename T, typename S> - std::ostream& fail(S text, T curr); - std::ostream& printFailureHeader(); - - template<typename T> - bool shouldBeTrue(bool result, T curr, const char* text) { - if (!result) { - fail("unexpected false: " + std::string(text), curr); - return false; - } - return result; - } - template<typename T> - bool shouldBeFalse(bool result, T curr, const char* text) { - if (result) { - fail("unexpected true: " + std::string(text), curr); - return false; - } - return result; - } - - template<typename T, typename S> - bool shouldBeEqual(S left, S right, T curr, const char* text) { - if (left != right) { - std::ostringstream ss; - ss << left << " != " << right << ": " << text; - fail(ss.str(), curr); - return false; - } - return true; - } - - template<typename T, typename S> - bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { - if (left != unreachable && left != right) { - std::ostringstream ss; - ss << left << " != " << right << ": " << text; - fail(ss.str(), curr); - return false; - } - return true; - } - - template<typename T, typename S> - bool shouldBeUnequal(S left, S right, T curr, const char* text) { - if (left == right) { - std::ostringstream ss; - ss << left << " == " << right << ": " << text; - fail(ss.str(), curr); - return false; - } - return true; - } - - void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text); - void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic, - Expression* curr); - void validateMemBytes(uint8_t bytes, WasmType type, Expression* curr); - void validateBinaryenIR(Module& wasm); +struct WasmValidator { + bool validate(Module& module, bool validateWeb = false, bool validateGlobally = true, bool quiet = false); }; } // namespace wasm diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index cfe84ac3e..ebefa8b65 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -14,21 +14,279 @@ * limitations under the License. */ -#include "wasm-validator.h" +#include <mutex> +#include <set> +#include <sstream> +#include <unordered_set> +#include "wasm.h" +#include "wasm-printing.h" +#include "wasm-validator.h" #include "ast_utils.h" #include "ast/branch-utils.h" #include "support/colors.h" namespace wasm { -void WasmValidator::noteLabelName(Name name) { + +// Print anything that can be streamed to an ostream +template <typename T, + typename std::enable_if< + !std::is_base_of<Expression, typename std::remove_pointer<T>::type>::value + >::type* = nullptr> +inline std::ostream& printModuleComponent(T curr, std::ostream& stream) { + stream << curr << std::endl; + return stream; +} + +// Extra overload for Expressions, to print type info too +inline std::ostream& printModuleComponent(Expression* curr, std::ostream& stream) { + WasmPrinter::printExpression(curr, stream, false, true) << std::endl; + return stream; +} + +// For parallel validation, we have a helper struct for coordination +struct ValidationInfo { + bool validateWeb; + bool validateGlobally; + bool quiet; + + std::atomic<bool> valid; + + // a stream of error test for each function. we print in the right order at + // the end, for deterministic output + // note errors are rare/unexpected, so it's ok to use a slow mutex here + std::mutex mutex; + std::unordered_map<Function*, std::unique_ptr<std::ostringstream>> outputs; + + ValidationInfo() { + valid.store(true); + } + + std::ostringstream& getStream(Function* func) { + std::unique_lock<std::mutex> lock(mutex); + auto iter = outputs.find(func); + if (iter != outputs.end()) return *(iter->second.get()); + auto& ret = outputs[func] = make_unique<std::ostringstream>(); + return *ret.get(); + } + + // printing and error handling support + + template <typename T, typename S> + std::ostream& fail(S text, T curr, Function* func) { + valid.store(false); + auto& stream = getStream(func); + if (quiet) return stream; + auto& ret = printFailureHeader(func); + ret << text << ", on \n"; + return printModuleComponent(curr, ret); + } + + std::ostream& printFailureHeader(Function* func) { + auto& stream = getStream(func); + if (quiet) return stream; + Colors::red(stream); + if (func) { + stream << "[wasm-validator error in function "; + Colors::green(stream); + stream << func->name; + Colors::red(stream); + stream << "] "; + } else { + stream << "[wasm-validator error in module] "; + } + Colors::normal(stream); + return stream; + } + + // checking utilities + + template<typename T> + bool shouldBeTrue(bool result, T curr, const char* text, Function* func = nullptr) { + if (!result) { + fail("unexpected false: " + std::string(text), curr, func); + return false; + } + return result; + } + template<typename T> + bool shouldBeFalse(bool result, T curr, const char* text, Function* func = nullptr) { + if (result) { + fail("unexpected true: " + std::string(text), curr, func); + return false; + } + return result; + } + + template<typename T, typename S> + bool shouldBeEqual(S left, S right, T curr, const char* text, Function* func = nullptr) { + if (left != right) { + std::ostringstream ss; + ss << left << " != " << right << ": " << text; + fail(ss.str(), curr, func); + return false; + } + return true; + } + + template<typename T, typename S> + bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text, Function* func = nullptr) { + if (left != unreachable && left != right) { + std::ostringstream ss; + ss << left << " != " << right << ": " << text; + fail(ss.str(), curr, func); + return false; + } + return true; + } + + template<typename T, typename S> + bool shouldBeUnequal(S left, S right, T curr, const char* text, Function* func = nullptr) { + if (left == right) { + std::ostringstream ss; + ss << left << " == " << right << ": " << text; + fail(ss.str(), curr, func); + return false; + } + return true; + } + + void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text, Function* func = nullptr) { + switch (ty) { + case i32: + case i64: + case unreachable: { + break; + } + default: fail(text, curr, func); + } + } + +}; + +struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> { + bool isFunctionParallel() override { return true; } + + Pass* create() override { return new FunctionValidator(&info); } + + ValidationInfo& info; + + FunctionValidator(ValidationInfo* info) : info(*info) {} + + struct BreakInfo { + WasmType type; + Index arity; + BreakInfo() {} + BreakInfo(WasmType type, Index arity) : type(type), arity(arity) {} + }; + + std::map<Name, Expression*> breakTargets; + std::map<Expression*, BreakInfo> breakInfos; + + WasmType returnType = unreachable; // type used in returns + + std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that + + std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice + + void noteLabelName(Name name); + +public: + // visitors + + static void visitPreBlock(FunctionValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Block>(); + if (curr->name.is()) self->breakTargets[curr->name] = curr; + } + + void visitBlock(Block *curr); + + static void visitPreLoop(FunctionValidator* self, Expression** currp) { + auto* curr = (*currp)->cast<Loop>(); + if (curr->name.is()) self->breakTargets[curr->name] = curr; + } + + void visitLoop(Loop *curr); + void visitIf(If *curr); + + // override scan to add a pre and a post check task to all nodes + static void scan(FunctionValidator* self, Expression** currp) { + PostWalker<FunctionValidator>::scan(self, currp); + + auto* curr = *currp; + if (curr->is<Block>()) self->pushTask(visitPreBlock, currp); + if (curr->is<Loop>()) self->pushTask(visitPreLoop, currp); + } + + void noteBreak(Name name, Expression* value, Expression* curr); + void visitBreak(Break *curr); + void visitSwitch(Switch *curr); + void visitCall(Call *curr); + void visitCallImport(CallImport *curr); + void visitCallIndirect(CallIndirect *curr); + void visitGetLocal(GetLocal* curr); + void visitSetLocal(SetLocal *curr); + void visitLoad(Load *curr); + void visitStore(Store *curr); + void visitAtomicRMW(AtomicRMW *curr); + void visitAtomicCmpxchg(AtomicCmpxchg *curr); + void visitAtomicWait(AtomicWait *curr); + void visitAtomicWake(AtomicWake *curr); + void visitBinary(Binary *curr); + void visitUnary(Unary *curr); + void visitSelect(Select* curr); + void visitDrop(Drop* curr); + void visitReturn(Return* curr); + void visitHost(Host* curr); + void visitFunction(Function *curr); + + // helpers +private: + std::ostream& getStream() { + return info.getStream(getFunction()); + } + + template<typename T> + bool shouldBeTrue(bool result, T curr, const char* text) { + return info.shouldBeTrue(result, curr, text, getFunction()); + } + template<typename T> + bool shouldBeFalse(bool result, T curr, const char* text) { + return info.shouldBeFalse(result, curr, text, getFunction()); + } + + template<typename T, typename S> + bool shouldBeEqual(S left, S right, T curr, const char* text) { + return info.shouldBeEqual(left, right, curr, text, getFunction()); + } + + template<typename T, typename S> + bool shouldBeEqualOrFirstIsUnreachable(S left, S right, T curr, const char* text) { + return info.shouldBeEqualOrFirstIsUnreachable(left, right, curr, text, getFunction()); + } + + template<typename T, typename S> + bool shouldBeUnequal(S left, S right, T curr, const char* text) { + return info.shouldBeUnequal(left, right, curr, text, getFunction()); + } + + void shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text) { + return info.shouldBeIntOrUnreachable(ty, curr, text, getFunction()); + } + + void validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic, + Expression* curr); + void validateMemBytes(uint8_t bytes, WasmType type, Expression* curr); +}; + +void FunctionValidator::noteLabelName(Name name) { if (!name.is()) return; shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that"); labelNames.insert(name); } -void WasmValidator::visitBlock(Block *curr) { +void FunctionValidator::visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { noteLabelName(curr->name); @@ -61,8 +319,8 @@ void WasmValidator::visitBlock(Block *curr) { } if (curr->list.size() > 1) { for (Index i = 0; i < curr->list.size() - 1; i++) { - if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)") && !quiet) { - std::cerr << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; + if (!shouldBeTrue(!isConcreteWasmType(curr->list[i]->type), curr, "non-final block elements returning a value must be drop()ed (binaryen's autodrop option might help you)") && !info.quiet) { + getStream() << "(on index " << i << ":\n" << curr->list[i] << "\n), type: " << curr->list[i]->type << "\n"; } } } @@ -83,7 +341,7 @@ void WasmValidator::visitBlock(Block *curr) { } } -void WasmValidator::visitLoop(Loop *curr) { +void FunctionValidator::visitLoop(Loop *curr) { if (curr->name.is()) { noteLabelName(curr->name); breakTargets.erase(curr->name); @@ -97,7 +355,7 @@ void WasmValidator::visitLoop(Loop *curr) { } } -void WasmValidator::visitIf(If *curr) { +void FunctionValidator::visitIf(If *curr) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "if condition must be valid"); if (!curr->ifFalse) { shouldBeFalse(isConcreteWasmType(curr->ifTrue->type), curr, "if without else must not return a value in body"); @@ -125,7 +383,7 @@ void WasmValidator::visitIf(If *curr) { } } -void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) { +void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr) { WasmType valueType = none; Index arity = 0; if (value) { @@ -151,65 +409,70 @@ void WasmValidator::noteBreak(Name name, Expression* value, Expression* curr) { } } } -void WasmValidator::visitBreak(Break *curr) { +void FunctionValidator::visitBreak(Break *curr) { noteBreak(curr->name, curr->value, curr); if (curr->condition) { shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "break condition must be i32"); } } -void WasmValidator::visitSwitch(Switch *curr) { +void FunctionValidator::visitSwitch(Switch *curr) { for (auto& target : curr->targets) { noteBreak(target, curr->value, curr); } noteBreak(curr->default_, curr->value, curr); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "br_table condition must be i32"); } -void WasmValidator::visitCall(Call *curr) { - if (!validateGlobally) return; + +void FunctionValidator::visitCall(Call *curr) { + if (!info.validateGlobally) return; auto* target = getModule()->getFunctionOrNull(curr->target); if (!shouldBeTrue(!!target, curr, "call target must exist")) { - if (getModule()->getImportOrNull(curr->target) && !quiet) { - std::cerr << "(perhaps it should be a CallImport instead of Call?)\n"; + if (getModule()->getImportOrNull(curr->target) && !info.quiet) { + getStream() << "(perhaps it should be a CallImport instead of Call?)\n"; } return; } if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !quiet) { - std::cerr << "(on argument " << i << ")\n"; + if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !info.quiet) { + getStream() << "(on argument " << i << ")\n"; } } } -void WasmValidator::visitCallImport(CallImport *curr) { - if (!validateGlobally) return; + +void FunctionValidator::visitCallImport(CallImport *curr) { + if (!info.validateGlobally) return; auto* import = getModule()->getImportOrNull(curr->target); if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return; auto* type = getModule()->getFunctionType(import->functionType); if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !quiet) { - std::cerr << "(on argument " << i << ")\n"; + if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) { + getStream() << "(on argument " << i << ")\n"; } } } -void WasmValidator::visitCallIndirect(CallIndirect *curr) { - if (!validateGlobally) return; + +void FunctionValidator::visitCallIndirect(CallIndirect *curr) { + if (!info.validateGlobally) return; auto* type = getModule()->getFunctionTypeOrNull(curr->fullType); if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) return; shouldBeEqualOrFirstIsUnreachable(curr->target->type, i32, curr, "indirect call target must be an i32"); if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !quiet) { - std::cerr << "(on argument " << i << ")\n"; + if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) { + getStream() << "(on argument " << i << ")\n"; } } } -void WasmValidator::visitGetLocal(GetLocal* curr) { + +void FunctionValidator::visitGetLocal(GetLocal* curr) { shouldBeTrue(isConcreteWasmType(curr->type), curr, "get_local must have a valid type - check what you provided when you constructed the node"); } -void WasmValidator::visitSetLocal(SetLocal *curr) { + +void FunctionValidator::visitSetLocal(SetLocal *curr) { shouldBeTrue(curr->index < getFunction()->getNumLocals(), curr, "set_local index must be small enough"); if (curr->value->type != unreachable) { if (curr->type != none) { // tee is ok anyhow @@ -218,39 +481,33 @@ void WasmValidator::visitSetLocal(SetLocal *curr) { shouldBeEqual(getFunction()->getLocalType(curr->index), curr->value->type, curr, "set_local type must match function"); } } -void WasmValidator::visitLoad(Load *curr) { - if (curr->isAtomic && !getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitLoad(Load *curr) { + shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); validateMemBytes(curr->bytes, curr->type, curr); validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "load pointer type must be i32"); } -void WasmValidator::visitStore(Store *curr) { - if (curr->isAtomic && !getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitStore(Store *curr) { + shouldBeFalse(curr->isAtomic && !getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); validateMemBytes(curr->bytes, curr->valueType, curr); validateAlignment(curr->align, curr->type, curr->bytes, curr->isAtomic, curr); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "store pointer type must be i32"); shouldBeUnequal(curr->value->type, none, curr, "store value type must not be none"); shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->valueType, curr, "store value type must match"); } -void WasmValidator::shouldBeIntOrUnreachable(WasmType ty, Expression* curr, const char* text) { - switch (ty) { - case i32: - case i64: - case unreachable: { - break; - } - default: fail(text, curr); - } -} -void WasmValidator::visitAtomicRMW(AtomicRMW* curr) { - if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitAtomicRMW(AtomicRMW* curr) { + shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); validateMemBytes(curr->bytes, curr->type, curr); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicRMW pointer type must be i32"); shouldBeEqualOrFirstIsUnreachable(curr->value->type, curr->type, curr, "AtomicRMW result type must match operand"); shouldBeIntOrUnreachable(curr->type, curr, "Atomic operations are only valid on int types"); } -void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) { - if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) { + shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); validateMemBytes(curr->bytes, curr->type, curr); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "cmpxchg pointer type must be i32"); if (curr->expected->type != unreachable && curr->replacement->type != unreachable) { @@ -260,21 +517,24 @@ void WasmValidator::visitAtomicCmpxchg(AtomicCmpxchg* curr) { shouldBeEqualOrFirstIsUnreachable(curr->replacement->type, curr->type, curr, "Cmpxchg result type must match replacement"); shouldBeIntOrUnreachable(curr->expected->type, curr, "Atomic operations are only valid on int types"); } -void WasmValidator::visitAtomicWait(AtomicWait* curr) { - if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitAtomicWait(AtomicWait* curr) { + shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWait must have type i32"); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWait pointer type must be i32"); shouldBeIntOrUnreachable(curr->expected->type, curr, "AtomicWait expected type must be int"); shouldBeEqualOrFirstIsUnreachable(curr->expected->type, curr->expectedType, curr, "AtomicWait expected type must match operand"); shouldBeEqualOrFirstIsUnreachable(curr->timeout->type, i64, curr, "AtomicWait timeout type must be i64"); } -void WasmValidator::visitAtomicWake(AtomicWake* curr) { - if (!getModule()->memory.shared) fail("Atomic operation with non-shared memory", curr); + +void FunctionValidator::visitAtomicWake(AtomicWake* curr) { + shouldBeFalse(!getModule()->memory.shared, curr, "Atomic operation with non-shared memory"); shouldBeEqualOrFirstIsUnreachable(curr->type, i32, curr, "AtomicWake must have type i32"); shouldBeEqualOrFirstIsUnreachable(curr->ptr->type, i32, curr, "AtomicWake pointer type must be i32"); shouldBeEqualOrFirstIsUnreachable(curr->wakeCount->type, i32, curr, "AtomicWake wakeCount type must be i32"); } -void WasmValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* curr) { + +void FunctionValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* curr) { switch (bytes) { case 1: case 2: @@ -287,10 +547,11 @@ void WasmValidator::validateMemBytes(uint8_t bytes, WasmType type, Expression* c } break; } - default: fail("Memory operations must be 1,2,4, or 8 bytes", curr); + default: info.fail("Memory operations must be 1,2,4, or 8 bytes", curr, getFunction()); } } -void WasmValidator::visitBinary(Binary *curr) { + +void FunctionValidator::visitBinary(Binary *curr) { if (curr->left->type != unreachable && curr->right->type != unreachable) { shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal"); } @@ -386,7 +647,8 @@ void WasmValidator::visitBinary(Binary *curr) { default: WASM_UNREACHABLE(); } } -void WasmValidator::visitUnary(Unary *curr) { + +void FunctionValidator::visitUnary(Unary *curr) { shouldBeUnequal(curr->value->type, none, curr, "unaries must not receive a none as their input"); if (curr->value->type == unreachable) return; // nothing to check switch (curr->op) { @@ -467,7 +729,8 @@ void WasmValidator::visitUnary(Unary *curr) { default: abort(); } } -void WasmValidator::visitSelect(Select* curr) { + +void FunctionValidator::visitSelect(Select* curr) { shouldBeUnequal(curr->ifTrue->type, none, curr, "select left must be valid"); shouldBeUnequal(curr->ifFalse->type, none, curr, "select right must be valid"); shouldBeTrue(curr->condition->type == unreachable || curr->condition->type == i32, curr, "select condition must be valid"); @@ -476,11 +739,11 @@ void WasmValidator::visitSelect(Select* curr) { } } -void WasmValidator::visitDrop(Drop* curr) { +void FunctionValidator::visitDrop(Drop* curr) { shouldBeTrue(isConcreteWasmType(curr->value->type) || curr->value->type == unreachable, curr, "can only drop a valid value"); } -void WasmValidator::visitReturn(Return* curr) { +void FunctionValidator::visitReturn(Return* curr) { if (curr->value) { if (returnType == unreachable) { returnType = curr->value->type; @@ -492,7 +755,7 @@ void WasmValidator::visitReturn(Return* curr) { } } -void WasmValidator::visitHost(Host* curr) { +void FunctionValidator::visitHost(Host* curr) { switch (curr->op) { case GrowMemory: { shouldBeEqual(curr->operands.size(), size_t(1), curr, "grow_memory must have 1 operand"); @@ -506,48 +769,7 @@ void WasmValidator::visitHost(Host* curr) { } } -void WasmValidator::visitImport(Import* curr) { - if (!validateGlobally) return; - if (curr->kind == ExternalKind::Function) { - if (validateWeb) { - auto* functionType = getModule()->getFunctionType(curr->functionType); - shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type"); - for (WasmType param : functionType->params) { - shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); - } - } - } - if (curr->kind == ExternalKind::Table) { - shouldBeTrue(getModule()->table.imported, curr->name, "Table import record exists but table is not marked as imported"); - } - if (curr->kind == ExternalKind::Memory) { - shouldBeTrue(getModule()->memory.imported, curr->name, "Memory import record exists but memory is not marked as imported"); - } -} - -void WasmValidator::visitExport(Export* curr) { - if (!validateGlobally) return; - if (curr->kind == ExternalKind::Function) { - if (validateWeb) { - Function* f = getModule()->getFunction(curr->value); - shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type"); - for (auto param : f->params) { - shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters"); - } - } - } -} - -void WasmValidator::visitGlobal(Global* curr) { - if (!validateGlobally) return; - shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null"); - shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid"); - if (!shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !quiet) { - std::cerr << "(on global " << curr->name << ")\n"; - } -} - -void WasmValidator::visitFunction(Function *curr) { +void FunctionValidator::visitFunction(Function *curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return if (curr->body->type != unreachable) { @@ -575,7 +797,7 @@ void WasmValidator::visitFunction(Function *curr) { Walker walker(seenExpressions); walker.walk(curr->body); for (auto* bad : walker.dupes) { - fail("expression seen more than once in the tree", bad); + info.fail("expression seen more than once in the tree", bad, getFunction()); } } @@ -594,73 +816,7 @@ static bool checkOffset(Expression* curr, Address add, Address max) { return offset + add <= max; } -void WasmValidator::visitMemory(Memory *curr) { - shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); - shouldBeTrue(curr->max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); - shouldBeTrue(!curr->shared || curr->hasMax(), "memory", "shared memory must have max size"); - Index mustBeGreaterOrEqual = 0; - for (auto& segment : curr->segments) { - if (!shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue; - shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable"); - Index size = segment.data.size(); - shouldBeTrue(size <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); - if (segment.offset->is<Const>()) { - Index start = segment.offset->cast<Const>()->value.geti32(); - Index end = start + size; - shouldBeTrue(end <= curr->initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); - shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory"); - mustBeGreaterOrEqual = end; - } - } -} -void WasmValidator::visitTable(Table* curr) { - for (auto& segment : curr->segments) { - shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); - shouldBeTrue(checkOffset(segment.offset, segment.data.size(), getModule()->table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable"); - for (auto name : segment.data) { - shouldBeTrue(getModule()->getFunctionOrNull(name) || getModule()->getImportOrNull(name), name, "segment name should be valid"); - } - } -} -void WasmValidator::visitModule(Module *curr) { - if (!validateGlobally) return; - // exports - std::set<Name> exportNames; - for (auto& exp : curr->exports) { - Name name = exp->value; - if (exp->kind == ExternalKind::Function) { - bool found = false; - for (auto& func : curr->functions) { - if (func->name == name) { - found = true; - break; - } - } - shouldBeTrue(found, name, "module function exports must be found"); - } else if (exp->kind == ExternalKind::Global) { - shouldBeTrue(curr->getGlobalOrNull(name), name, "module global exports must be found"); - } else if (exp->kind == ExternalKind::Table) { - shouldBeTrue(name == Name("0") || name == curr->table.name, name, "module table exports must be found"); - } else if (exp->kind == ExternalKind::Memory) { - shouldBeTrue(name == Name("0") || name == curr->memory.name, name, "module memory exports must be found"); - } else { - WASM_UNREACHABLE(); - } - Name exportName = exp->name; - shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); - exportNames.insert(exportName); - } - // start - if (curr->start.is()) { - auto func = curr->getFunctionOrNull(curr->start); - if (shouldBeTrue(func != nullptr, curr->start, "start must be found")) { - shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params"); - shouldBeTrue(func->result == none, curr, "start must not return a value"); - } - } -} - -void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes, +void FunctionValidator::validateAlignment(size_t align, WasmType type, Index bytes, bool isAtomic, Expression* curr) { if (isAtomic) { shouldBeEqual(align, (size_t)bytes, curr, "atomic accesses must have natural alignment"); @@ -672,7 +828,7 @@ void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes, case 4: case 8: break; default:{ - fail("bad alignment: " + std::to_string(align), curr); + info.fail("bad alignment: " + std::to_string(align), curr, getFunction()); break; } } @@ -692,11 +848,11 @@ void WasmValidator::validateAlignment(size_t align, WasmType type, Index bytes, } } -void WasmValidator::validateBinaryenIR(Module& wasm) { +static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> { - WasmValidator& parent; + ValidationInfo& info; - BinaryenIRValidator(WasmValidator& parent) : parent(parent) {} + BinaryenIRValidator(ValidationInfo& info) : info(info) {} void visitExpression(Expression* curr) { // check if a node type is 'stale', i.e., we forgot to finalize() the node. @@ -712,40 +868,168 @@ void WasmValidator::validateBinaryenIR(Module& wasm) { // The block has an added type, not derived from the ast itself, so it is // ok for it to be either i32 or unreachable. if (!(isConcreteWasmType(oldType) && newType == unreachable)) { - parent.printFailureHeader() << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n"; - parent.valid = false; + std::ostringstream ss; + ss << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printWasmType(oldType) << ", should be " << printWasmType(newType) << ")\n"; + info.fail(ss.str(), curr, getFunction()); } curr->type = oldType; } } }; - BinaryenIRValidator binaryenIRValidator(*this); + BinaryenIRValidator binaryenIRValidator(info); binaryenIRValidator.walkModule(&wasm); } -template <typename T, typename S> -std::ostream& WasmValidator::fail(S text, T curr) { - valid = false; - if (quiet) return std::cerr; - auto& ret = printFailureHeader() << text << ", on \n"; - return printModuleComponent(curr, ret); +// Main validator class + +static void validateImports(Module& module, ValidationInfo& info) { + for (auto& curr : module.imports) { + if (curr->kind == ExternalKind::Function) { + if (info.validateWeb) { + auto* functionType = module.getFunctionType(curr->functionType); + info.shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type"); + for (WasmType param : functionType->params) { + info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); + } + } + } + if (curr->kind == ExternalKind::Table) { + info.shouldBeTrue(module.table.imported, curr->name, "Table import record exists but table is not marked as imported"); + } + if (curr->kind == ExternalKind::Memory) { + info.shouldBeTrue(module.memory.imported, curr->name, "Memory import record exists but memory is not marked as imported"); + } + } } -std::ostream& WasmValidator::printFailureHeader() { - if (quiet) return std::cerr; - Colors::red(std::cerr); - if (getFunction()) { - std::cerr << "[wasm-validator error in function "; - Colors::green(std::cerr); - std::cerr << getFunction()->name; - Colors::red(std::cerr); - std::cerr << "] "; - } else { - std::cerr << "[wasm-validator error in module] "; +static void validateExports(Module& module, ValidationInfo& info) { + for (auto& curr : module.exports) { + if (curr->kind == ExternalKind::Function) { + if (info.validateWeb) { + Function* f = module.getFunction(curr->value); + info.shouldBeUnequal(f->result, i64, f->name, "Exported function must not have i64 return type"); + for (auto param : f->params) { + info.shouldBeUnequal(param, i64, f->name, "Exported function must not have i64 parameters"); + } + } + } + } + std::set<Name> exportNames; + for (auto& exp : module.exports) { + Name name = exp->value; + if (exp->kind == ExternalKind::Function) { + bool found = false; + for (auto& func : module.functions) { + if (func->name == name) { + found = true; + break; + } + } + info.shouldBeTrue(found, name, "module function exports must be found"); + } else if (exp->kind == ExternalKind::Global) { + info.shouldBeTrue(module.getGlobalOrNull(name), name, "module global exports must be found"); + } else if (exp->kind == ExternalKind::Table) { + info.shouldBeTrue(name == Name("0") || name == module.table.name, name, "module table exports must be found"); + } else if (exp->kind == ExternalKind::Memory) { + info.shouldBeTrue(name == Name("0") || name == module.memory.name, name, "module memory exports must be found"); + } else { + WASM_UNREACHABLE(); + } + Name exportName = exp->name; + info.shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); + exportNames.insert(exportName); + } +} + +static void validateGlobals(Module& module, ValidationInfo& info) { + for (auto& curr : module.globals) { + info.shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null"); + info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid"); + if (!info.shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !info.quiet) { + info.getStream(nullptr) << "(on global " << curr->name << ")\n"; + } } - Colors::normal(std::cerr); - return std::cerr; } +static void validateMemory(Module& module, ValidationInfo& info) { + auto& curr = module.memory; + info.shouldBeFalse(curr.initial > curr.max, "memory", "memory max >= initial"); + info.shouldBeTrue(curr.max <= Memory::kMaxSize, "memory", "max memory must be <= 4GB"); + info.shouldBeTrue(!curr.shared || curr.hasMax(), "memory", "shared memory must have max size"); + Index mustBeGreaterOrEqual = 0; + for (auto& segment : curr.segments) { + if (!info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32")) continue; + info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.memory.initial * Memory::kPageSize), segment.offset, "segment offset should be reasonable"); + Index size = segment.data.size(); + info.shouldBeTrue(size <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + if (segment.offset->is<Const>()) { + Index start = segment.offset->cast<Const>()->value.geti32(); + Index end = start + size; + info.shouldBeTrue(end <= curr.initial * Memory::kPageSize, segment.data.size(), "segment size should fit in memory"); + info.shouldBeTrue(start >= mustBeGreaterOrEqual, segment.data.size(), "segment size should fit in memory"); + mustBeGreaterOrEqual = end; + } + } +} + +static void validateTable(Module& module, ValidationInfo& info) { + auto& curr = module.table; + for (auto& segment : curr.segments) { + info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); + info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable"); + for (auto name : segment.data) { + info.shouldBeTrue(module.getFunctionOrNull(name) || module.getImportOrNull(name), name, "segment name should be valid"); + } + } +} + +static void validateModule(Module& module, ValidationInfo& info) { + // start + if (module.start.is()) { + auto func = module.getFunctionOrNull(module.start); + if (info.shouldBeTrue(func != nullptr, module.start, "start must be found")) { + info.shouldBeTrue(func->params.size() == 0, module.start, "start must have 0 params"); + info.shouldBeTrue(func->result == none, module.start, "start must not return a value"); + } + } +} + +// TODO: If we want the validator to be part of libwasm rather than libpasses, then +// Using PassRunner::getPassDebug causes a circular dependence. We should fix that, +// perhaps by moving some of the pass infrastructure into libsupport. +bool WasmValidator::validate(Module& module, bool validateWeb, bool validateGlobally, bool quiet) { + ValidationInfo info; + info.validateWeb = validateWeb; + info.validateGlobally = validateGlobally; + info.quiet = quiet; + // parallel wasm logic validation + PassRunner runner(&module); + runner.add<FunctionValidator>(&info); + runner.setIsNested(true); + runner.run(); + // validate globally + if (validateGlobally) { + validateImports(module, info); + validateExports(module, info); + validateGlobals(module, info); + validateMemory(module, info); + validateTable(module, info); + validateModule(module, info); + } + // validate additional internal IR details when in pass-debug mode + if (PassRunner::getPassDebug()) { + validateBinaryenIR(module, info); + } + // print all the data + if (!info.valid.load() && !info.quiet) { + for (auto& func : module.functions) { + std::cerr << info.getStream(func.get()).str(); + } + std::cerr << info.getStream(nullptr).str(); + // also print the module + WasmPrinter::printModule(&module, std::cerr); + } + return info.valid.load(); +} } // namespace wasm |