summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/ast_utils.h2
-rw-r--r--src/binaryen-shell.cpp1
-rw-r--r--src/wasm-binary.h6
-rw-r--r--src/wasm-s-parser.h3
-rw-r--r--src/wasm-validator.h183
-rw-r--r--src/wasm.cpp84
-rw-r--r--src/wasm.h17
7 files changed, 264 insertions, 32 deletions
diff --git a/src/ast_utils.h b/src/ast_utils.h
index 5b569435e..8aacb62da 100644
--- a/src/ast_utils.h
+++ b/src/ast_utils.h
@@ -23,7 +23,7 @@
namespace wasm {
struct BreakSeeker : public PostWalker<BreakSeeker, Visitor<BreakSeeker>> {
- Name target; // look for this one
+ Name target; // look for this one XXX looking by name may fall prey to duplicate names
size_t found;
BreakSeeker(Name target) : target(target), found(false) {}
diff --git a/src/binaryen-shell.cpp b/src/binaryen-shell.cpp
index 509bd97dc..f988379c4 100644
--- a/src/binaryen-shell.cpp
+++ b/src/binaryen-shell.cpp
@@ -217,6 +217,7 @@ int main(int argc, const char* argv[]) {
std::unique_ptr<SExpressionWasmBuilder> builder(
new SExpressionWasmBuilder(wasm, *root[i], [&]() { abort(); }));
i++;
+ assert(WasmValidator().validate(wasm));
MixedArena moreModuleAllocations;
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 0bd86ffb6..89073b8be 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -31,6 +31,7 @@
#include "asm_v_wasm.h"
#include "wasm-builder.h"
#include "ast_utils.h"
+#include "wasm-validator.h"
namespace wasm {
@@ -1187,6 +1188,10 @@ public:
}
processFunctions();
+
+ if (!WasmValidator().validate(wasm)) {
+ abort();
+ }
}
bool more() {
@@ -1727,6 +1732,7 @@ public:
curr->name = getBreakName(getU32LEB());
if (code == BinaryConsts::BrIf) curr->condition = popExpression();
if (arity == 1) curr->value = popExpression();
+ curr->finalize();
}
void visitSwitch(Switch *curr) {
if (debug) std::cerr << "zz node: Switch" << std::endl;
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 674341d1e..05715f739 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -424,10 +424,10 @@ private:
if (!autoBlock) {
autoBlock = allocator.alloc<Block>();
autoBlock->list.push_back(body);
- autoBlock->finalize();
body = autoBlock;
}
autoBlock->list.push_back(ex);
+ autoBlock->finalize();
}
}
}
@@ -1023,6 +1023,7 @@ private:
} else {
ret->value = parseExpression(s[i]);
}
+ ret->finalize();
return ret;
}
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
index ca168e48b..d71a3413a 100644
--- a/src/wasm-validator.h
+++ b/src/wasm-validator.h
@@ -22,6 +22,7 @@
#define wasm_wasm_validator_h
#include "wasm.h"
+#include "wasm-printing.h"
namespace wasm {
@@ -42,32 +43,50 @@ public:
void visitBlock(Block *curr) {
// if we are break'ed to, then the value must be right for us
if (curr->name.is()) {
- if (breakTypes.count(curr->name) > 0) {
- shouldBeTrue(curr->type == breakTypes[curr->name]);
- breakTypes.erase(curr->name);
+ // none or unreachable means a poison value that we should ignore - if consumed, it will error
+ if (breakTypes.count(curr->name) > 0 && isConcreteWasmType(breakTypes[curr->name]) && isConcreteWasmType(curr->type)) {
+ shouldBeEqual(curr->type, breakTypes[curr->name], curr, "block+breaks must have right type if breaks return a value");
}
+ breakTypes.erase(curr->name);
}
}
void visitLoop(Loop *curr) {
if (curr->in.is()) {
LoopChildChecker childChecker(curr->in);
childChecker.walk(curr->body);
- shouldBeTrue(childChecker.valid);
+ shouldBeTrue(childChecker.valid, curr, "loop must return none");
+ breakTypes.erase(curr->in);
+ }
+ if (curr->out.is()) {
+ breakTypes.erase(curr->out);
}
}
- void visitBreak(Break *curr) {
+ void noteBreak(Name name, Expression* value) {
WasmType valueType = none;
- if (curr->value) {
- valueType = curr->value->type;
+ if (value) {
+ valueType = value->type;
}
- if (breakTypes.count(curr->name) == 0) {
- breakTypes[curr->name] = valueType;
+ if (breakTypes.count(name) == 0) {
+ breakTypes[name] = valueType;
} else {
- shouldBeTrue(valueType == breakTypes[curr->name]);
+ if (breakTypes[name] == unreachable) {
+ breakTypes[name] = valueType;
+ } else {
+ shouldBeEqual(valueType, breakTypes[name], name.str, "breaks to same target must have same type (ignoring unreachable)");
+ }
}
}
+ void visitBreak(Break *curr) {
+ noteBreak(curr->name, curr->value);
+ }
+ void visitSwitch(Switch *curr) {
+ for (auto& target : curr->targets) {
+ noteBreak(target, curr->value);
+ }
+ noteBreak(curr->default_, curr->value);
+ }
void visitSetLocal(SetLocal *curr) {
- shouldBeTrue(curr->type == curr->value->type);
+ shouldBeTrue(curr->type == curr->value->type, curr, "set_local type might be correct");
}
void visitLoad(Load *curr) {
validateAlignment(curr->align);
@@ -75,28 +94,77 @@ public:
void visitStore(Store *curr) {
validateAlignment(curr->align);
}
- void visitSwitch(Switch *curr) {
+ void visitBinary(Binary *curr) {
+ if (curr->left->type != unreachable && curr->right->type != unreachable) {
+ shouldBeEqual(curr->left->type, curr->right->type, curr, "binary child types must be equal");
+ }
}
void visitUnary(Unary *curr) {
- shouldBeTrue(curr->value->type == curr->type);
+ switch (curr->op) {
+ case Clz:
+ case Ctz:
+ case Popcnt:
+ case Neg:
+ case Abs:
+ case Ceil:
+ case Floor:
+ case Trunc:
+ case Nearest:
+ case Sqrt: {
+ //if (curr->value->type != unreachable) {
+ shouldBeEqual(curr->value->type, curr->type, curr, "non-conversion unaries must return the same type");
+ //}
+ break;
+ }
+ case EqZ: {
+ shouldBeEqual(curr->type, i32, curr, "relational unaries must return i32");
+ break;
+ }
+ case ExtendSInt32:
+ case ExtendUInt32:
+ case WrapInt64:
+ case TruncSFloat32:
+ case TruncUFloat32:
+ case TruncSFloat64:
+ case TruncUFloat64:
+ case ReinterpretFloat:
+ case ConvertUInt32:
+ case ConvertSInt32:
+ case ConvertUInt64:
+ case ConvertSInt64:
+ case PromoteFloat32:
+ case DemoteFloat64:
+ case ReinterpretInt: {
+ //if (curr->value->type != unreachable) {
+ shouldBeUnequal(curr->value->type, curr->type, curr, "conversion unaries must not return the same type");
+ //}
+ break;
+ }
+ default: abort();
+ }
}
void visitFunction(Function *curr) {
- shouldBeTrue(curr->result == curr->body->type);
+ // if function has no result, it is ignored
+ // if body is unreachable, it might be e.g. a return
+ if (curr->result != none && curr->body->type != unreachable) {
+ shouldBeEqual(curr->result, curr->body->type, curr->body, "function result must match, if function returns");
+ }
}
void visitMemory(Memory *curr) {
- shouldBeFalse(curr->initial > curr->max);
+ shouldBeFalse(curr->initial > curr->max, "memory", "memory max >= initial");
size_t top = 0;
for (auto& segment : curr->segments) {
- shouldBeFalse(segment.offset < top);
+ shouldBeFalse(segment.offset < top, "memory", "segment offset is small enough");
top = segment.offset + segment.data.size();
}
- shouldBeFalse(top > curr->initial);
+ shouldBeFalse(top > curr->initial * Memory::kPageSize, "memory", "total segments must be small enough");
}
void visitModule(Module *curr) {
// exports
+ std::set<Name> exportNames;
for (auto& exp : curr->exports) {
- Name name = exp->name;
+ Name name = exp->value;
bool found = false;
for (auto& func : curr->functions) {
if (func->name == name) {
@@ -104,17 +172,27 @@ public:
break;
}
}
- shouldBeTrue(found);
+ shouldBeTrue(found, name, "module exports must be found");
+ Name exportName = exp->name;
+ shouldBeFalse(exportNames.count(exportName) > 0, exportName, "module exports must be unique");
+ exportNames.insert(exportName);
}
// start
if (curr->start.is()) {
auto func = curr->checkFunction(curr->start);
- if (shouldBeTrue(func)) {
- shouldBeTrue(func->params.size() == 0); // must be nullary
+ if (shouldBeTrue(func, curr->start, "start must be found")) {
+ shouldBeTrue(func->params.size() == 0, curr, "start must have 0 params");
+ shouldBeTrue(func->result == none, curr, "start must not return a value");
}
}
}
+ void walk(Expression*& root) {
+ //std::cerr << "start a function " << getFunction()->name << "\n";
+ PostWalker<WasmValidator, Visitor<WasmValidator>>::walk(root);
+ shouldBeTrue(breakTypes.size() == 0, "break targets", "all break targets must be valid");
+ }
+
private:
// the "in" label has a none type, since no one can receive its value. make sure no one breaks to it with a value.
@@ -133,15 +211,69 @@ private:
// helpers
- bool shouldBeTrue(bool result) {
- if (!result) valid = false;
+ std::ostream& fail() {
+ Colors::red(std::cerr);
+ if (getFunction()) {
+ std::cerr << "[wasm-validator error in function ";
+ Colors::green(std::cerr);
+ std::cerr << getFunction()->name;
+ Colors::red(std::cerr);
+ std::cerr << "] ";
+ } else {
+ std::cerr << "[wasm-validator error in module] ";
+ }
+ Colors::normal(std::cerr);
+ return std::cerr;
+ }
+
+ template<typename T>
+ bool shouldBeTrue(bool result, T curr, const char* text) {
+ if (!result) {
+ fail() << "unexpected false: " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
return result;
}
- bool shouldBeFalse(bool result) {
- if (result) valid = false;
+ template<typename T>
+ bool shouldBeFalse(bool result, T curr, const char* text) {
+ if (result) {
+ fail() << "unexpected true: " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
return result;
}
+ template<typename T, typename S>
+ bool shouldBeEqual(S left, S right, T curr, const char* text) {
+ if (left != right) {
+ fail() << "" << left << " != " << right << ": " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+ template<typename T, typename S, typename U>
+ bool shouldBeEqual(S left, S right, T curr, U other, const char* text) {
+ if (left != right) {
+ fail() << "" << left << " != " << right << ": " << text << ", on " << curr << " / " << other << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+
+ template<typename T, typename S>
+ bool shouldBeUnequal(S left, S right, T curr, const char* text) {
+ if (left == right) {
+ fail() << "" << left << " == " << right << ": " << text << ", on " << curr << std::endl;
+ valid = false;
+ return false;
+ }
+ return true;
+ }
+
void validateAlignment(size_t align) {
switch (align) {
case 1:
@@ -149,6 +281,7 @@ private:
case 4:
case 8: break;
default:{
+ fail() << "bad alignment: " << align << std::endl;
valid = false;
break;
}
diff --git a/src/wasm.cpp b/src/wasm.cpp
new file mode 100644
index 000000000..2157992e3
--- /dev/null
+++ b/src/wasm.cpp
@@ -0,0 +1,84 @@
+/*
+ * Copyright 2016 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "wasm.h"
+#include "wasm-traversal.h"
+#include "ast_utils.h"
+
+namespace wasm {
+
+struct BlockTypeSeeker : public PostWalker<BlockTypeSeeker, Visitor<BlockTypeSeeker>> {
+ Block* target; // look for this one
+ WasmType type = unreachable;
+
+ BlockTypeSeeker(Block* target) : target(target) {}
+
+ void noteType(WasmType other) {
+ // once none, stop. it then indicates a poison value, that must not be consumed
+ // and ignore unreachable
+ if (type != none) {
+ if (other == none) {
+ type = none;
+ } else if (other != unreachable) {
+ type = other;
+ }
+ }
+ }
+
+ void visitBreak(Break *curr) {
+ if (curr->name == target->name) {
+ noteType(curr->value ? curr->value->type : none);
+ }
+ }
+
+ void visitSwitch(Switch *curr) {
+ for (auto name : curr->targets) {
+ if (name == target->name) noteType(curr->value ? curr->value->type : none);
+ }
+ }
+
+ void visitBlock(Block *curr) {
+ if (curr == target) {
+ if (curr->list.size() > 0) noteType(curr->list.back()->type);
+ } else {
+ type = unreachable; // ignore all breaks til now, they were captured by someone with the same name
+ }
+ }
+};
+
+void Block::finalize() {
+ if (list.size() > 0) {
+ auto last = list.back()->type;
+ if (last != unreachable) {
+ // well that was easy
+ type = last;
+ return;
+ }
+ }
+ if (!name.is()) {
+ // that was rather silly
+ type = unreachable;
+ return;
+ }
+ // oh no this is hard
+ BlockTypeSeeker seeker(this);
+ Expression* temp = this;
+ seeker.walk(temp);
+ type = seeker.type;
+}
+
+} // namespace wasm
+
diff --git a/src/wasm.h b/src/wasm.h
index b5dee81b8..b13011bae 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -833,11 +833,13 @@ public:
Name name;
ExpressionList list;
- void finalize() {
- if (list.size() > 0) {
- type = list.back()->type;
- }
+ // set the type of a block if you already know it
+ void finalize(WasmType type_) {
+ type = type;
}
+
+ // set the type of a block based on its contents. this scans the block, so it is not fast
+ void finalize();
};
class If : public SpecificExpression<Expression::IfId> {
@@ -877,6 +879,12 @@ public:
Name name;
Expression *value;
Expression *condition;
+
+ void finalize() {
+ if (condition) {
+ type = none;
+ }
+ }
};
class Switch : public SpecificExpression<Expression::SwitchId> {
@@ -1022,7 +1030,6 @@ public:
if (isRelational()) {
type = i32;
} else {
- assert(left->type != unreachable && right->type != unreachable ? left->type == right->type : true);
type = getReachableWasmType(left->type, right->type);
}
}