diff options
-rw-r--r-- | src/binaryen-shell.cpp | 1 | ||||
-rw-r--r-- | src/wasm-binary.h | 6 | ||||
-rw-r--r-- | src/wasm-s-parser.h | 3 | ||||
-rw-r--r-- | src/wasm-validator.h | 183 | ||||
-rw-r--r-- | test/dot_s/exit.s | 1 | ||||
-rw-r--r-- | test/dot_s/exit.wast | 1 | ||||
-rw-r--r-- | test/llvm_autogenerated/i64-load-store-alignment.s | 4 | ||||
-rw-r--r-- | test/llvm_autogenerated/i64-load-store-alignment.wast | 4 | ||||
-rw-r--r-- | test/passes/remove-unused-brs.txt | 4 | ||||
-rw-r--r-- | test/passes/remove-unused-brs.wast | 4 |
10 files changed, 177 insertions, 34 deletions
diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp index 509bd97dc..f988379c4 100644 --- a/src/binaryen-shell.cpp +++ b/src/binaryen-shell.cpp @@ -217,6 +217,7 @@ int main(int argc, const char* argv[]) { std::unique_ptr<SExpressionWasmBuilder> builder( new SExpressionWasmBuilder(wasm, *root[i], [&]() { abort(); })); i++; + assert(WasmValidator().validate(wasm)); MixedArena moreModuleAllocations; diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 0bd86ffb6..89073b8be 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -31,6 +31,7 @@ #include "asm_v_wasm.h" #include "wasm-builder.h" #include "ast_utils.h" +#include "wasm-validator.h" namespace wasm { @@ -1187,6 +1188,10 @@ public: } processFunctions(); + + if (!WasmValidator().validate(wasm)) { + abort(); + } } bool more() { @@ -1727,6 +1732,7 @@ public: curr->name = getBreakName(getU32LEB()); if (code == BinaryConsts::BrIf) curr->condition = popExpression(); if (arity == 1) curr->value = popExpression(); + curr->finalize(); } void visitSwitch(Switch *curr) { if (debug) std::cerr << "zz node: Switch" << std::endl; diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 674341d1e..05715f739 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -424,10 +424,10 @@ private: if (!autoBlock) { autoBlock = allocator.alloc<Block>(); autoBlock->list.push_back(body); - autoBlock->finalize(); body = autoBlock; } autoBlock->list.push_back(ex); + autoBlock->finalize(); } } } @@ -1023,6 +1023,7 @@ private: } else { ret->value = parseExpression(s[i]); } + ret->finalize(); return ret; } diff --git a/src/wasm-validator.h b/src/wasm-validator.h index ca168e48b..d71a3413a 100644 --- a/src/wasm-validator.h +++ b/src/wasm-validator.h @@ -22,6 +22,7 @@ #define wasm_wasm_validator_h #include "wasm.h" +#include "wasm-printing.h" namespace wasm { @@ -42,32 +43,50 @@ public: void visitBlock(Block *curr) { // if we are break'ed to, then the value must be right for us if (curr->name.is()) { - if (breakTypes.count(curr->name) > 0) { - shouldBeTrue(curr->type == breakTypes[curr->name]); - breakTypes.erase(curr->name); + // none or unreachable means a poison value that we should ignore - if consumed, it will error + if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) { + shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value"); } + breakTypes.erase(curr->name); } } void visitLoop(Loop *curr) { if (curr->in.is()) { LoopChildChecker childChecker(curr->in); childChecker.walk(curr->body); - shouldBeTrue(childChecker.valid); + shouldBeTrue(childChecker.valid, curr, "loop must return none"); + breakTypes.erase(curr->in); + } + if (curr->out.is()) { + breakTypes.erase(curr->out); } } - void visitBreak(Break *curr) { + void noteBreak(Name name, Expression* value) { WasmType valueType = none; - if (curr->value) { - valueType = curr->value->type; + if (value) { + valueType = value->type; } - if (breakTypes.count(curr->name) == 0) { - breakTypes[curr->name] = valueType; + if (breakTypes.count(name) == 0) { + breakTypes[name] = valueType; } else { - shouldBeTrue(valueType == breakTypes[curr->name]); + if (breakTypes[name] == unreachable) { + breakTypes[name] = valueType; + } else { + shouldBeEqual(valueType, breakTypes[name], name.str, "breaks to same target must have same type (ignoring unreachable)"); + } } } + void visitBreak(Break *curr) { + noteBreak(curr->name, curr->value); + } + void visitSwitch(Switch *curr) { + for (auto& target : curr->targets) { + noteBreak(target, curr->value); + } + noteBreak(curr->default_, curr->value); + } void visitSetLocal(SetLocal *curr) { - shouldBeTrue(curr->type == curr->value->type); + shouldBeTrue(curr->type == curr->value->type, curr, "set_local type might be correct"); } void visitLoad(Load *curr) { validateAlignment(curr->align); @@ -75,28 +94,77 @@ public: void visitStore(Store *curr) { validateAlignment(curr->align); } - void visitSwitch(Switch *curr) { + void visitBinary(Binary *curr) { + if (curr->left->type != unreachable && curr->right->type != unreachable) { + shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal"); + } } void visitUnary(Unary *curr) { - shouldBeTrue(curr->value->type == curr->type); + switch (curr->op) { + case Clz: + case Ctz: + case Popcnt: + case Neg: + case Abs: + case Ceil: + case Floor: + case Trunc: + case Nearest: + case Sqrt: { + //if (curr->value->type != unreachable) { + shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type"); + //} + break; + } + case EqZ: { + shouldBeEqual(curr->type, i32, curr, "relational unaries must return i32"); + break; + } + case ExtendSInt32: + case ExtendUInt32: + case WrapInt64: + case TruncSFloat32: + case TruncUFloat32: + case TruncSFloat64: + case TruncUFloat64: + case ReinterpretFloat: + case ConvertUInt32: + case ConvertSInt32: + case ConvertUInt64: + case ConvertSInt64: + case PromoteFloat32: + case DemoteFloat64: + case ReinterpretInt: { + //if (curr->value->type != unreachable) { + shouldBeUnequal(curr->value->type, curr->type, curr, "conversion unaries must not return the same type"); + //} + break; + } + default: abort(); + } } void visitFunction(Function *curr) { - shouldBeTrue(curr->result == curr->body->type); + // if function has no result, it is ignored + // if body is unreachable, it might be e.g. a return + if (curr->result != none && curr->body->type != unreachable) { + shouldBeEqual(curr->result, curr->body->type, curr->body, "function result must match, if function returns"); + } } void visitMemory(Memory *curr) { - shouldBeFalse(curr->initial > curr->max); + shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial"); size_t top = 0; for (auto& segment : curr->segments) { - shouldBeFalse(segment.offset < top); + shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough"); top = segment.offset + segment.data.size(); } - shouldBeFalse(top > curr->initial); + shouldBeFalse(top > curr->initial * Memory::kPageSize, "memory", "total segments must be small enough"); } void visitModule(Module *curr) { // exports + std::set<Name> exportNames; for (auto& exp : curr->exports) { - Name name = exp->name; + Name name = exp->value; bool found = false; for (auto& func : curr->functions) { if (func->name == name) { @@ -104,17 +172,27 @@ public: break; } } - shouldBeTrue(found); + shouldBeTrue(found, name, "module exports must be found"); + Name exportName = exp->name; + shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique"); + exportNames.insert(exportName); } // start if (curr->start.is()) { auto func = curr->checkFunction(curr->start); - if (shouldBeTrue(func)) { - shouldBeTrue(func->params.size() == 0); // must be nullary + if (shouldBeTrue(func, curr->start, "start must be found")) { + shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params"); + shouldBeTrue(func->result == none, curr, "start must not return a value"); } } } + void walk(Expression*& root) { + //std::cerr << "start a function " << getFunction()->name << "\n"; + PostWalker<WasmValidator, Visitor<WasmValidator>>::walk(root); + shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid"); + } + private: // the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value. @@ -133,15 +211,69 @@ private: // helpers - bool shouldBeTrue(bool result) { - if (!result) valid = false; + std::ostream& fail() { + Colors::red(std::cerr); + if (getFunction()) { + std::cerr << "[wasm-validator error in function "; + Colors::green(std::cerr); + std::cerr << getFunction()->name; + Colors::red(std::cerr); + std::cerr << "] "; + } else { + std::cerr << "[wasm-validator error in module] "; + } + Colors::normal(std::cerr); + return std::cerr; + } + + template<typename T> + bool shouldBeTrue(bool result, T curr, const char* text) { + if (!result) { + fail() << "unexpected false: " << text << ", on " << curr << std::endl; + valid = false; + return false; + } return result; } - bool shouldBeFalse(bool result) { - if (result) valid = false; + template<typename T> + bool shouldBeFalse(bool result, T curr, const char* text) { + if (result) { + fail() << "unexpected true: " << text << ", on " << curr << std::endl; + valid = false; + return false; + } return result; } + template<typename T, typename S> + bool shouldBeEqual(S left, S right, T curr, const char* text) { + if (left != right) { + fail() << "" << left << " != " << right << ": " << text << ", on " << curr << std::endl; + valid = false; + return false; + } + return true; + } + template<typename T, typename S, typename U> + bool shouldBeEqual(S left, S right, T curr, U other, const char* text) { + if (left != right) { + fail() << "" << left << " != " << right << ": " << text << ", on " << curr << " / " << other << std::endl; + valid = false; + return false; + } + return true; + } + + template<typename T, typename S> + bool shouldBeUnequal(S left, S right, T curr, const char* text) { + if (left == right) { + fail() << "" << left << " == " << right << ": " << text << ", on " << curr << std::endl; + valid = false; + return false; + } + return true; + } + void validateAlignment(size_t align) { switch (align) { case 1: @@ -149,6 +281,7 @@ private: case 4: case 8: break; default:{ + fail() << "bad alignment: " << align << std::endl; valid = false; break; } diff --git a/test/dot_s/exit.s b/test/dot_s/exit.s index 2fad9277f..bf24ead9e 100644 --- a/test/dot_s/exit.s +++ b/test/dot_s/exit.s @@ -7,5 +7,6 @@ main: .local i32 i32.const $push0=, 0 call exit@FUNCTION, $pop0 + unreachable .Lfunc_end0: .size main, .Lfunc_end0-main diff --git a/test/dot_s/exit.wast b/test/dot_s/exit.wast index 263fbf647..77fb13b0b 100644 --- a/test/dot_s/exit.wast +++ b/test/dot_s/exit.wast @@ -9,6 +9,7 @@ (call_import $exit (i32.const 0) ) + (unreachable) ) ) ;; METADATA: { "asmConsts": {},"staticBump": 12, "initializers": [] } diff --git a/test/llvm_autogenerated/i64-load-store-alignment.s b/test/llvm_autogenerated/i64-load-store-alignment.s index a61aa2965..2f290e044 100644 --- a/test/llvm_autogenerated/i64-load-store-alignment.s +++ b/test/llvm_autogenerated/i64-load-store-alignment.s @@ -60,7 +60,7 @@ ldi64: ldi64_a16: .param i32 .result i64 - i64.load $push0=, 0($0):p2align=4 + i64.load $push0=, 0($0):p2align=3 return $pop0 .endfunc .Lfunc_end5: @@ -219,7 +219,7 @@ sti64: .type sti64_a16,@function sti64_a16: .param i32, i64 - i64.store $discard=, 0($0):p2align=4, $1 + i64.store $discard=, 0($0):p2align=3, $1 return .endfunc .Lfunc_end20: diff --git a/test/llvm_autogenerated/i64-load-store-alignment.wast b/test/llvm_autogenerated/i64-load-store-alignment.wast index 9741e5a72..d6c76de7a 100644 --- a/test/llvm_autogenerated/i64-load-store-alignment.wast +++ b/test/llvm_autogenerated/i64-load-store-alignment.wast @@ -70,7 +70,7 @@ ) (func $ldi64_a16 (param $$0 i32) (result i64) (return - (i64.load align=16 + (i64.load (get_local $$0) ) ) @@ -174,7 +174,7 @@ (return) ) (func $sti64_a16 (param $$0 i32) (param $$1 i64) - (i64.store align=16 + (i64.store (get_local $$0) (get_local $$1) ) diff --git a/test/passes/remove-unused-brs.txt b/test/passes/remove-unused-brs.txt index a7bd70ecc..a8e7f29be 100644 --- a/test/passes/remove-unused-brs.txt +++ b/test/passes/remove-unused-brs.txt @@ -88,7 +88,7 @@ ) ) ) - (func $b12-yes (result i32) + (func $b12-yes (block $topmost (select (block $block1 @@ -144,7 +144,7 @@ ) ) ) - (func $b15 (result i32) + (func $b15 (block $topmost (br_if $topmost (i32.const 0) diff --git a/test/passes/remove-unused-brs.wast b/test/passes/remove-unused-brs.wast index dc67ad4da..8b9d42bc1 100644 --- a/test/passes/remove-unused-brs.wast +++ b/test/passes/remove-unused-brs.wast @@ -91,7 +91,7 @@ ) ) ) - (func $b12-yes (result i32) + (func $b12-yes (block $topmost (if_else (i32.const 1) (block @@ -149,7 +149,7 @@ ) ) ) - (func $b15 (result i32) + (func $b15 (block $topmost (if (i32.const 18) |