diff options
-rw-r--r-- | src/literal.h | 2 | ||||
-rw-r--r-- | src/parser/contexts.h | 5 | ||||
-rw-r--r-- | src/parser/lexer.cpp | 4 | ||||
-rw-r--r-- | src/parser/parsers.h | 22 | ||||
-rw-r--r-- | src/parser/wast-parser.cpp | 32 | ||||
-rw-r--r-- | src/parser/wat-parser.cpp | 8 | ||||
-rw-r--r-- | src/parser/wat-parser.h | 24 | ||||
-rw-r--r-- | src/tools/wasm-shell.cpp | 622 | ||||
-rw-r--r-- | src/wasm-builder.h | 1 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 2 | ||||
-rw-r--r-- | src/wasm-type.h | 2 | ||||
-rw-r--r-- | src/wasm/literal.cpp | 16 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 10 | ||||
-rw-r--r-- | test/spec/address64.wast | 7 | ||||
-rw-r--r-- | test/spec/call.wast | 4 | ||||
-rw-r--r-- | test/spec/call_indirect.wast | 4 | ||||
-rw-r--r-- | test/spec/exception-handling-old.wast | 40 | ||||
-rw-r--r-- | test/spec/fac.wast | 2 | ||||
-rw-r--r-- | test/spec/inline-module.wast | 2 | ||||
-rw-r--r-- | test/spec/skip-stack-guard-page.wast | 20 |
20 files changed, 462 insertions, 367 deletions
diff --git a/src/literal.h b/src/literal.h index a4017f6ec..1268448fb 100644 --- a/src/literal.h +++ b/src/literal.h @@ -347,6 +347,8 @@ public: bool operator!=(const Literal& other) const; bool isNaN(); + bool isCanonicalNaN(); + bool isArithmeticNaN(); static uint32_t NaNPayload(float f); static uint64_t NaNPayload(double f); diff --git a/src/parser/contexts.h b/src/parser/contexts.h index 5ad0c16de..bce051a93 100644 --- a/src/parser/contexts.h +++ b/src/parser/contexts.h @@ -1665,7 +1665,10 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> { return Ok{}; } - Result<> addExport(Index, Name value, Name name, ExternalKind kind) { + Result<> addExport(Index pos, Name value, Name name, ExternalKind kind) { + if (wasm.getExportOrNull(name)) { + return in.err(pos, "duplicate export"); + } wasm.addExport(builder.makeExport(name, value, kind)); return Ok{}; } diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index fd0a262b8..bb6428e87 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -23,6 +23,7 @@ #include <variant> #include "lexer.h" +#include "support/bits.h" #include "support/string.h" using namespace std::string_view_literals; @@ -1005,6 +1006,9 @@ std::optional<uint32_t> Lexer::takeAlign() { } Lexer subLexer(result->span.substr(6)); if (auto o = subLexer.takeU32()) { + if (Bits::popCount(*o) != 1) { + return std::nullopt; + } pos += result->span.size(); advance(); return o; diff --git a/src/parser/parsers.h b/src/parser/parsers.h index 88600fec3..5900deb27 100644 --- a/src/parser/parsers.h +++ b/src/parser/parsers.h @@ -30,7 +30,8 @@ template<typename Ctx> Result<typename Ctx::HeapTypeT> heaptype(Ctx&); template<typename Ctx> MaybeResult<typename Ctx::RefTypeT> reftype(Ctx&); template<typename Ctx> MaybeResult<typename Ctx::TypeT> tupletype(Ctx&); template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx&); -template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx&); +template<typename Ctx> +MaybeResult<typename Ctx::ParamsT> params(Ctx&, bool allowNames = true); template<typename Ctx> MaybeResult<typename Ctx::ResultsT> results(Ctx&); template<typename Ctx> MaybeResult<typename Ctx::SignatureT> functype(Ctx&); template<typename Ctx> Result<typename Ctx::FieldT> storagetype(Ctx&); @@ -325,7 +326,8 @@ MaybeResult<typename Ctx::LabelIdxT> maybeLabelidx(Ctx&, template<typename Ctx> Result<typename Ctx::LabelIdxT> labelidx(Ctx&, bool inDelegate = false); template<typename Ctx> Result<typename Ctx::TagIdxT> tagidx(Ctx&); -template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx&); +template<typename Ctx> +Result<typename Ctx::TypeUseT> typeuse(Ctx&, bool allowNames = true); MaybeResult<ImportNames> inlineImport(Lexer&); Result<std::vector<Name>> inlineExports(Lexer&); template<typename Ctx> Result<> strtype(Ctx&); @@ -561,13 +563,18 @@ template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx& ctx) { // param ::= '(' 'param id? t:valtype ')' => [t] // | '(' 'param t*:valtype* ')' => [t*] // params ::= param* -template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx& ctx) { +template<typename Ctx> +MaybeResult<typename Ctx::ParamsT> params(Ctx& ctx, bool allowNames) { bool hasAny = false; auto res = ctx.makeParams(); while (ctx.in.takeSExprStart("param"sv)) { hasAny = true; + auto pos = ctx.in.getPos(); if (auto id = ctx.in.takeID()) { // Single named param + if (!allowNames) { + return ctx.in.err(pos, "unexpected named parameter"); + } auto type = valtype(ctx); CHECK_ERR(type); if (!ctx.in.takeRParen()) { @@ -1065,7 +1072,7 @@ template<typename Ctx> Result<typename Ctx::BlockTypeT> blocktype(Ctx& ctx) { // We either had no results or multiple results. Reset and parse again as a // type use. ctx.in = initialLexer; - auto use = typeuse(ctx); + auto use = typeuse(ctx, false); CHECK_ERR(use); auto type = ctx.getBlockTypeFromTypeUse(pos, *use); @@ -1935,7 +1942,7 @@ Result<> makeCallIndirect(Ctx& ctx, bool isReturn) { auto table = maybeTableidx(ctx); CHECK_ERR(table); - auto type = typeuse(ctx); + auto type = typeuse(ctx, false); CHECK_ERR(type); return ctx.makeCallIndirect( pos, annotations, table.getPtr(), *type, isReturn); @@ -2669,7 +2676,8 @@ template<typename Ctx> Result<typename Ctx::TagIdxT> tagidx(Ctx& ctx) { // (if typedefs[x] = [t1*] -> [t2*]) // | ((t1,IDs):param)* (t2:result)* => x, IDs // (if x is minimum s.t. typedefs[x] = [t1*] -> [t2*]) -template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx) { +template<typename Ctx> +Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx, bool allowNames) { auto pos = ctx.in.getPos(); std::optional<typename Ctx::HeapTypeT> type; if (ctx.in.takeSExprStart("type"sv)) { @@ -2683,7 +2691,7 @@ template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx) { type = *x; } - auto namedParams = params(ctx); + auto namedParams = params(ctx, allowNames); CHECK_ERR(namedParams); auto resultTypes = results(ctx); diff --git a/src/parser/wast-parser.cpp b/src/parser/wast-parser.cpp index fb0dce932..87060b9fc 100644 --- a/src/parser/wast-parser.cpp +++ b/src/parser/wast-parser.cpp @@ -41,8 +41,7 @@ Result<Literals> consts(Lexer& in) { MaybeResult<Action> action(Lexer& in) { if (in.takeSExprStart("invoke"sv)) { - // TODO: Do we need to use this optional id? - in.takeID(); + auto id = in.takeID(); auto name = in.takeName(); if (!name) { return in.err("expected export name"); @@ -52,12 +51,11 @@ MaybeResult<Action> action(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of invoke action"); } - return InvokeAction{*name, *args}; + return InvokeAction{id, *name, *args}; } if (in.takeSExprStart("get"sv)) { - // TODO: Do we need to use this optional id? - in.takeID(); + auto id = in.takeID(); auto name = in.takeName(); if (!name) { return in.err("expected export name"); @@ -65,7 +63,7 @@ MaybeResult<Action> action(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of get action"); } - return GetAction{*name}; + return GetAction{id, *name}; } return {}; @@ -236,7 +234,7 @@ MaybeResult<AssertReturn> assertReturn(Lexer& in) { } // (assert_exception action) -MaybeResult<AssertException> assertException(Lexer& in) { +MaybeResult<AssertAction> assertException(Lexer& in) { if (!in.takeSExprStart("assert_exception"sv)) { return {}; } @@ -245,7 +243,7 @@ MaybeResult<AssertException> assertException(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of assert_exception"); } - return AssertException{*a}; + return AssertAction{ActionAssertionType::Exception, *a}; } // (assert_exhaustion action msg) @@ -266,7 +264,7 @@ MaybeResult<AssertAction> assertAction(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of assertion"); } - return AssertAction{type, *a, *msg}; + return AssertAction{type, *a}; } // (assert_malformed module msg) @@ -293,7 +291,7 @@ MaybeResult<AssertModule> assertModule(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of assertion"); } - return AssertModule{type, *mod, *msg}; + return AssertModule{type, *mod}; } // (assert_trap action msg) @@ -312,7 +310,7 @@ MaybeResult<Assertion> assertTrap(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of assertion"); } - return Assertion{AssertAction{ActionAssertionType::Trap, *a, *msg}}; + return Assertion{AssertAction{ActionAssertionType::Trap, *a}}; } auto mod = wastModule(in); if (mod.getErr()) { @@ -325,7 +323,7 @@ MaybeResult<Assertion> assertTrap(Lexer& in) { if (!in.takeRParen()) { return in.err("expected end of assertion"); } - return Assertion{AssertModule{ModuleAssertionType::Trap, *mod, *msg}}; + return Assertion{AssertModule{ModuleAssertionType::Trap, *mod}}; } MaybeResult<Assertion> assertion(Lexer& in) { @@ -391,24 +389,30 @@ Result<WASTCommand> command(Lexer& in) { return *mod; } +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + Result<WASTScript> wast(Lexer& in) { WASTScript cmds; while (!in.empty()) { + size_t line = in.position().line; auto cmd = command(in); if (cmd.getErr() && cmds.empty()) { // The entire script might be a single module comprising a sequence of // module fields with a top-level `(module ...)`. auto wasm = std::make_shared<Module>(); CHECK_ERR(parseModule(*wasm, in.buffer)); - cmds.emplace_back(std::move(wasm)); + cmds.push_back({WASTModule{std::move(wasm)}, line}); return cmds; } CHECK_ERR(cmd); - cmds.emplace_back(std::move(*cmd)); + cmds.push_back(ScriptEntry{std::move(*cmd), line}); } return cmds; } +#pragma GCC diagnostic pop + } // anonymous namespace Result<WASTScript> parseScript(std::string_view in) { diff --git a/src/parser/wat-parser.cpp b/src/parser/wat-parser.cpp index 85ef1e80f..4763c69ec 100644 --- a/src/parser/wat-parser.cpp +++ b/src/parser/wat-parser.cpp @@ -237,14 +237,6 @@ Result<> parseModule(Module& wasm, Lexer& lexer) { return doParseModule(wasm, lexer, true); } -Result<Expression*> parseExpression(Module& wasm, Lexer& lexer) { - ParseDefsCtx ctx(lexer, wasm, {}, {}, {}, {}, {}); - auto e = expr(ctx); - CHECK_ERR(e); - lexer = ctx.in; - return *e; -} - Result<Literal> parseConst(Lexer& lexer) { Module wasm; ParseDefsCtx ctx(lexer, wasm, {}, {}, {}, {}, {}); diff --git a/src/parser/wat-parser.h b/src/parser/wat-parser.h index 7fe6abfdd..041ba1d58 100644 --- a/src/parser/wat-parser.h +++ b/src/parser/wat-parser.h @@ -34,14 +34,14 @@ Result<> parseModule(Module& wasm, Lexer& lexer); Result<Literal> parseConst(Lexer& lexer); -Result<Expression*> parseExpression(Module& wasm, Lexer& lexer); - struct InvokeAction { + std::optional<Name> base; Name name; Literals args; }; struct GetAction { + std::optional<Name> base; Name name; }; @@ -68,19 +68,14 @@ using ExpectedResults = std::vector<ExpectedResult>; struct AssertReturn { Action action; - ExpectedResults results; -}; - -struct AssertException { - Action action; + ExpectedResults expected; }; -enum class ActionAssertionType { Trap, Exhaustion }; +enum class ActionAssertionType { Trap, Exhaustion, Exception }; struct AssertAction { ActionAssertionType type; Action action; - std::string msg; }; enum class QuotedModuleType { Text, Binary }; @@ -97,11 +92,9 @@ enum class ModuleAssertionType { Trap, Malformed, Invalid, Unlinkable }; struct AssertModule { ModuleAssertionType type; WASTModule wasm; - std::string msg; }; -using Assertion = - std::variant<AssertReturn, AssertException, AssertAction, AssertModule>; +using Assertion = std::variant<AssertReturn, AssertAction, AssertModule>; struct Register { Name name; @@ -109,7 +102,12 @@ struct Register { using WASTCommand = std::variant<WASTModule, Register, Action, Assertion>; -using WASTScript = std::vector<WASTCommand>; +struct ScriptEntry { + WASTCommand cmd; + size_t line; +}; + +using WASTScript = std::vector<ScriptEntry>; Result<WASTScript> parseScript(std::string_view in); diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 3fe8b3505..913a7a58a 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -30,26 +30,16 @@ #include "support/command-line.h" #include "support/file.h" #include "support/result.h" +#include "wasm-binary.h" #include "wasm-interpreter.h" #include "wasm-s-parser.h" #include "wasm-validator.h" using namespace wasm; -using Lexer = WATParser::Lexer; +using namespace wasm::WATParser; -Name ASSERT_RETURN("assert_return"); -Name ASSERT_TRAP("assert_trap"); -Name ASSERT_EXCEPTION("assert_exception"); -Name ASSERT_INVALID("assert_invalid"); -Name ASSERT_MALFORMED("assert_malformed"); -Name ASSERT_UNLINKABLE("assert_unlinkable"); -Name INVOKE("invoke"); -Name REGISTER("register"); -Name GET("get"); - -class Shell { -protected: +struct Shell { std::map<Name, std::shared_ptr<Module>> modules; std::map<Name, std::shared_ptr<ShellExternalInterface>> interfaces; std::map<Name, std::shared_ptr<ModuleRunner>> instances; @@ -58,284 +48,395 @@ protected: Name lastModule; - void instantiate(Module* wasm) { - auto tempInterface = - std::make_shared<ShellExternalInterface>(linkedInstances); - auto tempInstance = std::make_shared<ModuleRunner>( - *wasm, tempInterface.get(), linkedInstances); - interfaces[wasm->name].swap(tempInterface); - instances[wasm->name].swap(tempInstance); - } - - Result<std::string> parseSExpr(Lexer& lexer) { - auto begin = lexer.getPos(); + Options& options; - if (!lexer.takeLParen()) { - return lexer.err("expected s-expression"); - } + Shell(Options& options) : options(options) { buildSpectestModule(); } - size_t count = 1; - while (count != 0 && lexer.takeUntilParen()) { - if (lexer.takeLParen()) { - ++count; - } else if (lexer.takeRParen()) { - --count; + Result<> run(WASTScript& script) { + size_t i = 0; + for (auto& entry : script) { + Colors::red(std::cerr); + std::cerr << i << ' '; + Colors::normal(std::cerr); + if (std::get_if<WASTModule>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "BUILDING MODULE [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); + } else if (auto* reg = std::get_if<Register>(&entry.cmd)) { + Colors::green(std::cerr); + std::cerr << "REGISTER MODULE INSTANCE AS \"" << reg->name + << "\" [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } else { - WASM_UNREACHABLE("unexpected token"); + Colors::green(std::cerr); + std::cerr << "CHECKING [line: " << entry.line << "]\n"; + Colors::normal(std::cerr); } + ++i; + CHECK_ERR(runCommand(entry.cmd)); } - - if (count != 0) { - return lexer.err("unexpected unterminated s-expression"); - } - - return std::string(lexer.buffer.substr(begin, lexer.getPos() - begin)); + return Ok{}; } - Expression* parseExpression(Module& wasm, Element& s) { - std::stringstream ss; - ss << s; - auto str = ss.str(); - Lexer lexer(str); - auto arg = WATParser::parseExpression(wasm, lexer); - if (auto* err = arg.getErr()) { - Fatal() << err->msg << '\n'; + Result<> runCommand(WASTCommand& cmd) { + if (auto* mod = std::get_if<WASTModule>(&cmd)) { + return addModule(*mod); + } else if (auto* reg = std::get_if<Register>(&cmd)) { + return addRegistration(*reg); + } else if (auto* act = std::get_if<Action>(&cmd)) { + doAction(*act); + return Ok{}; + } else if (auto* assn = std::get_if<Assertion>(&cmd)) { + return doAssertion(*assn); + } else { + WASM_UNREACHABLE("unexpected command"); } - return *arg; } - Result<> parse(Lexer& lexer) { - if (auto res = parseModule(lexer)) { - CHECK_ERR(res); - return Ok{}; + Result<std::shared_ptr<Module>> makeModule(WASTModule& mod) { + std::shared_ptr<Module> wasm; + if (auto* quoted = std::get_if<QuotedModule>(&mod)) { + wasm = std::make_shared<Module>(); + switch (quoted->type) { + case QuotedModuleType::Text: { + CHECK_ERR(parseModule(*wasm, quoted->module)); + break; + } + case QuotedModuleType::Binary: { + std::vector<char> buffer(quoted->module.begin(), + quoted->module.end()); + WasmBinaryReader reader(*wasm, FeatureSet::All, buffer); + try { + reader.read(); + } catch (ParseException& p) { + std::stringstream ss; + p.dump(ss); + return Err{ss.str()}; + } + break; + } + } + } else if (auto* ptr = std::get_if<std::shared_ptr<Module>>(&mod)) { + wasm = *ptr; + } else { + WASM_UNREACHABLE("unexpected module kind"); } + wasm->features = FeatureSet::All; + return wasm; + } - auto pos = lexer.getPos(); - auto sexpr = parseSExpr(lexer); - CHECK_ERR(sexpr); - - SExpressionParser parser(sexpr->data()); - Element& s = *parser.root[0][0]; - IString id = s[0]->str(); - if (id == REGISTER) { - parseRegister(s); - } else if (id == INVOKE) { - parseOperation(s); - } else if (id == ASSERT_RETURN) { - parseAssertReturn(s); - } else if (id == ASSERT_TRAP) { - parseAssertTrap(s); - } else if (id == ASSERT_EXCEPTION) { - parseAssertException(s); - } else if ((id == ASSERT_INVALID) || (id == ASSERT_MALFORMED) || - (id == ASSERT_UNLINKABLE)) { - parseModuleAssertion(s); - } else { - return lexer.err(pos, "unrecognized command"); + Result<> validateModule(Module& wasm) { + if (!WasmValidator().validate(wasm)) { + return Err{"failed validation"}; } return Ok{}; } - MaybeResult<> parseModule(Lexer& lexer) { - if (!lexer.peekSExprStart("module")) { - return {}; - } - Colors::green(std::cerr); - std::cerr << "BUILDING MODULE [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - auto module = std::make_shared<Module>(); - - CHECK_ERR(WATParser::parseModule(*module, lexer)); - - auto moduleName = module->name; - lastModule = module->name; - modules[moduleName].swap(module); - modules[moduleName]->features = FeatureSet::All; - bool valid = WasmValidator().validate(*modules[moduleName]); - if (!valid) { - std::cout << *modules[moduleName] << '\n'; - Fatal() << "module failed to validate, see above"; + using InstanceInfo = std::pair<std::shared_ptr<ShellExternalInterface>, + std::shared_ptr<ModuleRunner>>; + + Result<InstanceInfo> instantiate(Module& wasm) { + try { + auto interface = + std::make_shared<ShellExternalInterface>(linkedInstances); + auto instance = + std::make_shared<ModuleRunner>(wasm, interface.get(), linkedInstances); + return {{std::move(interface), std::move(instance)}}; + } catch (...) { + return Err{"failed to instantiate module"}; } + } + + Result<> addModule(WASTModule& mod) { + auto module = makeModule(mod); + CHECK_ERR(module); + + auto wasm = *module; + CHECK_ERR(validateModule(*wasm)); + + auto instanceInfo = instantiate(*wasm); + CHECK_ERR(instanceInfo); + + auto& [interface, instance] = *instanceInfo; + lastModule = wasm->name; + modules[lastModule] = std::move(wasm); + interfaces[lastModule] = std::move(interface); + instances[lastModule] = std::move(instance); - instantiate(modules[moduleName].get()); return Ok{}; } - void parseRegister(Element& s) { + Result<> addRegistration(Register& reg) { auto instance = instances[lastModule]; if (!instance) { - Fatal() << "register called without a module"; + return Err{"register called without a module"}; } - auto name = s[1]->str(); - linkedInstances[name] = instance; + linkedInstances[reg.name] = instance; - // we copy pointers as a registered module's name might still be used + // We copy pointers as a registered module's name might still be used // in an assertion or invoke command. - modules[name] = modules[lastModule]; - interfaces[name] = interfaces[lastModule]; - instances[name] = instances[lastModule]; - - Colors::green(std::cerr); - std::cerr << "REGISTER MODULE INSTANCE AS \"" << name.str - << "\" [line: " << s.line << "]\n"; - Colors::normal(std::cerr); + modules[reg.name] = modules[lastModule]; + interfaces[reg.name] = interfaces[lastModule]; + instances[reg.name] = instances[lastModule]; + return Ok{}; } - Literals parseOperation(Element& s) { - Index i = 1; - Name moduleName = lastModule; - if (s[i]->dollared()) { - moduleName = s[i++]->str(); - } - ModuleRunner* instance = instances[moduleName].get(); - assert(instance); - - std::string baseStr = std::string("\"") + s[i++]->str().toString() + "\""; - auto base = Lexer(baseStr).takeString(); - if (!base) { - Fatal() << "expected string\n"; + struct TrapResult {}; + struct HostLimitResult {}; + struct ExceptionResult {}; + using ActionResult = + std::variant<Literals, TrapResult, HostLimitResult, ExceptionResult>; + + std::string resultToString(ActionResult& result) { + if (std::get_if<TrapResult>(&result)) { + return "trap"; + } else if (std::get_if<HostLimitResult>(&result)) { + return "exceeded host limit"; + } else if (std::get_if<ExceptionResult>(&result)) { + return "exception"; + } else if (auto* vals = std::get_if<Literals>(&result)) { + std::stringstream ss; + ss << *vals; + return ss.str(); + } else { + WASM_UNREACHABLE("unexpected result"); } + } - if (s[0]->str() == INVOKE) { - Literals args; - while (i < s.size()) { - auto* arg = parseExpression(*modules[moduleName], *s[i++]); - args.push_back(getLiteralFromConstExpression(arg)); + ActionResult doAction(Action& act) { + ModuleRunner* instance = instances[lastModule].get(); + assert(instance); + if (auto* invoke = std::get_if<InvokeAction>(&act)) { + auto it = instances.find(invoke->base ? *invoke->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; } - return instance->callExport(*base, args); - } else if (s[0]->str() == GET) { - return instance->getExport(*base); + auto& instance = it->second; + try { + return instance->callExport(invoke->name, invoke->args); + } catch (TrapException&) { + return TrapResult{}; + } catch (HostLimitException&) { + return HostLimitResult{}; + } catch (WasmException&) { + return ExceptionResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else if (auto* get = std::get_if<GetAction>(&act)) { + auto it = instances.find(get->base ? *get->base : lastModule); + if (it == instances.end()) { + return TrapResult{}; + } + auto& instance = it->second; + try { + return instance->getExport(get->name); + } catch (TrapException&) { + return TrapResult{}; + } catch (...) { + WASM_UNREACHABLE("unexpected error"); + } + } else { + WASM_UNREACHABLE("unexpected action"); } + } - Fatal() << "Invalid operation " << s[0]->toString(); + Result<> doAssertion(Assertion& assn) { + if (auto* ret = std::get_if<AssertReturn>(&assn)) { + return assertReturn(*ret); + } else if (auto* act = std::get_if<AssertAction>(&assn)) { + return assertAction(*act); + } else if (auto* mod = std::get_if<AssertModule>(&assn)) { + return assertModule(*mod); + } else { + WASM_UNREACHABLE("unexpected assertion"); + } } - void parseAssertTrap(Element& s) { - [[maybe_unused]] bool trapped = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> checkNaN(Literal val, NaNResult nan) { + std::stringstream err; + switch (nan.kind) { + case NaNKind::Canonical: + if (val.type != nan.type || !val.isCanonicalNaN()) { + err << "expected canonical " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; + case NaNKind::Arithmetic: + if (val.type != nan.type || !val.isArithmeticNaN()) { + err << "expected arithmetic " << nan.type << " NaN, got " << val; + return Err{err.str()}; + } + break; } + return Ok{}; + } - try { - parseOperation(inner); - } catch (const TrapException&) { - trapped = true; + Result<> checkLane(Literal val, LaneResult expected, Index index) { + std::stringstream err; + if (auto* e = std::get_if<Literal>(&expected)) { + if (*e != val) { + err << "expected " << *e << ", got " << val << " at lane " << index; + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << " at lane " << index; + return Err{err.str()}; + } + } else { + WASM_UNREACHABLE("unexpected lane expectation"); } - assert(trapped); + return Ok{}; } - void parseAssertException(Element& s) { - [[maybe_unused]] bool thrown = false; - auto& inner = *s[1]; - if (inner[0]->str() == MODULE) { - return parseModuleAssertion(s); + Result<> assertReturn(AssertReturn& assn) { + std::stringstream err; + auto result = doAction(assn.action); + auto* values = std::get_if<Literals>(&result); + if (!values) { + return Err{std::string("expected return, got ") + resultToString(result)}; } + if (values->size() != assn.expected.size()) { + err << "expected " << assn.expected.size() << " values, got " + << resultToString(result); + return Err{err.str()}; + } + for (Index i = 0; i < values->size(); ++i) { + auto atIndex = [&]() { + if (values->size() <= 1) { + return std::string{}; + } + std::stringstream ss; + ss << " at index " << i; + return ss.str(); + }; - try { - parseOperation(inner); - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - thrown = true; + Literal val = (*values)[i]; + auto& expected = assn.expected[i]; + if (auto* v = std::get_if<Literal>(&expected)) { + if (val != *v) { + err << "expected " << *v << ", got " << val << atIndex(); + return Err{err.str()}; + } + } else if (auto* ref = std::get_if<RefResult>(&expected)) { + if (!val.type.isRef() || val.type.getHeapType() != ref->type) { + err << "expected " << ref->type << " reference, got " << val + << atIndex(); + return Err{err.str()}; + } + } else if (auto* nan = std::get_if<NaNResult>(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } else if (auto* lanes = std::get_if<LaneResults>(&expected)) { + switch (lanes->size()) { + case 4: { + auto vals = val.getLanesF32x4(); + for (Index i = 0; i < 4; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + case 2: { + auto vals = val.getLanesF64x2(); + for (Index i = 0; i < 2; ++i) { + auto check = checkLane(vals[i], (*lanes)[i], i); + if (auto* e = check.getErr()) { + err << e->msg << atIndex(); + return Err{err.str()}; + } + } + break; + } + default: + WASM_UNREACHABLE("unexpected number of lanes"); + } + } else { + WASM_UNREACHABLE("unexpected expectation"); + } } - assert(thrown); + return Ok{}; } - void parseAssertReturn(Element& s) { - Literals actual; - Literals expected; - if (s.size() >= 3) { - expected = getLiteralsFromConstExpression( - parseExpression(*modules[lastModule], *s[2])); - } - [[maybe_unused]] bool trapped = false; - try { - actual = parseOperation(*s[1]); - } catch (const TrapException&) { - trapped = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - trapped = true; - } - assert(!trapped); - std::cerr << "seen " << actual << ", expected " << expected << '\n'; - if (expected != actual) { - Fatal() << "unexpected, should be identical\n"; + Result<> assertAction(AssertAction& assn) { + std::stringstream err; + auto result = doAction(assn.action); + switch (assn.type) { + case ActionAssertionType::Trap: + if (std::get_if<TrapResult>(&result)) { + return Ok{}; + } + err << "expected trap"; + break; + case ActionAssertionType::Exhaustion: + if (std::get_if<HostLimitResult>(&result)) { + return Ok{}; + } + err << "expected exhaustion"; + break; + case ActionAssertionType::Exception: + if (std::get_if<ExceptionResult>(&result)) { + return Ok{}; + } + err << "expected exception"; + break; } + err << ", got " << resultToString(result); + return Err{err.str()}; } - void parseModuleAssertion(Element& s) { - Module wasm; - wasm.features = FeatureSet::All; - std::unique_ptr<SExpressionWasmBuilder> builder; - auto id = s[0]->str(); + Result<> assertModule(AssertModule& assn) { + auto wasm = makeModule(assn.wasm); + if (const auto* err = wasm.getErr()) { + if (assn.type == ModuleAssertionType::Malformed || + assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; + } - bool invalid = false; - try { - SExpressionWasmBuilder(wasm, *s[1], IRProfile::Normal); - } catch (const ParseException&) { - invalid = true; + if (assn.type == ModuleAssertionType::Malformed) { + return Err{"expected malformed module"}; } - if (!invalid) { - // maybe parsed ok, but otherwise incorrect - invalid = !WasmValidator().validate(wasm); + auto valid = validateModule(**wasm); + if (auto* err = valid.getErr()) { + if (assn.type == ModuleAssertionType::Invalid) { + return Ok{}; + } + return Err{err->msg}; } - if (!invalid && id == ASSERT_UNLINKABLE) { - // validate "instantiating" the mdoule - auto reportUnknownImport = [&](Importable* import) { - auto it = linkedInstances.find(import->module); - if (it == linkedInstances.end() || - it->second->wasm.getExportOrNull(import->base) == nullptr) { - std::cerr << "unknown import: " << import->module << '.' - << import->base << '\n'; - invalid = true; - } - }; - ModuleUtils::iterImportedGlobals(wasm, reportUnknownImport); - ModuleUtils::iterImportedTables(wasm, reportUnknownImport); - ModuleUtils::iterImportedFunctions(wasm, [&](Importable* import) { - if (import->module == SPECTEST && import->base.startsWith(PRINT)) { - // We can handle it. - } else { - reportUnknownImport(import); - } - }); - ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) { - // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getFunction(name)) { - if (import->imported() && import->module == SPECTEST && - import->base.startsWith(PRINT)) { - std::cerr << "cannot put spectest.print in table\n"; - invalid = true; - } - } - }); - ModuleUtils::iterImportedMemories(wasm, reportUnknownImport); + if (assn.type == ModuleAssertionType::Invalid) { + return Err{"expected invalid module"}; } - if (!invalid && (id == ASSERT_TRAP || id == ASSERT_EXCEPTION)) { - try { - instantiate(&wasm); - } catch (const TrapException&) { - invalid = true; - } catch (const WasmException& e) { - std::cout << "[exception thrown: " << e << "]" << std::endl; - invalid = true; + auto instance = instantiate(**wasm); + if (auto* err = instance.getErr()) { + if (assn.type == ModuleAssertionType::Unlinkable || + assn.type == ModuleAssertionType::Trap) { + return Ok{}; } + return Err{err->msg}; } - if (!invalid) { - Colors::red(std::cerr); - std::cerr << "[should have been invalid]\n"; - Colors::normal(std::cerr); - Fatal() << &wasm << '\n'; + if (assn.type == ModuleAssertionType::Unlinkable) { + return Err{"expected unlinkable module"}; + } + if (assn.type == ModuleAssertionType::Trap) { + return Err{"expected instantiation to trap"}; } - } -protected: - Options& options; + WASM_UNREACHABLE("unexpected module assertion"); + } // spectest module is a default host-provided module defined by the spec's // reference interpreter. It's been replaced by the `(register ...)` @@ -345,7 +446,7 @@ protected: // is actually removed from the spec test. void buildSpectestModule() { auto spectest = std::make_shared<Module>(); - spectest->name = "spectest"; + spectest->features = FeatureSet::All; Builder builder(*spectest); spectest->addGlobal(builder.makeGlobal(Name::fromInt(0), @@ -388,45 +489,18 @@ protected: spectest->addExport( builder.makeExport("memory", memory->name, ExternalKind::Memory)); - modules["spectest"].swap(spectest); - modules["spectest"]->features = FeatureSet::All; - instantiate(modules["spectest"].get()); - linkedInstances["spectest"] = instances["spectest"]; // print_* functions are handled separately, no need to define here. - } - -public: - Shell(Options& options) : options(options) { buildSpectestModule(); } - - MaybeResult<> parseAndRun(Lexer& lexer) { - size_t i = 0; - while (!lexer.empty()) { - auto next = lexer.next(); - auto size = next.find('\n'); - if (size != std::string_view::npos) { - next = next.substr(0, size); - } else { - next = ""; - } - - if (!lexer.peekSExprStart("module")) { - Colors::red(std::cerr); - std::cerr << i; - Colors::green(std::cerr); - std::cerr << " CHECKING: "; - Colors::normal(std::cerr); - std::cerr << next; - Colors::green(std::cerr); - std::cerr << " [line: " << lexer.position().line << "]\n"; - Colors::normal(std::cerr); - } - CHECK_ERR(parse(lexer)); - - i += 1; + WASTModule mod = std::move(spectest); + auto added = addModule(mod); + if (added.getErr()) { + WASM_UNREACHABLE("error building spectest module"); + } + Register registration{"spectest"}; + auto registered = addRegistration(registration); + if (registered.getErr()) { + WASM_UNREACHABLE("error registering spectest module"); } - - return Ok{}; } }; @@ -453,16 +527,14 @@ int main(int argc, const char* argv[]) { } Lexer lexer(input); - auto result = Shell(options).parseAndRun(lexer); + auto result = Shell(options).run(*script); if (auto* err = result.getErr()) { std::cerr << err->msg << '\n'; exit(1); } - if (result) { - Colors::green(std::cerr); - Colors::bold(std::cerr); - std::cerr << "all checks passed.\n"; - Colors::normal(std::cerr); - } + Colors::green(std::cerr); + Colors::bold(std::cerr); + std::cerr << "all checks passed.\n"; + Colors::normal(std::cerr); } diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 8f86b3647..f37e45240 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -420,7 +420,6 @@ public: ret->valueType = type; ret->memory = memory; ret->finalize(); - assert(ret->value->type.isConcrete() ? ret->value->type == type : true); return ret; } Store* makeAtomicStore(unsigned bytes, diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index fa129f14a..f9b7e8c50 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -4103,7 +4103,7 @@ public: // can use it (refactor?) Literals callFunctionInternal(Name name, Literals arguments) { if (callDepth > maxDepth) { - externalInterface->trap("stack limit"); + hostLimit("stack limit"); } Flow flow; diff --git a/src/wasm-type.h b/src/wasm-type.h index 6422d09ad..b02484ae3 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -634,6 +634,8 @@ struct TypeBuilder { ForwardSupertypeReference, // A child of the type is an invalid forward reference. ForwardChildReference, + // A continuation reference that does not refer to a function type. + InvalidFuncType, }; struct Error { diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp index 2c61be12a..d4e1bfc45 100644 --- a/src/wasm/literal.cpp +++ b/src/wasm/literal.cpp @@ -467,6 +467,22 @@ bool Literal::isNaN() { return false; } +bool Literal::isCanonicalNaN() { + if (!isNaN()) { + return false; + } + return (type == Type::f32 && NaNPayload(getf32()) == (1u << 23) - 1) || + (type == Type::f64 && NaNPayload(getf64()) == (1ull << 52) - 1); +} + +bool Literal::isArithmeticNaN() { + if (!isNaN()) { + return false; + } + return (type == Type::f32 && NaNPayload(getf32()) > (1u << 23) - 1) || + (type == Type::f64 && NaNPayload(getf64()) > (1ull << 52) - 1); +} + uint32_t Literal::NaNPayload(float f) { assert(std::isnan(f) && "expected a NaN"); // SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index 2eb5369d0..3e215b23a 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -1666,6 +1666,8 @@ std::ostream& operator<<(std::ostream& os, TypeBuilder::ErrorReason reason) { return os << "Heap type has an undeclared supertype"; case TypeBuilder::ErrorReason::ForwardChildReference: return os << "Heap type has an undeclared child"; + case TypeBuilder::ErrorReason::InvalidFuncType: + return os << "Continuation has invalid function type"; } WASM_UNREACHABLE("Unexpected error reason"); } @@ -2616,7 +2618,7 @@ buildRecGroup(std::unique_ptr<RecGroupInfo>&& groupInfo, updateReferencedHeapTypes(info, canonicalized); } - // Collect the types and check supertype validity. + // Collect the types and check validity. std::unordered_set<HeapType> seenTypes; for (size_t i = 0; i < typeInfos.size(); ++i) { auto& info = typeInfos[i]; @@ -2633,6 +2635,12 @@ buildRecGroup(std::unique_ptr<RecGroupInfo>&& groupInfo, TypeBuilder::Error{i, TypeBuilder::ErrorReason::InvalidSupertype}}; } } + if (info->isContinuation()) { + if (!info->continuation.type.isSignature()) { + return { + TypeBuilder::Error{i, TypeBuilder::ErrorReason::InvalidFuncType}}; + } + } seenTypes.insert(asHeapType(info)); } diff --git a/test/spec/address64.wast b/test/spec/address64.wast index b3b009ae0..29771ae77 100644 --- a/test/spec/address64.wast +++ b/test/spec/address64.wast @@ -203,13 +203,6 @@ (assert_trap (invoke "16s_bad" (i64.const 1)) "out of bounds memory access") (assert_trap (invoke "32_bad" (i64.const 1)) "out of bounds memory access") -(assert_malformed - (module quote - "(memory i64 1)" - "(func (drop (i32.load offset=4294967296 (i64.const 0))))" - ) - "i32 constant" -) ;; Load i64 data with different offset/align arguments diff --git a/test/spec/call.wast b/test/spec/call.wast index 89082fbed..4d0f1a7c2 100644 --- a/test/spec/call.wast +++ b/test/spec/call.wast @@ -279,8 +279,8 @@ (assert_return (invoke "odd" (i64.const 200)) (i32.const 99)) (assert_return (invoke "odd" (i64.const 77)) (i32.const 44)) -;; (assert_exhaustion (invoke "runaway") "call stack exhausted") -;; (assert_exhaustion (invoke "mutual-runaway") "call stack exhausted") +(assert_exhaustion (invoke "runaway") "call stack exhausted") +(assert_exhaustion (invoke "mutual-runaway") "call stack exhausted") (assert_return (invoke "as-select-first") (i32.const 0x132)) (assert_return (invoke "as-select-mid") (i32.const 2)) diff --git a/test/spec/call_indirect.wast b/test/spec/call_indirect.wast index 791a756ca..87d1df75a 100644 --- a/test/spec/call_indirect.wast +++ b/test/spec/call_indirect.wast @@ -553,8 +553,8 @@ (assert_return (invoke "odd" (i32.const 200)) (i32.const 99)) (assert_return (invoke "odd" (i32.const 77)) (i32.const 44)) -;; (assert_exhaustion (invoke "runaway") "call stack exhausted") -;; (assert_exhaustion (invoke "mutual-runaway") "call stack exhausted") +(assert_exhaustion (invoke "runaway") "call stack exhausted") +(assert_exhaustion (invoke "mutual-runaway") "call stack exhausted") (assert_return (invoke "as-select-first") (i32.const 0x132)) (assert_return (invoke "as-select-mid") (i32.const 2)) diff --git a/test/spec/exception-handling-old.wast b/test/spec/exception-handling-old.wast index 5024fa734..6b4631877 100644 --- a/test/spec/exception-handling-old.wast +++ b/test/spec/exception-handling-old.wast @@ -352,40 +352,34 @@ "tag's param numbers must match" ) -(assert_invalid - (module - (func $f0 - (block $l0 - (try - (do - (try - (do) - (delegate $l0) ;; target is a block - ) +(module + (func $f0 + (block $l0 + (try + (do + (try + (do) + (delegate $l0) ;; target is a block ) - (catch_all) ) + (catch_all) ) ) ) - "all delegate targets must be valid" ) -(assert_invalid - (module - (func $f0 - (try $l0 - (do) - (catch_all - (try - (do) - (delegate $l0) ;; the target catch is above the delegate - ) +(module + (func $f0 + (try $l0 + (do) + (catch_all + (try + (do) + (delegate $l0) ;; the target catch is above the delegate ) ) ) ) - "all delegate targets must be valid" ) (assert_invalid diff --git a/test/spec/fac.wast b/test/spec/fac.wast index 521cdc459..ef10991a8 100644 --- a/test/spec/fac.wast +++ b/test/spec/fac.wast @@ -86,4 +86,4 @@ (assert_return (invoke "fac-rec-named" (i64.const 25)) (i64.const 7034535277573963776)) (assert_return (invoke "fac-iter-named" (i64.const 25)) (i64.const 7034535277573963776)) (assert_return (invoke "fac-opt" (i64.const 25)) (i64.const 7034535277573963776)) -;; (assert_exhaustion (invoke "fac-rec" (i64.const 1073741824)) "call stack exhausted") +(assert_exhaustion (invoke "fac-rec" (i64.const 1073741824)) "call stack exhausted") diff --git a/test/spec/inline-module.wast b/test/spec/inline-module.wast index a8871dfb2..dc7ead776 100644 --- a/test/spec/inline-module.wast +++ b/test/spec/inline-module.wast @@ -1 +1 @@ -;; (func) (memory 0) (func (export "f")) +(func) (memory 0) (func (export "f")) diff --git a/test/spec/skip-stack-guard-page.wast b/test/spec/skip-stack-guard-page.wast index 4f9273eb6..a472e6814 100644 --- a/test/spec/skip-stack-guard-page.wast +++ b/test/spec/skip-stack-guard-page.wast @@ -2272,13 +2272,13 @@ ) ) -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 0)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 100)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 200)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 300)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 400)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 500)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 600)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 700)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 800)) "call stack exhausted") -;; (assert_exhaustion (invoke "test-guard-page-skip" (i32.const 900)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 0)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 100)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 200)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 300)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 400)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 500)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 600)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 700)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 800)) "call stack exhausted") +(assert_exhaustion (invoke "test-guard-page-skip" (i32.const 900)) "call stack exhausted") |