summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/wasm-shell.cpp3
-rw-r--r--src/wasm-validator.h91
-rw-r--r--src/wasm.h126
3 files changed, 139 insertions, 81 deletions
diff --git a/src/wasm-shell.cpp b/src/wasm-shell.cpp
index 6b7586719..701b6b6c2 100644
--- a/src/wasm-shell.cpp
+++ b/src/wasm-shell.cpp
@@ -9,6 +9,7 @@
#include "wasm-s-parser.h"
#include "wasm-interpreter.h"
+#include "wasm-validator.h"
using namespace cashew;
using namespace wasm;
@@ -208,7 +209,7 @@ int main(int argc, char **argv) {
}
if (!invalid) {
// maybe parsed ok, but otherwise incorrect
- invalid = !wasm.validate();
+ invalid = !WasmValidator().validate(wasm);
}
assert(invalid);
} else if (id == INVOKE) {
diff --git a/src/wasm-validator.h b/src/wasm-validator.h
new file mode 100644
index 000000000..49e7a6a33
--- /dev/null
+++ b/src/wasm-validator.h
@@ -0,0 +1,91 @@
+
+//
+// Simple WebAssembly module validator.
+//
+
+#include "wasm.h"
+
+namespace wasm {
+
+struct WasmValidator : public WasmWalker {
+ bool valid;
+
+public:
+ bool validate(Module& module) {
+ valid = true;
+ startWalk(&module);
+ return valid;
+ }
+
+ // visitors
+
+ void visitSetLocal(SetLocal *curr) override {
+ shouldBeTrue(curr->type == curr->value->type);
+ }
+ void visitLoad(Load *curr) override {
+ validateAlignment(curr->align);
+ }
+ void visitStore(Store *curr) override {
+ validateAlignment(curr->align);
+ }
+ void visitSwitch(Switch *curr) override {
+ std::set<Name> inTable;
+ for (auto target : curr->targets) {
+ if (target.is()) {
+ inTable.insert(target);
+ }
+ }
+ for (auto& c : curr->cases) {
+ shouldBeFalse(c.name.is() && inTable.find(c.name) == inTable.end());
+ }
+ shouldBeFalse(curr->default_.is() && inTable.find(curr->default_) == inTable.end());
+ }
+ void visitMemory(Memory *curr) override {
+ shouldBeFalse(curr->initial > curr->max);
+ size_t top = 0;
+ for (auto segment : curr->segments) {
+ shouldBeFalse(segment.offset < top);
+ top = segment.offset + segment.size;
+ }
+ shouldBeFalse(top > curr->initial);
+ }
+ void visitModule(Module *curr) override {
+ for (auto& exp : curr->exports) {
+ Name name = exp->name;
+ bool found = false;
+ for (auto& func : curr->functions) {
+ if (func->name == name) {
+ found = true;
+ break;
+ }
+ }
+ shouldBeTrue(found);
+ }
+ }
+
+private:
+ // helpers
+
+ void shouldBeTrue(bool result) {
+ if (!result) valid = false;
+ }
+ void shouldBeFalse(bool result) {
+ if (result) valid = false;
+ }
+
+ void validateAlignment(size_t align) {
+ switch (align) {
+ case 1:
+ case 2:
+ case 4:
+ case 8: break;
+ default:{
+ valid = false;
+ break;
+ }
+ }
+ }
+};
+
+} // namespace wasm
+
diff --git a/src/wasm.h b/src/wasm.h
index 9dc140efe..769c0ecad 100644
--- a/src/wasm.h
+++ b/src/wasm.h
@@ -12,6 +12,7 @@
// * Interpreting: See wasm-interpreter.h.
// * Optimizing: See asm2wasm.h, which performs some optimizations
// after code generation.
+// * Validation: See wasm-validator.h.
//
#ifndef __wasm_h__
@@ -556,7 +557,7 @@ public:
unsigned bytes;
bool signed_;
- uint64_t offset; // XXX https://github.com/WebAssembly/spec/issues/161
+ uint32_t offset;
unsigned align;
Expression *ptr;
@@ -924,19 +925,18 @@ public:
};
class Module {
- // wasm contents (don't access these from outside; use add*() and the *Map objects)
+public:
+ // internal wasm contents (don't access these from outside; use add*() and the *Map objects)
std::vector<FunctionType*> functionTypes;
std::vector<Import*> imports;
std::vector<Export*> exports;
std::vector<Function*> functions;
-public:
- // utility maps
+ // publicly-accessible content
std::map<Name, FunctionType*> functionTypesMap;
std::map<Name, Import*> importsMap;
std::map<Name, Export*> exportsMap;
std::map<Name, Function*> functionsMap;
-
Table table;
Memory memory;
@@ -1033,33 +1033,6 @@ public:
return o << '\n';
}
- bool validate() {
- if (memory.initial > memory.max) return false;
- size_t top = 0;
- for (auto segment : memory.segments) {
- if (segment.offset < top) return false;
- top = segment.offset + segment.size;
- }
- if (top > memory.initial) return false;
- for (auto& exp : exports) {
- Name name = exp->name;
- bool found = false;
- for (auto& func : functions) {
- if (func->name == name) {
- found = true;
- break;
- }
- }
- if (!found) return false;
- }
- for (auto& curr : functions) {
- if (!validateFunction(curr)) return false;
- }
- return true;
- }
-
- bool validateFunction(Function *func);
-
private:
size_t functionTypeIndex, importIndex, exportIndex, functionIndex;
};
@@ -1073,6 +1046,7 @@ private:
template<class ReturnType>
struct WasmVisitor {
// should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048
+ // Expression visitors
virtual ReturnType visitBlock(Block *curr) { abort(); }
virtual ReturnType visitIf(If *curr) { abort(); }
virtual ReturnType visitLoop(Loop *curr) { abort(); }
@@ -1095,6 +1069,14 @@ struct WasmVisitor {
virtual ReturnType visitHost(Host *curr) { abort(); }
virtual ReturnType visitNop(Nop *curr) { abort(); }
virtual ReturnType visitUnreachable(Unreachable *curr) { abort(); }
+ // Module-level visitors
+ virtual ReturnType visitFunctionType(FunctionType *curr) { abort(); }
+ virtual ReturnType visitImport(Import *curr) { abort(); }
+ virtual ReturnType visitExport(Export *curr) { abort(); }
+ virtual ReturnType visitFunction(Function *curr) { abort(); }
+ virtual ReturnType visitTable(Table *curr) { abort(); }
+ virtual ReturnType visitMemory(Memory *curr) { abort(); }
+ virtual ReturnType visitModule(Module *curr) { abort(); }
ReturnType visit(Expression *curr) {
assert(curr);
@@ -1166,8 +1148,8 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) {
}
//
-// Simple WebAssembly children-first walking, with the ability to
-// replace the current node. Useful for writing optimization passes.
+// Simple WebAssembly children-first walking, with the ability to replace
+// the current expression node. Useful for writing optimization passes.
//
struct WasmWalker : public WasmVisitor<void> {
@@ -1204,6 +1186,14 @@ struct WasmWalker : public WasmVisitor<void> {
void visitNop(Nop *curr) override {}
void visitUnreachable(Unreachable *curr) override {}
+ void visitFunctionType(FunctionType *curr) override {}
+ void visitImport(Import *curr) override {}
+ void visitExport(Export *curr) override {}
+ void visitFunction(Function *curr) override {}
+ void visitTable(Table *curr) override {}
+ void visitMemory(Memory *curr) override {}
+ void visitModule(Module *curr) override {}
+
// children-first
void walk(Expression*& curr) {
if (!curr) return;
@@ -1310,57 +1300,33 @@ struct WasmWalker : public WasmVisitor<void> {
void startWalk(Function *func) {
walk(func->body);
}
-};
-
-bool Module::validateFunction(Function *func) {
- struct Validator : public WasmWalker {
- bool valid = true;
- void should(bool result) {
- if (!result) valid = false;
+ void startWalk(Module *module) {
+ for (auto curr : module->functionTypes) {
+ visitFunctionType(curr);
+ assert(!replace);
}
-
- void visitSetLocal(SetLocal *curr) override {
- should(curr->type == curr->value->type);
- }
- void visitLoad(Load *curr) override {
- if (!validateAlignment(curr->align)) valid = false;
+ for (auto curr : module->imports) {
+ visitImport(curr);
+ assert(!replace);
}
- void visitStore(Store *curr) override {
- if (!validateAlignment(curr->align)) valid = false;
+ for (auto curr : module->exports) {
+ visitExport(curr);
+ assert(!replace);
}
- void visitSwitch(Switch *curr) override {
- std::set<Name> inTable;
- for (auto target : curr->targets) {
- if (target.is()) {
- inTable.insert(target);
- }
- }
- for (auto& c : curr->cases) {
- if (c.name.is() && inTable.find(c.name) == inTable.end()) {
- valid = false;
- }
- }
- if (curr->default_.is() && inTable.find(curr->default_) == inTable.end()) {
- valid = false;
- }
+ for (auto curr : module->functions) {
+ startWalk(curr);
+ visitFunction(curr);
+ assert(!replace);
}
-
- bool validateAlignment(size_t align) {
- switch (align) {
- case 1:
- case 2:
- case 4:
- case 8: return true;
- default: return false;
- }
- }
- };
-
- Validator validator;
- validator.walk(func->body);
- return validator.valid;
-}
+ visitTable(&module->table);
+ assert(!replace);
+ visitMemory(&module->memory);
+ assert(!replace);
+ visitModule(module);
+ assert(!replace);
+ }
+};
} // namespace wasm