summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/binaryen-shell.cpp1
-rw-r--r--src/wasm-binary.h6
-rw-r--r--src/wasm-s-parser.h3
-rw-r--r--src/wasm-validator.h183
4 files changed, 167 insertions, 26 deletions
diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp
index 509bd97dc..f988379c4 100644
--- a/src/binaryen-shell.cpp
+++ b/src/binaryen-shell.cpp
@@ -217,6 +217,7 @@ int main(int argc, const char* argv[]) {
std::unique_ptr<SExpressionWasmBuilder> builder(
new SExpressionWasmBuilder(wasm, *root[i], [&]() { abort(); }));
i++;
+ assert(WasmValidator().validate(wasm));
MixedArena moreModuleAllocations;
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 0bd86ffb6..89073b8be 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -31,6 +31,7 @@
#include "asm_v_wasm.h"
#include "wasm-builder.h"
#include "ast_utils.h"
+#include "wasm-validator.h"
namespace wasm {
@@ -1187,6 +1188,10 @@ public:
}
processFunctions();
+
+ if (!WasmValidator().validate(wasm)) {
+ abort();
+ }
}
bool more() {
@@ -1727,6 +1732,7 @@ public:
curr->name = getBreakName(getU32LEB());
if (code == BinaryConsts::BrIf) curr->condition = popExpression();
if (arity == 1) curr->value = popExpression();
+ curr->finalize();
}
void visitSwitch(Switch *curr) {
if (debug) std::cerr << "zz node: Switch" << std::endl;
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 674341d1e..05715f739 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -424,10 +424,10 @@ private:
if (!autoBlock) {
autoBlock = allocator.alloc<Block>();
autoBlock->list.push_back(body);
- autoBlock->finalize();
body = autoBlock;
}
autoBlock->list.push_back(ex);
+ autoBlock->finalize();
}
}
}
@@ -1023,6 +1023,7 @@ private:
} else {
ret->value = parseExpression(s[i]);
}
+ ret->finalize();
return ret;
}
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index ca168e48b..d71a3413a 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -22,6 +22,7 @@
#define wasm_wasm_validator_h
#include "wasm.h"
+#include "wasm-printing.h"
namespace wasm {
@@ -42,32 +43,50 @@ public:
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- if (breakTypes.count(curr->name) > 0) {
- shouldBeTrue(curr->type == breakTypes[curr->name]);
- breakTypes.erase(curr->name);
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
}
+ breakTypes.erase(curr->name);
}
}
void visitLoop(Loop *curr) {
if (curr->in.is()) {
LoopChildChecker childChecker(curr->in);
childChecker.walk(curr->body);
- shouldBeTrue(childChecker.valid);
+ shouldBeTrue(childChecker.valid, curr, "loop must return none");
+ breakTypes.erase(curr->in);
+ }
+ if (curr->out.is()) {
+ breakTypes.erase(curr->out);
}
}
- void visitBreak(Break *curr) {
+ void noteBreak(Name name, Expression* value) {
WasmType valueType = none;
- if (curr->value) {
- valueType = curr->value->type;
+ if (value) {
+ valueType = value->type;
}
- if (breakTypes.count(curr->name) == 0) {
- breakTypes[curr->name] = valueType;
+ if (breakTypes.count(name) == 0) {
+ breakTypes[name] = valueType;
} else {
- shouldBeTrue(valueType == breakTypes[curr->name]);
+ if (breakTypes[name] == unreachable) {
+ breakTypes[name] = valueType;
+ } else {
+ shouldBeEqual(valueType, breakTypes[name], name.str, "breaks to same target must have same type (ignoring unreachable)");
+ }
}
}
+ void visitBreak(Break *curr) {
+ noteBreak(curr->name, curr->value);
+ }
+ void visitSwitch(Switch *curr) {
+ for (auto& target : curr->targets) {
+ noteBreak(target, curr->value);
+ }
+ noteBreak(curr->default_, curr->value);
+ }
void visitSetLocal(SetLocal *curr) {
- shouldBeTrue(curr->type == curr->value->type);
+ shouldBeTrue(curr->type == curr->value->type, curr, "set_local type might be correct");
}
void visitLoad(Load *curr) {
validateAlignment(curr->align);
@@ -75,28 +94,77 @@ public:
void visitStore(Store *curr) {
validateAlignment(curr->align);
}
- void visitSwitch(Switch *curr) {
+ void visitBinary(Binary *curr) {
+ if (curr->left->type != unreachable && curr->right->type != unreachable) {
+ shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
+ }
}
void visitUnary(Unary *curr) {
- shouldBeTrue(curr->value->type == curr->type);
+ switch (curr->op) {
+ case Clz:
+ case Ctz:
+ case Popcnt:
+ case Neg:
+ case Abs:
+ case Ceil:
+ case Floor:
+ case Trunc:
+ case Nearest:
+ case Sqrt: {
+ //if (curr->value->type != unreachable) {
+ shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type");
+ //}
+ break;
+ }
+ case EqZ: {
+ shouldBeEqual(curr->type, i32, curr, "relational unaries must return i32");
+ break;
+ }
+ case ExtendSInt32:
+ case ExtendUInt32:
+ case WrapInt64:
+ case TruncSFloat32:
+ case TruncUFloat32:
+ case TruncSFloat64:
+ case TruncUFloat64:
+ case ReinterpretFloat:
+ case ConvertUInt32:
+ case ConvertSInt32:
+ case ConvertUInt64:
+ case ConvertSInt64:
+ case PromoteFloat32:
+ case DemoteFloat64:
+ case ReinterpretInt: {
+ //if (curr->value->type != unreachable) {
+ shouldBeUnequal(curr->value->type, curr->type, curr, "conversion unaries must not return the same type");
+ //}
+ break;
+ }
+ default: abort();
+ }
}
void visitFunction(Function *curr) {
- shouldBeTrue(curr->result == curr->body->type);
+ // if function has no result, it is ignored
+ // if body is unreachable, it might be e.g. a return
+ if (curr->result != none && curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function result must match, if function returns");
+ }
}
void visitMemory(Memory *curr) {
- shouldBeFalse(curr->initial > curr->max);
+ shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
size_t top = 0;
for (auto& segment : curr->segments) {
- shouldBeFalse(segment.offset < top);
+ shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough");
top = segment.offset + segment.data.size();
}
- shouldBeFalse(top > curr->initial);
+ shouldBeFalse(top > curr->initial * Memory::kPageSize, "memory", "total segments must be small enough");
}
void visitModule(Module *curr) {
// exports
+ std::set<Name> exportNames;
for (auto& exp : curr->exports) {
- Name name = exp->name;
+ Name name = exp->value;
bool found = false;
for (auto& func : curr->functions) {
if (func->name == name) {
@@ -104,17 +172,27 @@ public:
break;
}
}
- shouldBeTrue(found);
+ shouldBeTrue(found, name, "module exports must be found");
+ Name exportName = exp->name;
+ shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
+ exportNames.insert(exportName);
}
// start
if (curr->start.is()) {
auto func = curr->checkFunction(curr->start);
- if (shouldBeTrue(func)) {
- shouldBeTrue(func->params.size() == 0); // must be nullary
+ if (shouldBeTrue(func, curr->start, "start must be found")) {
+ shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params");
+ shouldBeTrue(func->result == none, curr, "start must not return a value");
}
}
}
+ void walk(Expression*& root) {
+ //std::cerr << "start a function " << getFunction()->name << "\n";
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::walk(root);
+ shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid");
+ }
+
private:
// the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value.
@@ -133,15 +211,69 @@ private:
// helpers
- bool shouldBeTrue(bool result) {
- if (!result) valid = false;
+ std::ostream& fail() {
+ Colors::red(std::cerr);
+ if (getFunction()) {
+ std::cerr << "[wasm-validator error in function ";
+ Colors::green(std::cerr);
+ std::cerr << getFunction()->name;
+ Colors::red(std::cerr);
+ std::cerr << "] ";
+ } else {
+ std::cerr << "[wasm-validator error in module] ";
+ }
+ Colors::normal(std::cerr);
+ return std::cerr;
+ }
+
+ template<typename T>
+ bool shouldBeTrue(bool result, T curr, const char* text) {
+ if (!result) {
+ fail() << "unexpected false: " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
return result;
}
- bool shouldBeFalse(bool result) {
- if (result) valid = false;
+ template<typename T>
+ bool shouldBeFalse(bool result, T curr, const char* text) {
+ if (result) {
+ fail() << "unexpected true: " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
return result;
}
+ template<typename T, typename S>
+ bool shouldBeEqual(S left, S right, T curr, const char* text) {
+ if (left != right) {
+ fail() << "" << left << " != " << right << ": " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+ template<typename T, typename S, typename U>
+ bool shouldBeEqual(S left, S right, T curr, U other, const char* text) {
+ if (left != right) {
+ fail() << "" << left << " != " << right << ": " << text << ", on " << curr << " / " << other << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+
+ template<typename T, typename S>
+ bool shouldBeUnequal(S left, S right, T curr, const char* text) {
+ if (left == right) {
+ fail() << "" << left << " == " << right << ": " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+
void validateAlignment(size_t align) {
switch (align) {
case 1:
@@ -149,6 +281,7 @@ private:
case 4:
case 8: break;
default:{
+ fail() << "bad alignment: " << align << std::endl;
valid = false;
break;
}