diff options
author | Thomas Lively <tlively@google.com> | 2024-05-17 17:49:45 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-17 17:49:45 -0700 |
commit | 921644ca65afbafb84fb82d58dacc4a028e2d720 (patch) | |
tree | 9253fbcf3f1dd9930dd1b9bc9f545234399b918e /src/tools | |
parent | 369cddfb44ddbada2ef7742a9ebef54727d12dd5 (diff) | |
download | binaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.tar.gz binaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.tar.bz2 binaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.zip |
Rewrite wasm-shell to use new wast parser (#6601)
Use the new wast parser to parse a full script up front, then traverse the
parsed script data structure and execute the commands. wasm-shell had previously
used the new wat parser for top-level modules, but it now uses the new parser
for module assertions as well. Fix various bugs this uncovered.
After this change, wasm-shell supports all the assertions used in the upstream
spec tests (although not new kinds of assertions introduced in any proposals).
Uncomment various `assert_exhaustion` tests that we can now execute.
Other kinds of assertions remain commented out in our tests: wasm-shell now
supports `assert_unlinkable`, but the interpreter does not eagerly check for the
existence of imports, so those tests do not pass. Tests that check for NaNs also
remain commented out because they do not yet use the standard syntax that
wasm-shell now supports for canonical and arithmetic NaN results, and our
interpreter would not pass all of those tests even if they did use the standard
syntax.
Diffstat (limited to 'src/tools')
-rw-r--r-- | src/tools/wasm-shell.cpp | 622 |
1 files changed, 347 insertions, 275 deletions
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 3fe8b3505..913a7a58a 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -30,26 +30,16 @@ #include "support/command-line.h" #include "support/file.h" #include "support/result.h" +#include "wasm-binary.h" #include "wasm-interpreter.h" #include "wasm-s-parser.h" #include "wasm-validator.h" using namespace wasm; -using Lexer = WATParser::Lexer; +using namespace wasm::WATParser; -Name ASSERT_RETURN("assert_return"); -Name ASSERT_TRAP("assert_trap"); -Name ASSERT_EXCEPTION("assert_exception"); -Name ASSERT_INVALID("assert_invalid"); -Name ASSERT_MALFORMED("assert_malformed"); -Name ASSERT_UNLINKABLE("assert_unlinkable"); -Name INVOKE("invoke"); -Name REGISTER("register"); -Name GET("get"); - -class Shell { -protected: +struct Shell { std::map<Name, std::shared_ptr<Module>> modules; std::map<Name, std::shared_ptr<ShellExternalInterface>> interfaces; std::map<Name, std::shared_ptr<ModuleRunner>> instances; @@ -58,284 +48,395 @@ protected: Name lastModule; - void instantiate(Module* wasm) { - auto tempInterface = - std::make_shared<ShellExternalInterface>(linkedInstances); - auto tempInstance = std::make_shared<ModuleRunner>( - *wasm, tempInterface.get(), linkedInstances); - interfaces[wasm->name].swap(tempInterface); - instances[wasm->name].swap(tempInstance); - } - - Result<std::string> parseSExpr(Lexer& lexer) { - auto begin = lexer.getPos(); + Options& options; - if (!lexer.takeLParen()) { - return lexer.err("expected s-expression"); - } + Shell(Options& options) : options(options) { buildSpectestModule(); } - size_t count = 1; - while (count != 0 && lexer.takeUntilParen()) { - if (lexer.takeLParen()) { - ++count; - } else if (lexer.takeRParen()) { - --count; + Result<> run(WASTScript& script) { + size_t i = 0; + for (auto& entry : script) { + Colors::red(std::cerr); + std::cerr << i << ' '; + Colors::normal(std::cerr); + if (std::get_if<WASTModule>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "BUILDING MODULE [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); + } else if (auto* reg = std::get_if<Register>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "REGISTER MODULE INSTANCE AS \"" << reg->name + << "\" [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } else { - WASM_UNREACHABLE("unexpected token"); + Colors::green(std::cerr); + std::cerr << "CHECKING [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } + ++i; + CHECK_ERR(runCommand(entry.cmd)); } - - if (count != 0) { - return lexer.err("unexpected unterminated s-expression"); - } - - return std::string(lexer.buffer.substr(begin, lexer.getPos() - begin)); + return Ok{}; } - Expression* parseExpression(Module& wasm, Element& s) { - std::stringstream ss; - ss << s; - auto str = ss.str(); - Lexer lexer(str); - auto arg = WATParser::parseExpression(wasm, lexer); - if (auto* err = arg.getErr()) { - Fatal() << err->msg << '\n'; + Result<> runCommand(WASTCommand& cmd) { + if (auto* mod = std::get_if<WASTModule>(&cmd)) { + return addModule(*mod); + } else if (auto* reg = std::get_if<Register>(&cmd)) { + return addRegistration(*reg); + } else if (auto* act = std::get_if<Action>(&cmd)) { + doAction(*act); + return Ok{}; + } else if (auto* assn = std::get_if<Assertion>(&cmd)) { + return doAssertion(*assn); + } else { + WASM_UNREACHABLE("unexpected command"); } - return *arg; } - Result<> parse(Lexer& lexer) { - if (auto res = parseModule(lexer)) { - CHECK_ERR(res); - return Ok{}; + Result<std::shared_ptr<Module>> makeModule(WASTModule& mod) { + std::shared_ptr<Module> wasm; + if (auto* quoted = std::get_if<QuotedModule>(&mod)) { + wasm = std::make_shared<Module>(); + switch (quoted->type) { + case QuotedModuleType::Text: { + CHECK_ERR(parseModule(*wasm, quoted->module)); + break; + } + case QuotedModuleType::Binary: { + std::vector<char> buffer(quoted->module.begin(), + quoted->module.end()); + WasmBinaryReader reader(*wasm, FeatureSet::All, buffer); + try { + reader.read(); + } catch (ParseException& p) { + std::stringstream ss; + p.dump(ss); + return Err{ss.str()}; + } + break; + } + } + } else if (auto* ptr = std::get_if<std::shared_ptr<Module>>(&mod)) { + wasm = *ptr; + } else { + WASM_UNREACHABLE("unexpected module kind"); } + wasm->features = FeatureSet::All; + return wasm; + } - auto pos = lexer.getPos(); - auto sexpr = parseSExpr(lexer); - CHECK_ERR(sexpr); - - SExpressionParser parser(sexpr->data()); - Element& s = *parser.root[0][0]; - IString id = s[0]->str(); - if (id == REGISTER) { - parseRegister(s); - } else if (id == INVOKE) { - parseOperation(s); - } else if (id == ASSERT_RETURN) { - parseAssertReturn(s); - } else if (id == ASSERT_TRAP) { - parseAssertTrap(s); - } else if (id == ASSERT_EXCEPTION) { - parseAssertException(s); - } else if ((id == ASSERT_INVALID) || (id == ASSERT_MALFORMED) || - (id == ASSERT_UNLINKABLE)) { - parseModuleAssertion(s); - } else { - return lexer.err(pos, "unrecognized command"); + Result<> validateModule(Module& wasm) { + if (!WasmValidator().validate(wasm)) { + return Err{"failed validation"}; } return Ok{}; } - MaybeResult<> parseModule(Lexer& lexer) { - if (!lexer.peekSExprStart("module")) { - return {}; - } - Colors::green(std::cerr); - std::cerr << "BUILDING MODULE [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - auto module = std::make_shared<Module>(); - - CHECK_ERR(WATParser::parseModule(*module, lexer)); - - auto moduleName = module->name; - lastModule = module->name; - modules[moduleName].swap(module); - modules[moduleName]->features = FeatureSet::All; - bool valid = WasmValidator().validate(*modules[moduleName]); - if (!valid) { - std::cout << *modules[moduleName] << '\n'; - Fatal() << "module failed to validate, see above"; + using InstanceInfo = std::pair<std::shared_ptr<ShellExternalInterface>, + std::shared_ptr<ModuleRunner>>; + + Result<InstanceInfo> instantiate(Module& wasm) { + try { + auto interface = + std::make_shared<ShellExternalInterface>(linkedInstances); + auto instance = + std::make_shared<ModuleRunner>(wasm, interface.get(), linkedInstances); + return {{std::move(interface), std::move(instance)}}; + } catch (...) { + return Err{"failed to instantiate module"}; } + } + + Result<> addModule(WASTModule& mod) { + auto module = makeModule(mod); + CHECK_ERR(module); + + auto wasm = *module; + CHECK_ERR(validateModule(*wasm)); + + auto instanceInfo = instantiate(*wasm); + CHECK_ERR(instanceInfo); + + auto& [interface, instance] = *instanceInfo; + lastModule = wasm->name; + modules[lastModule] = std::move(wasm); + interfaces[lastModule] = std::move(interface); + instances[lastModule] = std::move(instance); - instantiate(modules[moduleName].get()); return Ok{}; } - void parseRegister(Element& s) { + Result<> addRegistration(Register& reg) { auto instance = instances[lastModule]; if (!instance) { - Fatal() << "register called without a module"; + return Err{"register called without a module"}; } - auto name = s[1]->str(); - linkedInstances[name] = instance; + linkedInstances[reg.name] = instance; - // we copy pointers as a registered module's name might still be used + // We copy pointers as a registered module's name might still be used // in an assertion or invoke command. - modules[name] = modules[lastModule]; - interfaces[name] = interfaces[lastModule]; - instances[name] = instances[lastModule]; - - Colors::green(std::cerr); - std::cerr << "REGISTER MODULE INSTANCE AS \"" << name.str - << "\" [line: " << s.line << "]\n"; - Colors::normal(std::cerr); + modules[reg.name] = modules[lastModule]; + interfaces[reg.name] = interfaces[lastModule]; + instances[reg.name] = instances[lastModule]; + return Ok{}; } - Literals parseOperation(Element& s) { - Index i = 1; - Name moduleName = lastModule; - if (s[i]->dollared()) { - moduleName = s[i++]->str(); - } - ModuleRunner* instance = instances[moduleName].get(); - assert(instance); - - std::string baseStr = std::string("\"") + s[i++]->str().toString() + "\""; - auto base = Lexer(baseStr).takeString(); - if (!base) { - Fatal() << "expected string\n"; + struct TrapResult {}; + struct HostLimitResult {}; + struct ExceptionResult {}; + using ActionResult = + std::variant<Literals, TrapResult, HostLimitResult, ExceptionResult>; + + std::string resultToString(ActionResult& result) { + if (std::get_if<TrapResult>(&result)) { + return "trap"; + } else if (std::get_if<HostLimitResult>(&result)) { + return "exceeded host limit"; + } else if (std::get_if<ExceptionResult>(&result)) { + return "exception"; + } else if (auto* vals = std::get_if<Literals>(&result)) { + std::stringstream ss; + ss << *vals; + return ss.str(); + } else { + WASM_UNREACHABLE("unexpected result"); } + } - if (s[0]->str() == INVOKE) { - Literals args; - while (i < s.size()) { - auto* arg = parseExpression(*modules[moduleName], *s[i++]); - args.push_back(getLiteralFromConstExpression(arg)); + ActionResult doAction(Action& act) { + ModuleRunner* instance = instances[lastModule].get(); + assert(instance); + if (auto* invoke = std::get_if<InvokeAction>(&act)) { + auto it = instances.find(invoke->base ? *invoke->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; } - return instance->callExport(*base, args); - } else if (s[0]->str() == GET) { - return instance->getExport(*base); + auto& instance = it->second; + try { + return instance->callExport(invoke->name, invoke->args); + } catch (TrapException&) { + return TrapResult{}; + } catch (HostLimitException&) { + return HostLimitResult{}; + } catch (WasmException&) { + return ExceptionResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else if (auto* get = std::get_if<GetAction>(&act)) { + auto it = instances.find(get->base ? *get->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; + } + auto& instance = it->second; + try { + return instance->getExport(get->name); + } catch (TrapException&) { + return TrapResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else { + WASM_UNREACHABLE("unexpected action"); } + } - Fatal() << "Invalid operation " << s[0]->toString(); + Result<> doAssertion(Assertion& assn) { + if (auto* ret = std::get_if<AssertReturn>(&assn)) { + return assertReturn(*ret); + } else if (auto* act = std::get_if<AssertAction>(&assn)) { + return assertAction(*act); + } else if (auto* mod = std::get_if<AssertModule>(&assn)) { + return assertModule(*mod); + } else { + WASM_UNREACHABLE("unexpected assertion"); + } } - void parseAssertTrap(Element& s) { - [[maybe_unused]] bool trapped = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> checkNaN(Literal val, NaNResult nan) { + std::stringstream err; + switch (nan.kind) { + case NaNKind::Canonical: + if (val.type != nan.type || !val.isCanonicalNaN()) { + err << "expected canonical " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; + case NaNKind::Arithmetic: + if (val.type != nan.type || !val.isArithmeticNaN()) { + err << "expected arithmetic " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; } + return Ok{}; + } - try { - parseOperation(inner); - } catch (const TrapException&) { - trapped = true; + Result<> checkLane(Literal val, LaneResult expected, Index index) { + std::stringstream err; + if (auto* e = std::get_if<Literal>(&expected)) { + if (*e != val) { + err << "expected " << *e << ", got " << val << " at lane " << index; + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << " at lane " << index; + return Err{err.str()}; + } + } else { + WASM_UNREACHABLE("unexpected lane expectation"); } - assert(trapped); + return Ok{}; } - void parseAssertException(Element& s) { - [[maybe_unused]] bool thrown = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> assertReturn(AssertReturn& assn) { + std::stringstream err; + auto result = doAction(assn.action); + auto* values = std::get_if<Literals>(&result); + if (!values) { + return Err{std::string("expected return, got ") + resultToString(result)}; } + if (values->size() != assn.expected.size()) { + err << "expected " << assn.expected.size() << " values, got " + << resultToString(result); + return Err{err.str()}; + } + for (Index i = 0; i < values->size(); ++i) { + auto atIndex = [&]() { + if (values->size() <= 1) { + return std::string{}; + } + std::stringstream ss; + ss << " at index " << i; + return ss.str(); + }; - try { - parseOperation(inner); - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - thrown = true; + Literal val = (*values)[i]; + auto& expected = assn.expected[i]; + if (auto* v = std::get_if<Literal>(&expected)) { + if (val != *v) { + err << "expected " << *v << ", got " << val << atIndex(); + return Err{err.str()}; + } + } else if (auto* ref = std::get_if<RefResult>(&expected)) { + if (!val.type.isRef() || val.type.getHeapType() != ref->type) { + err << "expected " << ref->type << " reference, got " << val + << atIndex(); + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } else if (auto* lanes = std::get_if<LaneResults>(&expected)) { + switch (lanes->size()) { + case 4: { + auto vals = val.getLanesF32x4(); + for (Index i = 0; i < 4; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + case 2: { + auto vals = val.getLanesF64x2(); + for (Index i = 0; i < 2; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + default: + WASM_UNREACHABLE("unexpected number of lanes"); + } + } else { + WASM_UNREACHABLE("unexpected expectation"); + } } - assert(thrown); + return Ok{}; } - void parseAssertReturn(Element& s) { - Literals actual; - Literals expected; - if (s.size() >= 3) { - expected = getLiteralsFromConstExpression( - parseExpression(*modules[lastModule], *s[2])); - } - [[maybe_unused]] bool trapped = false; - try { - actual = parseOperation(*s[1]); - } catch (const TrapException&) { - trapped = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - trapped = true; - } - assert(!trapped); - std::cerr << "seen " << actual << ", expected " << expected << '\n'; - if (expected != actual) { - Fatal() << "unexpected, should be identical\n"; + Result<> assertAction(AssertAction& assn) { + std::stringstream err; + auto result = doAction(assn.action); + switch (assn.type) { + case ActionAssertionType::Trap: + if (std::get_if<TrapResult>(&result)) { + return Ok{}; + } + err << "expected trap"; + break; + case ActionAssertionType::Exhaustion: + if (std::get_if<HostLimitResult>(&result)) { + return Ok{}; + } + err << "expected exhaustion"; + break; + case ActionAssertionType::Exception: + if (std::get_if<ExceptionResult>(&result)) { + return Ok{}; + } + err << "expected exception"; + break; } + err << ", got " << resultToString(result); + return Err{err.str()}; } - void parseModuleAssertion(Element& s) { - Module wasm; - wasm.features = FeatureSet::All; - std::unique_ptr<SExpressionWasmBuilder> builder; - auto id = s[0]->str(); + Result<> assertModule(AssertModule& assn) { + auto wasm = makeModule(assn.wasm); + if (const auto* err = wasm.getErr()) { + if (assn.type == ModuleAssertionType::Malformed || + assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; + } - bool invalid = false; - try { - SExpressionWasmBuilder(wasm, *s[1], IRProfile::Normal); - } catch (const ParseException&) { - invalid = true; + if (assn.type == ModuleAssertionType::Malformed) { + return Err{"expected malformed module"}; } - if (!invalid) { - // maybe parsed ok, but otherwise incorrect - invalid = !WasmValidator().validate(wasm); + auto valid = validateModule(**wasm); + if (auto* err = valid.getErr()) { + if (assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; } - if (!invalid && id == ASSERT_UNLINKABLE) { - // validate "instantiating" the mdoule - auto reportUnknownImport = [&](Importable* import) { - auto it = linkedInstances.find(import->module); - if (it == linkedInstances.end() || - it->second->wasm.getExportOrNull(import->base) == nullptr) { - std::cerr << "unknown import: " << import->module << '.' - << import->base << '\n'; - invalid = true; - } - }; - ModuleUtils::iterImportedGlobals(wasm, reportUnknownImport); - ModuleUtils::iterImportedTables(wasm, reportUnknownImport); - ModuleUtils::iterImportedFunctions(wasm, [&](Importable* import) { - if (import->module == SPECTEST && import->base.startsWith(PRINT)) { - // We can handle it. - } else { - reportUnknownImport(import); - } - }); - ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) { - // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getFunction(name)) { - if (import->imported() && import->module == SPECTEST && - import->base.startsWith(PRINT)) { - std::cerr << "cannot put spectest.print in table\n"; - invalid = true; - } - } - }); - ModuleUtils::iterImportedMemories(wasm, reportUnknownImport); + if (assn.type == ModuleAssertionType::Invalid) { + return Err{"expected invalid module"}; } - if (!invalid && (id == ASSERT_TRAP || id == ASSERT_EXCEPTION)) { - try { - instantiate(&wasm); - } catch (const TrapException&) { - invalid = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - invalid = true; + auto instance = instantiate(**wasm); + if (auto* err = instance.getErr()) { + if (assn.type == ModuleAssertionType::Unlinkable || + assn.type == ModuleAssertionType::Trap) { + return Ok{}; } + return Err{err->msg}; } - if (!invalid) { - Colors::red(std::cerr); - std::cerr << "[should have been invalid]\n"; - Colors::normal(std::cerr); - Fatal() << &wasm << '\n'; + if (assn.type == ModuleAssertionType::Unlinkable) { + return Err{"expected unlinkable module"}; + } + if (assn.type == ModuleAssertionType::Trap) { + return Err{"expected instantiation to trap"}; } - } -protected: - Options& options; + WASM_UNREACHABLE("unexpected module assertion"); + } // spectest module is a default host-provided module defined by the spec's // reference interpreter. It's been replaced by the `(register ...)` @@ -345,7 +446,7 @@ protected: // is actually removed from the spec test. void buildSpectestModule() { auto spectest = std::make_shared<Module>(); - spectest->name = "spectest"; + spectest->features = FeatureSet::All; Builder builder(*spectest); spectest->addGlobal(builder.makeGlobal(Name::fromInt(0), @@ -388,45 +489,18 @@ protected: spectest->addExport( builder.makeExport("memory", memory->name, ExternalKind::Memory)); - modules["spectest"].swap(spectest); - modules["spectest"]->features = FeatureSet::All; - instantiate(modules["spectest"].get()); - linkedInstances["spectest"] = instances["spectest"]; // print_* functions are handled separately, no need to define here. - } - -public: - Shell(Options& options) : options(options) { buildSpectestModule(); } - - MaybeResult<> parseAndRun(Lexer& lexer) { - size_t i = 0; - while (!lexer.empty()) { - auto next = lexer.next(); - auto size = next.find('\n'); - if (size != std::string_view::npos) { - next = next.substr(0, size); - } else { - next = ""; - } - - if (!lexer.peekSExprStart("module")) { - Colors::red(std::cerr); - std::cerr << i; - Colors::green(std::cerr); - std::cerr << " CHECKING: "; - Colors::normal(std::cerr); - std::cerr << next; - Colors::green(std::cerr); - std::cerr << " [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - } - CHECK_ERR(parse(lexer)); - - i += 1; + WASTModule mod = std::move(spectest); + auto added = addModule(mod); + if (added.getErr()) { + WASM_UNREACHABLE("error building spectest module"); + } + Register registration{"spectest"}; + auto registered = addRegistration(registration); + if (registered.getErr()) { + WASM_UNREACHABLE("error registering spectest module"); } - - return Ok{}; } }; @@ -453,16 +527,14 @@ int main(int argc, const char* argv[]) { } Lexer lexer(input); - auto result = Shell(options).parseAndRun(lexer); + auto result = Shell(options).run(*script); if (auto* err = result.getErr()) { std::cerr << err->msg << '\n'; exit(1); } - if (result) { - Colors::green(std::cerr); - Colors::bold(std::cerr); - std::cerr << "all checks passed.\n"; - Colors::normal(std::cerr); - } + Colors::green(std::cerr); + Colors::bold(std::cerr); + std::cerr << "all checks passed.\n"; + Colors::normal(std::cerr); } |