diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/wasm-shell.cpp | 3 | ||||
-rw-r--r-- | src/wasm-validator.h | 91 | ||||
-rw-r--r-- | src/wasm.h | 126 |
3 files changed, 139 insertions, 81 deletions
diff --git a/src/wasm-shell.cpp b/src/wasm-shell.cpp index 6b7586719..701b6b6c2 100644 --- a/src/wasm-shell.cpp +++ b/src/wasm-shell.cpp @@ -9,6 +9,7 @@ #include "wasm-s-parser.h" #include "wasm-interpreter.h" +#include "wasm-validator.h" using namespace cashew; using namespace wasm; @@ -208,7 +209,7 @@ int main(int argc, char **argv) { } if (!invalid) { // maybe parsed ok, but otherwise incorrect - invalid = !wasm.validate(); + invalid = !WasmValidator().validate(wasm); } assert(invalid); } else if (id == INVOKE) { diff --git a/src/wasm-validator.h b/src/wasm-validator.h new file mode 100644 index 000000000..49e7a6a33 --- /dev/null +++ b/src/wasm-validator.h @@ -0,0 +1,91 @@ + +// +// Simple WebAssembly module validator. +// + +#include "wasm.h" + +namespace wasm { + +struct WasmValidator : public WasmWalker { + bool valid; + +public: + bool validate(Module& module) { + valid = true; + startWalk(&module); + return valid; + } + + // visitors + + void visitSetLocal(SetLocal *curr) override { + shouldBeTrue(curr->type == curr->value->type); + } + void visitLoad(Load *curr) override { + validateAlignment(curr->align); + } + void visitStore(Store *curr) override { + validateAlignment(curr->align); + } + void visitSwitch(Switch *curr) override { + std::set<Name> inTable; + for (auto target : curr->targets) { + if (target.is()) { + inTable.insert(target); + } + } + for (auto& c : curr->cases) { + shouldBeFalse(c.name.is() && inTable.find(c.name) == inTable.end()); + } + shouldBeFalse(curr->default_.is() && inTable.find(curr->default_) == inTable.end()); + } + void visitMemory(Memory *curr) override { + shouldBeFalse(curr->initial > curr->max); + size_t top = 0; + for (auto segment : curr->segments) { + shouldBeFalse(segment.offset < top); + top = segment.offset + segment.size; + } + shouldBeFalse(top > curr->initial); + } + void visitModule(Module *curr) override { + for (auto& exp : curr->exports) { + Name name = exp->name; + bool found = false; + for (auto& func : curr->functions) { + if (func->name == name) { + found = true; + break; + } + } + shouldBeTrue(found); + } + } + +private: + // helpers + + void shouldBeTrue(bool result) { + if (!result) valid = false; + } + void shouldBeFalse(bool result) { + if (result) valid = false; + } + + void validateAlignment(size_t align) { + switch (align) { + case 1: + case 2: + case 4: + case 8: break; + default:{ + valid = false; + break; + } + } + } +}; + +} // namespace wasm + diff --git a/src/wasm.h b/src/wasm.h index 9dc140efe..769c0ecad 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -12,6 +12,7 @@ // * Interpreting: See wasm-interpreter.h. // * Optimizing: See asm2wasm.h, which performs some optimizations // after code generation. +// * Validation: See wasm-validator.h. // #ifndef __wasm_h__ @@ -556,7 +557,7 @@ public: unsigned bytes; bool signed_; - uint64_t offset; // XXX https://github.com/WebAssembly/spec/issues/161 + uint32_t offset; unsigned align; Expression *ptr; @@ -924,19 +925,18 @@ public: }; class Module { - // wasm contents (don't access these from outside; use add*() and the *Map objects) +public: + // internal wasm contents (don't access these from outside; use add*() and the *Map objects) std::vector<FunctionType*> functionTypes; std::vector<Import*> imports; std::vector<Export*> exports; std::vector<Function*> functions; -public: - // utility maps + // publicly-accessible content std::map<Name, FunctionType*> functionTypesMap; std::map<Name, Import*> importsMap; std::map<Name, Export*> exportsMap; std::map<Name, Function*> functionsMap; - Table table; Memory memory; @@ -1033,33 +1033,6 @@ public: return o << '\n'; } - bool validate() { - if (memory.initial > memory.max) return false; - size_t top = 0; - for (auto segment : memory.segments) { - if (segment.offset < top) return false; - top = segment.offset + segment.size; - } - if (top > memory.initial) return false; - for (auto& exp : exports) { - Name name = exp->name; - bool found = false; - for (auto& func : functions) { - if (func->name == name) { - found = true; - break; - } - } - if (!found) return false; - } - for (auto& curr : functions) { - if (!validateFunction(curr)) return false; - } - return true; - } - - bool validateFunction(Function *func); - private: size_t functionTypeIndex, importIndex, exportIndex, functionIndex; }; @@ -1073,6 +1046,7 @@ private: template<class ReturnType> struct WasmVisitor { // should be pure virtual, but https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51048 + // Expression visitors virtual ReturnType visitBlock(Block *curr) { abort(); } virtual ReturnType visitIf(If *curr) { abort(); } virtual ReturnType visitLoop(Loop *curr) { abort(); } @@ -1095,6 +1069,14 @@ struct WasmVisitor { virtual ReturnType visitHost(Host *curr) { abort(); } virtual ReturnType visitNop(Nop *curr) { abort(); } virtual ReturnType visitUnreachable(Unreachable *curr) { abort(); } + // Module-level visitors + virtual ReturnType visitFunctionType(FunctionType *curr) { abort(); } + virtual ReturnType visitImport(Import *curr) { abort(); } + virtual ReturnType visitExport(Export *curr) { abort(); } + virtual ReturnType visitFunction(Function *curr) { abort(); } + virtual ReturnType visitTable(Table *curr) { abort(); } + virtual ReturnType visitMemory(Memory *curr) { abort(); } + virtual ReturnType visitModule(Module *curr) { abort(); } ReturnType visit(Expression *curr) { assert(curr); @@ -1166,8 +1148,8 @@ std::ostream& Expression::print(std::ostream &o, unsigned indent) { } // -// Simple WebAssembly children-first walking, with the ability to -// replace the current node. Useful for writing optimization passes. +// Simple WebAssembly children-first walking, with the ability to replace +// the current expression node. Useful for writing optimization passes. // struct WasmWalker : public WasmVisitor<void> { @@ -1204,6 +1186,14 @@ struct WasmWalker : public WasmVisitor<void> { void visitNop(Nop *curr) override {} void visitUnreachable(Unreachable *curr) override {} + void visitFunctionType(FunctionType *curr) override {} + void visitImport(Import *curr) override {} + void visitExport(Export *curr) override {} + void visitFunction(Function *curr) override {} + void visitTable(Table *curr) override {} + void visitMemory(Memory *curr) override {} + void visitModule(Module *curr) override {} + // children-first void walk(Expression*& curr) { if (!curr) return; @@ -1310,57 +1300,33 @@ struct WasmWalker : public WasmVisitor<void> { void startWalk(Function *func) { walk(func->body); } -}; - -bool Module::validateFunction(Function *func) { - struct Validator : public WasmWalker { - bool valid = true; - void should(bool result) { - if (!result) valid = false; + void startWalk(Module *module) { + for (auto curr : module->functionTypes) { + visitFunctionType(curr); + assert(!replace); } - - void visitSetLocal(SetLocal *curr) override { - should(curr->type == curr->value->type); - } - void visitLoad(Load *curr) override { - if (!validateAlignment(curr->align)) valid = false; + for (auto curr : module->imports) { + visitImport(curr); + assert(!replace); } - void visitStore(Store *curr) override { - if (!validateAlignment(curr->align)) valid = false; + for (auto curr : module->exports) { + visitExport(curr); + assert(!replace); } - void visitSwitch(Switch *curr) override { - std::set<Name> inTable; - for (auto target : curr->targets) { - if (target.is()) { - inTable.insert(target); - } - } - for (auto& c : curr->cases) { - if (c.name.is() && inTable.find(c.name) == inTable.end()) { - valid = false; - } - } - if (curr->default_.is() && inTable.find(curr->default_) == inTable.end()) { - valid = false; - } + for (auto curr : module->functions) { + startWalk(curr); + visitFunction(curr); + assert(!replace); } - - bool validateAlignment(size_t align) { - switch (align) { - case 1: - case 2: - case 4: - case 8: return true; - default: return false; - } - } - }; - - Validator validator; - validator.walk(func->body); - return validator.valid; -} + visitTable(&module->table); + assert(!replace); + visitMemory(&module->memory); + assert(!replace); + visitModule(module); + assert(!replace); + } +}; } // namespace wasm |