diff options
Diffstat (limited to 'src/tools/wasm-shell.cpp')
-rw-r--r-- | src/tools/wasm-shell.cpp | 622 |
1 files changed, 347 insertions, 275 deletions
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 3fe8b3505..913a7a58a 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -30,26 +30,16 @@ #include "support/command-line.h" #include "support/file.h" #include "support/result.h" +#include "wasm-binary.h" #include "wasm-interpreter.h" #include "wasm-s-parser.h" #include "wasm-validator.h" using namespace wasm; -using Lexer = WATParser::Lexer; +using namespace wasm::WATParser; -Name ASSERT_RETURN("assert_return"); -Name ASSERT_TRAP("assert_trap"); -Name ASSERT_EXCEPTION("assert_exception"); -Name ASSERT_INVALID("assert_invalid"); -Name ASSERT_MALFORMED("assert_malformed"); -Name ASSERT_UNLINKABLE("assert_unlinkable"); -Name INVOKE("invoke"); -Name REGISTER("register"); -Name GET("get"); - -class Shell { -protected: +struct Shell { std::map<Name, std::shared_ptr<Module>> modules; std::map<Name, std::shared_ptr<ShellExternalInterface>> interfaces; std::map<Name, std::shared_ptr<ModuleRunner>> instances; @@ -58,284 +48,395 @@ protected: Name lastModule; - void instantiate(Module* wasm) { - auto tempInterface = - std::make_shared<ShellExternalInterface>(linkedInstances); - auto tempInstance = std::make_shared<ModuleRunner>( - *wasm, tempInterface.get(), linkedInstances); - interfaces[wasm->name].swap(tempInterface); - instances[wasm->name].swap(tempInstance); - } - - Result<std::string> parseSExpr(Lexer& lexer) { - auto begin = lexer.getPos(); + Options& options; - if (!lexer.takeLParen()) { - return lexer.err("expected s-expression"); - } + Shell(Options& options) : options(options) { buildSpectestModule(); } - size_t count = 1; - while (count != 0 && lexer.takeUntilParen()) { - if (lexer.takeLParen()) { - ++count; - } else if (lexer.takeRParen()) { - --count; + Result<> run(WASTScript& script) { + size_t i = 0; + for (auto& entry : script) { + Colors::red(std::cerr); + std::cerr << i << ' '; + Colors::normal(std::cerr); + if (std::get_if<WASTModule>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "BUILDING MODULE [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); + } else if (auto* reg = std::get_if<Register>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "REGISTER MODULE INSTANCE AS \"" << reg->name + << "\" [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } else { - WASM_UNREACHABLE("unexpected token"); + Colors::green(std::cerr); + std::cerr << "CHECKING [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } + ++i; + CHECK_ERR(runCommand(entry.cmd)); } - - if (count != 0) { - return lexer.err("unexpected unterminated s-expression"); - } - - return std::string(lexer.buffer.substr(begin, lexer.getPos() - begin)); + return Ok{}; } - Expression* parseExpression(Module& wasm, Element& s) { - std::stringstream ss; - ss << s; - auto str = ss.str(); - Lexer lexer(str); - auto arg = WATParser::parseExpression(wasm, lexer); - if (auto* err = arg.getErr()) { - Fatal() << err->msg << '\n'; + Result<> runCommand(WASTCommand& cmd) { + if (auto* mod = std::get_if<WASTModule>(&cmd)) { + return addModule(*mod); + } else if (auto* reg = std::get_if<Register>(&cmd)) { + return addRegistration(*reg); + } else if (auto* act = std::get_if<Action>(&cmd)) { + doAction(*act); + return Ok{}; + } else if (auto* assn = std::get_if<Assertion>(&cmd)) { + return doAssertion(*assn); + } else { + WASM_UNREACHABLE("unexpected command"); } - return *arg; } - Result<> parse(Lexer& lexer) { - if (auto res = parseModule(lexer)) { - CHECK_ERR(res); - return Ok{}; + Result<std::shared_ptr<Module>> makeModule(WASTModule& mod) { + std::shared_ptr<Module> wasm; + if (auto* quoted = std::get_if<QuotedModule>(&mod)) { + wasm = std::make_shared<Module>(); + switch (quoted->type) { + case QuotedModuleType::Text: { + CHECK_ERR(parseModule(*wasm, quoted->module)); + break; + } + case QuotedModuleType::Binary: { + std::vector<char> buffer(quoted->module.begin(), + quoted->module.end()); + WasmBinaryReader reader(*wasm, FeatureSet::All, buffer); + try { + reader.read(); + } catch (ParseException& p) { + std::stringstream ss; + p.dump(ss); + return Err{ss.str()}; + } + break; + } + } + } else if (auto* ptr = std::get_if<std::shared_ptr<Module>>(&mod)) { + wasm = *ptr; + } else { + WASM_UNREACHABLE("unexpected module kind"); } + wasm->features = FeatureSet::All; + return wasm; + } - auto pos = lexer.getPos(); - auto sexpr = parseSExpr(lexer); - CHECK_ERR(sexpr); - - SExpressionParser parser(sexpr->data()); - Element& s = *parser.root[0][0]; - IString id = s[0]->str(); - if (id == REGISTER) { - parseRegister(s); - } else if (id == INVOKE) { - parseOperation(s); - } else if (id == ASSERT_RETURN) { - parseAssertReturn(s); - } else if (id == ASSERT_TRAP) { - parseAssertTrap(s); - } else if (id == ASSERT_EXCEPTION) { - parseAssertException(s); - } else if ((id == ASSERT_INVALID) || (id == ASSERT_MALFORMED) || - (id == ASSERT_UNLINKABLE)) { - parseModuleAssertion(s); - } else { - return lexer.err(pos, "unrecognized command"); + Result<> validateModule(Module& wasm) { + if (!WasmValidator().validate(wasm)) { + return Err{"failed validation"}; } return Ok{}; } - MaybeResult<> parseModule(Lexer& lexer) { - if (!lexer.peekSExprStart("module")) { - return {}; - } - Colors::green(std::cerr); - std::cerr << "BUILDING MODULE [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - auto module = std::make_shared<Module>(); - - CHECK_ERR(WATParser::parseModule(*module, lexer)); - - auto moduleName = module->name; - lastModule = module->name; - modules[moduleName].swap(module); - modules[moduleName]->features = FeatureSet::All; - bool valid = WasmValidator().validate(*modules[moduleName]); - if (!valid) { - std::cout << *modules[moduleName] << '\n'; - Fatal() << "module failed to validate, see above"; + using InstanceInfo = std::pair<std::shared_ptr<ShellExternalInterface>, + std::shared_ptr<ModuleRunner>>; + + Result<InstanceInfo> instantiate(Module& wasm) { + try { + auto interface = + std::make_shared<ShellExternalInterface>(linkedInstances); + auto instance = + std::make_shared<ModuleRunner>(wasm, interface.get(), linkedInstances); + return {{std::move(interface), std::move(instance)}}; + } catch (...) { + return Err{"failed to instantiate module"}; } + } + + Result<> addModule(WASTModule& mod) { + auto module = makeModule(mod); + CHECK_ERR(module); + + auto wasm = *module; + CHECK_ERR(validateModule(*wasm)); + + auto instanceInfo = instantiate(*wasm); + CHECK_ERR(instanceInfo); + + auto& [interface, instance] = *instanceInfo; + lastModule = wasm->name; + modules[lastModule] = std::move(wasm); + interfaces[lastModule] = std::move(interface); + instances[lastModule] = std::move(instance); - instantiate(modules[moduleName].get()); return Ok{}; } - void parseRegister(Element& s) { + Result<> addRegistration(Register& reg) { auto instance = instances[lastModule]; if (!instance) { - Fatal() << "register called without a module"; + return Err{"register called without a module"}; } - auto name = s[1]->str(); - linkedInstances[name] = instance; + linkedInstances[reg.name] = instance; - // we copy pointers as a registered module's name might still be used + // We copy pointers as a registered module's name might still be used // in an assertion or invoke command. - modules[name] = modules[lastModule]; - interfaces[name] = interfaces[lastModule]; - instances[name] = instances[lastModule]; - - Colors::green(std::cerr); - std::cerr << "REGISTER MODULE INSTANCE AS \"" << name.str - << "\" [line: " << s.line << "]\n"; - Colors::normal(std::cerr); + modules[reg.name] = modules[lastModule]; + interfaces[reg.name] = interfaces[lastModule]; + instances[reg.name] = instances[lastModule]; + return Ok{}; } - Literals parseOperation(Element& s) { - Index i = 1; - Name moduleName = lastModule; - if (s[i]->dollared()) { - moduleName = s[i++]->str(); - } - ModuleRunner* instance = instances[moduleName].get(); - assert(instance); - - std::string baseStr = std::string("\"") + s[i++]->str().toString() + "\""; - auto base = Lexer(baseStr).takeString(); - if (!base) { - Fatal() << "expected string\n"; + struct TrapResult {}; + struct HostLimitResult {}; + struct ExceptionResult {}; + using ActionResult = + std::variant<Literals, TrapResult, HostLimitResult, ExceptionResult>; + + std::string resultToString(ActionResult& result) { + if (std::get_if<TrapResult>(&result)) { + return "trap"; + } else if (std::get_if<HostLimitResult>(&result)) { + return "exceeded host limit"; + } else if (std::get_if<ExceptionResult>(&result)) { + return "exception"; + } else if (auto* vals = std::get_if<Literals>(&result)) { + std::stringstream ss; + ss << *vals; + return ss.str(); + } else { + WASM_UNREACHABLE("unexpected result"); } + } - if (s[0]->str() == INVOKE) { - Literals args; - while (i < s.size()) { - auto* arg = parseExpression(*modules[moduleName], *s[i++]); - args.push_back(getLiteralFromConstExpression(arg)); + ActionResult doAction(Action& act) { + ModuleRunner* instance = instances[lastModule].get(); + assert(instance); + if (auto* invoke = std::get_if<InvokeAction>(&act)) { + auto it = instances.find(invoke->base ? *invoke->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; } - return instance->callExport(*base, args); - } else if (s[0]->str() == GET) { - return instance->getExport(*base); + auto& instance = it->second; + try { + return instance->callExport(invoke->name, invoke->args); + } catch (TrapException&) { + return TrapResult{}; + } catch (HostLimitException&) { + return HostLimitResult{}; + } catch (WasmException&) { + return ExceptionResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else if (auto* get = std::get_if<GetAction>(&act)) { + auto it = instances.find(get->base ? *get->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; + } + auto& instance = it->second; + try { + return instance->getExport(get->name); + } catch (TrapException&) { + return TrapResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else { + WASM_UNREACHABLE("unexpected action"); } + } - Fatal() << "Invalid operation " << s[0]->toString(); + Result<> doAssertion(Assertion& assn) { + if (auto* ret = std::get_if<AssertReturn>(&assn)) { + return assertReturn(*ret); + } else if (auto* act = std::get_if<AssertAction>(&assn)) { + return assertAction(*act); + } else if (auto* mod = std::get_if<AssertModule>(&assn)) { + return assertModule(*mod); + } else { + WASM_UNREACHABLE("unexpected assertion"); + } } - void parseAssertTrap(Element& s) { - [[maybe_unused]] bool trapped = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> checkNaN(Literal val, NaNResult nan) { + std::stringstream err; + switch (nan.kind) { + case NaNKind::Canonical: + if (val.type != nan.type || !val.isCanonicalNaN()) { + err << "expected canonical " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; + case NaNKind::Arithmetic: + if (val.type != nan.type || !val.isArithmeticNaN()) { + err << "expected arithmetic " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; } + return Ok{}; + } - try { - parseOperation(inner); - } catch (const TrapException&) { - trapped = true; + Result<> checkLane(Literal val, LaneResult expected, Index index) { + std::stringstream err; + if (auto* e = std::get_if<Literal>(&expected)) { + if (*e != val) { + err << "expected " << *e << ", got " << val << " at lane " << index; + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << " at lane " << index; + return Err{err.str()}; + } + } else { + WASM_UNREACHABLE("unexpected lane expectation"); } - assert(trapped); + return Ok{}; } - void parseAssertException(Element& s) { - [[maybe_unused]] bool thrown = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> assertReturn(AssertReturn& assn) { + std::stringstream err; + auto result = doAction(assn.action); + auto* values = std::get_if<Literals>(&result); + if (!values) { + return Err{std::string("expected return, got ") + resultToString(result)}; } + if (values->size() != assn.expected.size()) { + err << "expected " << assn.expected.size() << " values, got " + << resultToString(result); + return Err{err.str()}; + } + for (Index i = 0; i < values->size(); ++i) { + auto atIndex = [&]() { + if (values->size() <= 1) { + return std::string{}; + } + std::stringstream ss; + ss << " at index " << i; + return ss.str(); + }; - try { - parseOperation(inner); - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - thrown = true; + Literal val = (*values)[i]; + auto& expected = assn.expected[i]; + if (auto* v = std::get_if<Literal>(&expected)) { + if (val != *v) { + err << "expected " << *v << ", got " << val << atIndex(); + return Err{err.str()}; + } + } else if (auto* ref = std::get_if<RefResult>(&expected)) { + if (!val.type.isRef() || val.type.getHeapType() != ref->type) { + err << "expected " << ref->type << " reference, got " << val + << atIndex(); + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } else if (auto* lanes = std::get_if<LaneResults>(&expected)) { + switch (lanes->size()) { + case 4: { + auto vals = val.getLanesF32x4(); + for (Index i = 0; i < 4; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + case 2: { + auto vals = val.getLanesF64x2(); + for (Index i = 0; i < 2; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + default: + WASM_UNREACHABLE("unexpected number of lanes"); + } + } else { + WASM_UNREACHABLE("unexpected expectation"); + } } - assert(thrown); + return Ok{}; } - void parseAssertReturn(Element& s) { - Literals actual; - Literals expected; - if (s.size() >= 3) { - expected = getLiteralsFromConstExpression( - parseExpression(*modules[lastModule], *s[2])); - } - [[maybe_unused]] bool trapped = false; - try { - actual = parseOperation(*s[1]); - } catch (const TrapException&) { - trapped = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - trapped = true; - } - assert(!trapped); - std::cerr << "seen " << actual << ", expected " << expected << '\n'; - if (expected != actual) { - Fatal() << "unexpected, should be identical\n"; + Result<> assertAction(AssertAction& assn) { + std::stringstream err; + auto result = doAction(assn.action); + switch (assn.type) { + case ActionAssertionType::Trap: + if (std::get_if<TrapResult>(&result)) { + return Ok{}; + } + err << "expected trap"; + break; + case ActionAssertionType::Exhaustion: + if (std::get_if<HostLimitResult>(&result)) { + return Ok{}; + } + err << "expected exhaustion"; + break; + case ActionAssertionType::Exception: + if (std::get_if<ExceptionResult>(&result)) { + return Ok{}; + } + err << "expected exception"; + break; } + err << ", got " << resultToString(result); + return Err{err.str()}; } - void parseModuleAssertion(Element& s) { - Module wasm; - wasm.features = FeatureSet::All; - std::unique_ptr<SExpressionWasmBuilder> builder; - auto id = s[0]->str(); + Result<> assertModule(AssertModule& assn) { + auto wasm = makeModule(assn.wasm); + if (const auto* err = wasm.getErr()) { + if (assn.type == ModuleAssertionType::Malformed || + assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; + } - bool invalid = false; - try { - SExpressionWasmBuilder(wasm, *s[1], IRProfile::Normal); - } catch (const ParseException&) { - invalid = true; + if (assn.type == ModuleAssertionType::Malformed) { + return Err{"expected malformed module"}; } - if (!invalid) { - // maybe parsed ok, but otherwise incorrect - invalid = !WasmValidator().validate(wasm); + auto valid = validateModule(**wasm); + if (auto* err = valid.getErr()) { + if (assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; } - if (!invalid && id == ASSERT_UNLINKABLE) { - // validate "instantiating" the mdoule - auto reportUnknownImport = [&](Importable* import) { - auto it = linkedInstances.find(import->module); - if (it == linkedInstances.end() || - it->second->wasm.getExportOrNull(import->base) == nullptr) { - std::cerr << "unknown import: " << import->module << '.' - << import->base << '\n'; - invalid = true; - } - }; - ModuleUtils::iterImportedGlobals(wasm, reportUnknownImport); - ModuleUtils::iterImportedTables(wasm, reportUnknownImport); - ModuleUtils::iterImportedFunctions(wasm, [&](Importable* import) { - if (import->module == SPECTEST && import->base.startsWith(PRINT)) { - // We can handle it. - } else { - reportUnknownImport(import); - } - }); - ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) { - // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getFunction(name)) { - if (import->imported() && import->module == SPECTEST && - import->base.startsWith(PRINT)) { - std::cerr << "cannot put spectest.print in table\n"; - invalid = true; - } - } - }); - ModuleUtils::iterImportedMemories(wasm, reportUnknownImport); + if (assn.type == ModuleAssertionType::Invalid) { + return Err{"expected invalid module"}; } - if (!invalid && (id == ASSERT_TRAP || id == ASSERT_EXCEPTION)) { - try { - instantiate(&wasm); - } catch (const TrapException&) { - invalid = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - invalid = true; + auto instance = instantiate(**wasm); + if (auto* err = instance.getErr()) { + if (assn.type == ModuleAssertionType::Unlinkable || + assn.type == ModuleAssertionType::Trap) { + return Ok{}; } + return Err{err->msg}; } - if (!invalid) { - Colors::red(std::cerr); - std::cerr << "[should have been invalid]\n"; - Colors::normal(std::cerr); - Fatal() << &wasm << '\n'; + if (assn.type == ModuleAssertionType::Unlinkable) { + return Err{"expected unlinkable module"}; + } + if (assn.type == ModuleAssertionType::Trap) { + return Err{"expected instantiation to trap"}; } - } -protected: - Options& options; + WASM_UNREACHABLE("unexpected module assertion"); + } // spectest module is a default host-provided module defined by the spec's // reference interpreter. It's been replaced by the `(register ...)` @@ -345,7 +446,7 @@ protected: // is actually removed from the spec test. void buildSpectestModule() { auto spectest = std::make_shared<Module>(); - spectest->name = "spectest"; + spectest->features = FeatureSet::All; Builder builder(*spectest); spectest->addGlobal(builder.makeGlobal(Name::fromInt(0), @@ -388,45 +489,18 @@ protected: spectest->addExport( builder.makeExport("memory", memory->name, ExternalKind::Memory)); - modules["spectest"].swap(spectest); - modules["spectest"]->features = FeatureSet::All; - instantiate(modules["spectest"].get()); - linkedInstances["spectest"] = instances["spectest"]; // print_* functions are handled separately, no need to define here. - } - -public: - Shell(Options& options) : options(options) { buildSpectestModule(); } - - MaybeResult<> parseAndRun(Lexer& lexer) { - size_t i = 0; - while (!lexer.empty()) { - auto next = lexer.next(); - auto size = next.find('\n'); - if (size != std::string_view::npos) { - next = next.substr(0, size); - } else { - next = ""; - } - - if (!lexer.peekSExprStart("module")) { - Colors::red(std::cerr); - std::cerr << i; - Colors::green(std::cerr); - std::cerr << " CHECKING: "; - Colors::normal(std::cerr); - std::cerr << next; - Colors::green(std::cerr); - std::cerr << " [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - } - CHECK_ERR(parse(lexer)); - - i += 1; + WASTModule mod = std::move(spectest); + auto added = addModule(mod); + if (added.getErr()) { + WASM_UNREACHABLE("error building spectest module"); + } + Register registration{"spectest"}; + auto registered = addRegistration(registration); + if (registered.getErr()) { + WASM_UNREACHABLE("error registering spectest module"); } - - return Ok{}; } }; @@ -453,16 +527,14 @@ int main(int argc, const char* argv[]) { } Lexer lexer(input); - auto result = Shell(options).parseAndRun(lexer); + auto result = Shell(options).run(*script); if (auto* err = result.getErr()) { std::cerr << err->msg << '\n'; exit(1); } - if (result) { - Colors::green(std::cerr); - Colors::bold(std::cerr); - std::cerr << "all checks passed.\n"; - Colors::normal(std::cerr); - } + Colors::green(std::cerr); + Colors::bold(std::cerr); + std::cerr << "all checks passed.\n"; + Colors::normal(std::cerr); } |