summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Lively <tlively@google.com>2024-05-17 17:49:45 -0700
committerGitHub <noreply@github.com>2024-05-17 17:49:45 -0700
commit921644ca65afbafb84fb82d58dacc4a028e2d720 (patch)
tree9253fbcf3f1dd9930dd1b9bc9f545234399b918e /src
parent369cddfb44ddbada2ef7742a9ebef54727d12dd5 (diff)
downloadbinaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.tar.gz
binaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.tar.bz2
binaryen-921644ca65afbafb84fb82d58dacc4a028e2d720.zip
Rewrite wasm-shell to use new wast parser (#6601)
Use the new wast parser to parse a full script up front, then traverse the parsed script data structure and execute the commands. wasm-shell had previously used the new wat parser for top-level modules, but it now uses the new parser for module assertions as well. Fix various bugs this uncovered. After this change, wasm-shell supports all the assertions used in the upstream spec tests (although not new kinds of assertions introduced in any proposals). Uncomment various `assert_exhaustion` tests that we can now execute. Other kinds of assertions remain commented out in our tests: wasm-shell now supports `assert_unlinkable`, but the interpreter does not eagerly check for the existence of imports, so those tests do not pass. Tests that check for NaNs also remain commented out because they do not yet use the standard syntax that wasm-shell now supports for canonical and arithmetic NaN results, and our interpreter would not pass all of those tests even if they did use the standard syntax.
Diffstat (limited to 'src')
-rw-r--r--src/literal.h2
-rw-r--r--src/parser/contexts.h5
-rw-r--r--src/parser/lexer.cpp4
-rw-r--r--src/parser/parsers.h22
-rw-r--r--src/parser/wast-parser.cpp32
-rw-r--r--src/parser/wat-parser.cpp8
-rw-r--r--src/parser/wat-parser.h24
-rw-r--r--src/tools/wasm-shell.cpp622
-rw-r--r--src/wasm-builder.h1
-rw-r--r--src/wasm-interpreter.h2
-rw-r--r--src/wasm-type.h2
-rw-r--r--src/wasm/literal.cpp16
-rw-r--r--src/wasm/wasm-type.cpp10
13 files changed, 429 insertions, 321 deletions
diff --git a/src/literal.h b/src/literal.h
index a4017f6ec..1268448fb 100644
--- a/src/literal.h
+++ b/src/literal.h
@@ -347,6 +347,8 @@ public:
bool operator!=(const Literal& other) const;
bool isNaN();
+ bool isCanonicalNaN();
+ bool isArithmeticNaN();
static uint32_t NaNPayload(float f);
static uint64_t NaNPayload(double f);
diff --git a/src/parser/contexts.h b/src/parser/contexts.h
index 5ad0c16de..bce051a93 100644
--- a/src/parser/contexts.h
+++ b/src/parser/contexts.h
@@ -1665,7 +1665,10 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
return Ok{};
}
- Result<> addExport(Index, Name value, Name name, ExternalKind kind) {
+ Result<> addExport(Index pos, Name value, Name name, ExternalKind kind) {
+ if (wasm.getExportOrNull(name)) {
+ return in.err(pos, "duplicate export");
+ }
wasm.addExport(builder.makeExport(name, value, kind));
return Ok{};
}
diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp
index fd0a262b8..bb6428e87 100644
--- a/src/parser/lexer.cpp
+++ b/src/parser/lexer.cpp
@@ -23,6 +23,7 @@
#include <variant>
#include "lexer.h"
+#include "support/bits.h"
#include "support/string.h"
using namespace std::string_view_literals;
@@ -1005,6 +1006,9 @@ std::optional<uint32_t> Lexer::takeAlign() {
}
Lexer subLexer(result->span.substr(6));
if (auto o = subLexer.takeU32()) {
+ if (Bits::popCount(*o) != 1) {
+ return std::nullopt;
+ }
pos += result->span.size();
advance();
return o;
diff --git a/src/parser/parsers.h b/src/parser/parsers.h
index 88600fec3..5900deb27 100644
--- a/src/parser/parsers.h
+++ b/src/parser/parsers.h
@@ -30,7 +30,8 @@ template<typename Ctx> Result<typename Ctx::HeapTypeT> heaptype(Ctx&);
template<typename Ctx> MaybeResult<typename Ctx::RefTypeT> reftype(Ctx&);
template<typename Ctx> MaybeResult<typename Ctx::TypeT> tupletype(Ctx&);
template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx&);
-template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx&);
+template<typename Ctx>
+MaybeResult<typename Ctx::ParamsT> params(Ctx&, bool allowNames = true);
template<typename Ctx> MaybeResult<typename Ctx::ResultsT> results(Ctx&);
template<typename Ctx> MaybeResult<typename Ctx::SignatureT> functype(Ctx&);
template<typename Ctx> Result<typename Ctx::FieldT> storagetype(Ctx&);
@@ -325,7 +326,8 @@ MaybeResult<typename Ctx::LabelIdxT> maybeLabelidx(Ctx&,
template<typename Ctx>
Result<typename Ctx::LabelIdxT> labelidx(Ctx&, bool inDelegate = false);
template<typename Ctx> Result<typename Ctx::TagIdxT> tagidx(Ctx&);
-template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx&);
+template<typename Ctx>
+Result<typename Ctx::TypeUseT> typeuse(Ctx&, bool allowNames = true);
MaybeResult<ImportNames> inlineImport(Lexer&);
Result<std::vector<Name>> inlineExports(Lexer&);
template<typename Ctx> Result<> strtype(Ctx&);
@@ -561,13 +563,18 @@ template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx& ctx) {
// param ::= '(' 'param id? t:valtype ')' => [t]
// | '(' 'param t*:valtype* ')' => [t*]
// params ::= param*
-template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx& ctx) {
+template<typename Ctx>
+MaybeResult<typename Ctx::ParamsT> params(Ctx& ctx, bool allowNames) {
bool hasAny = false;
auto res = ctx.makeParams();
while (ctx.in.takeSExprStart("param"sv)) {
hasAny = true;
+ auto pos = ctx.in.getPos();
if (auto id = ctx.in.takeID()) {
// Single named param
+ if (!allowNames) {
+ return ctx.in.err(pos, "unexpected named parameter");
+ }
auto type = valtype(ctx);
CHECK_ERR(type);
if (!ctx.in.takeRParen()) {
@@ -1065,7 +1072,7 @@ template<typename Ctx> Result<typename Ctx::BlockTypeT> blocktype(Ctx& ctx) {
// We either had no results or multiple results. Reset and parse again as a
// type use.
ctx.in = initialLexer;
- auto use = typeuse(ctx);
+ auto use = typeuse(ctx, false);
CHECK_ERR(use);
auto type = ctx.getBlockTypeFromTypeUse(pos, *use);
@@ -1935,7 +1942,7 @@ Result<> makeCallIndirect(Ctx& ctx,
bool isReturn) {
auto table = maybeTableidx(ctx);
CHECK_ERR(table);
- auto type = typeuse(ctx);
+ auto type = typeuse(ctx, false);
CHECK_ERR(type);
return ctx.makeCallIndirect(
pos, annotations, table.getPtr(), *type, isReturn);
@@ -2669,7 +2676,8 @@ template<typename Ctx> Result<typename Ctx::TagIdxT> tagidx(Ctx& ctx) {
// (if typedefs[x] = [t1*] -> [t2*])
// | ((t1,IDs):param)* (t2:result)* => x, IDs
// (if x is minimum s.t. typedefs[x] = [t1*] -> [t2*])
-template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx) {
+template<typename Ctx>
+Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx, bool allowNames) {
auto pos = ctx.in.getPos();
std::optional<typename Ctx::HeapTypeT> type;
if (ctx.in.takeSExprStart("type"sv)) {
@@ -2683,7 +2691,7 @@ template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx) {
type = *x;
}
- auto namedParams = params(ctx);
+ auto namedParams = params(ctx, allowNames);
CHECK_ERR(namedParams);
auto resultTypes = results(ctx);
diff --git a/src/parser/wast-parser.cpp b/src/parser/wast-parser.cpp
index fb0dce932..87060b9fc 100644
--- a/src/parser/wast-parser.cpp
+++ b/src/parser/wast-parser.cpp
@@ -41,8 +41,7 @@ Result<Literals> consts(Lexer& in) {
MaybeResult<Action> action(Lexer& in) {
if (in.takeSExprStart("invoke"sv)) {
- // TODO: Do we need to use this optional id?
- in.takeID();
+ auto id = in.takeID();
auto name = in.takeName();
if (!name) {
return in.err("expected export name");
@@ -52,12 +51,11 @@ MaybeResult<Action> action(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of invoke action");
}
- return InvokeAction{*name, *args};
+ return InvokeAction{id, *name, *args};
}
if (in.takeSExprStart("get"sv)) {
- // TODO: Do we need to use this optional id?
- in.takeID();
+ auto id = in.takeID();
auto name = in.takeName();
if (!name) {
return in.err("expected export name");
@@ -65,7 +63,7 @@ MaybeResult<Action> action(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of get action");
}
- return GetAction{*name};
+ return GetAction{id, *name};
}
return {};
@@ -236,7 +234,7 @@ MaybeResult<AssertReturn> assertReturn(Lexer& in) {
}
// (assert_exception action)
-MaybeResult<AssertException> assertException(Lexer& in) {
+MaybeResult<AssertAction> assertException(Lexer& in) {
if (!in.takeSExprStart("assert_exception"sv)) {
return {};
}
@@ -245,7 +243,7 @@ MaybeResult<AssertException> assertException(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of assert_exception");
}
- return AssertException{*a};
+ return AssertAction{ActionAssertionType::Exception, *a};
}
// (assert_exhaustion action msg)
@@ -266,7 +264,7 @@ MaybeResult<AssertAction> assertAction(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of assertion");
}
- return AssertAction{type, *a, *msg};
+ return AssertAction{type, *a};
}
// (assert_malformed module msg)
@@ -293,7 +291,7 @@ MaybeResult<AssertModule> assertModule(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of assertion");
}
- return AssertModule{type, *mod, *msg};
+ return AssertModule{type, *mod};
}
// (assert_trap action msg)
@@ -312,7 +310,7 @@ MaybeResult<Assertion> assertTrap(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of assertion");
}
- return Assertion{AssertAction{ActionAssertionType::Trap, *a, *msg}};
+ return Assertion{AssertAction{ActionAssertionType::Trap, *a}};
}
auto mod = wastModule(in);
if (mod.getErr()) {
@@ -325,7 +323,7 @@ MaybeResult<Assertion> assertTrap(Lexer& in) {
if (!in.takeRParen()) {
return in.err("expected end of assertion");
}
- return Assertion{AssertModule{ModuleAssertionType::Trap, *mod, *msg}};
+ return Assertion{AssertModule{ModuleAssertionType::Trap, *mod}};
}
MaybeResult<Assertion> assertion(Lexer& in) {
@@ -391,24 +389,30 @@ Result<WASTCommand> command(Lexer& in) {
return *mod;
}
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
+
Result<WASTScript> wast(Lexer& in) {
WASTScript cmds;
while (!in.empty()) {
+ size_t line = in.position().line;
auto cmd = command(in);
if (cmd.getErr() && cmds.empty()) {
// The entire script might be a single module comprising a sequence of
// module fields with a top-level `(module ...)`.
auto wasm = std::make_shared<Module>();
CHECK_ERR(parseModule(*wasm, in.buffer));
- cmds.emplace_back(std::move(wasm));
+ cmds.push_back({WASTModule{std::move(wasm)}, line});
return cmds;
}
CHECK_ERR(cmd);
- cmds.emplace_back(std::move(*cmd));
+ cmds.push_back(ScriptEntry{std::move(*cmd), line});
}
return cmds;
}
+#pragma GCC diagnostic pop
+
} // anonymous namespace
Result<WASTScript> parseScript(std::string_view in) {
diff --git a/src/parser/wat-parser.cpp b/src/parser/wat-parser.cpp
index 85ef1e80f..4763c69ec 100644
--- a/src/parser/wat-parser.cpp
+++ b/src/parser/wat-parser.cpp
@@ -237,14 +237,6 @@ Result<> parseModule(Module& wasm, Lexer& lexer) {
return doParseModule(wasm, lexer, true);
}
-Result<Expression*> parseExpression(Module& wasm, Lexer& lexer) {
- ParseDefsCtx ctx(lexer, wasm, {}, {}, {}, {}, {});
- auto e = expr(ctx);
- CHECK_ERR(e);
- lexer = ctx.in;
- return *e;
-}
-
Result<Literal> parseConst(Lexer& lexer) {
Module wasm;
ParseDefsCtx ctx(lexer, wasm, {}, {}, {}, {}, {});
diff --git a/src/parser/wat-parser.h b/src/parser/wat-parser.h
index 7fe6abfdd..041ba1d58 100644
--- a/src/parser/wat-parser.h
+++ b/src/parser/wat-parser.h
@@ -34,14 +34,14 @@ Result<> parseModule(Module& wasm, Lexer& lexer);
Result<Literal> parseConst(Lexer& lexer);
-Result<Expression*> parseExpression(Module& wasm, Lexer& lexer);
-
struct InvokeAction {
+ std::optional<Name> base;
Name name;
Literals args;
};
struct GetAction {
+ std::optional<Name> base;
Name name;
};
@@ -68,19 +68,14 @@ using ExpectedResults = std::vector<ExpectedResult>;
struct AssertReturn {
Action action;
- ExpectedResults results;
-};
-
-struct AssertException {
- Action action;
+ ExpectedResults expected;
};
-enum class ActionAssertionType { Trap, Exhaustion };
+enum class ActionAssertionType { Trap, Exhaustion, Exception };
struct AssertAction {
ActionAssertionType type;
Action action;
- std::string msg;
};
enum class QuotedModuleType { Text, Binary };
@@ -97,11 +92,9 @@ enum class ModuleAssertionType { Trap, Malformed, Invalid, Unlinkable };
struct AssertModule {
ModuleAssertionType type;
WASTModule wasm;
- std::string msg;
};
-using Assertion =
- std::variant<AssertReturn, AssertException, AssertAction, AssertModule>;
+using Assertion = std::variant<AssertReturn, AssertAction, AssertModule>;
struct Register {
Name name;
@@ -109,7 +102,12 @@ struct Register {
using WASTCommand = std::variant<WASTModule, Register, Action, Assertion>;
-using WASTScript = std::vector<WASTCommand>;
+struct ScriptEntry {
+ WASTCommand cmd;
+ size_t line;
+};
+
+using WASTScript = std::vector<ScriptEntry>;
Result<WASTScript> parseScript(std::string_view in);
diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp
index 3fe8b3505..913a7a58a 100644
--- a/src/tools/wasm-shell.cpp
+++ b/src/tools/wasm-shell.cpp
@@ -30,26 +30,16 @@
#include "support/command-line.h"
#include "support/file.h"
#include "support/result.h"
+#include "wasm-binary.h"
#include "wasm-interpreter.h"
#include "wasm-s-parser.h"
#include "wasm-validator.h"
using namespace wasm;
-using Lexer = WATParser::Lexer;
+using namespace wasm::WATParser;
-Name ASSERT_RETURN("assert_return");
-Name ASSERT_TRAP("assert_trap");
-Name ASSERT_EXCEPTION("assert_exception");
-Name ASSERT_INVALID("assert_invalid");
-Name ASSERT_MALFORMED("assert_malformed");
-Name ASSERT_UNLINKABLE("assert_unlinkable");
-Name INVOKE("invoke");
-Name REGISTER("register");
-Name GET("get");
-
-class Shell {
-protected:
+struct Shell {
std::map<Name, std::shared_ptr<Module>> modules;
std::map<Name, std::shared_ptr<ShellExternalInterface>> interfaces;
std::map<Name, std::shared_ptr<ModuleRunner>> instances;
@@ -58,284 +48,395 @@ protected:
Name lastModule;
- void instantiate(Module* wasm) {
- auto tempInterface =
- std::make_shared<ShellExternalInterface>(linkedInstances);
- auto tempInstance = std::make_shared<ModuleRunner>(
- *wasm, tempInterface.get(), linkedInstances);
- interfaces[wasm->name].swap(tempInterface);
- instances[wasm->name].swap(tempInstance);
- }
-
- Result<std::string> parseSExpr(Lexer& lexer) {
- auto begin = lexer.getPos();
+ Options& options;
- if (!lexer.takeLParen()) {
- return lexer.err("expected s-expression");
- }
+ Shell(Options& options) : options(options) { buildSpectestModule(); }
- size_t count = 1;
- while (count != 0 && lexer.takeUntilParen()) {
- if (lexer.takeLParen()) {
- ++count;
- } else if (lexer.takeRParen()) {
- --count;
+ Result<> run(WASTScript& script) {
+ size_t i = 0;
+ for (auto& entry : script) {
+ Colors::red(std::cerr);
+ std::cerr << i << ' ';
+ Colors::normal(std::cerr);
+ if (std::get_if<WASTModule>(&entry.cmd)) {
+ Colors::green(std::cerr);
+ std::cerr << "BUILDING MODULE [line: " << entry.line << "]\n";
+ Colors::normal(std::cerr);
+ } else if (auto* reg = std::get_if<Register>(&entry.cmd)) {
+ Colors::green(std::cerr);
+ std::cerr << "REGISTER MODULE INSTANCE AS \"" << reg->name
+ << "\" [line: " << entry.line << "]\n";
+ Colors::normal(std::cerr);
} else {
- WASM_UNREACHABLE("unexpected token");
+ Colors::green(std::cerr);
+ std::cerr << "CHECKING [line: " << entry.line << "]\n";
+ Colors::normal(std::cerr);
}
+ ++i;
+ CHECK_ERR(runCommand(entry.cmd));
}
-
- if (count != 0) {
- return lexer.err("unexpected unterminated s-expression");
- }
-
- return std::string(lexer.buffer.substr(begin, lexer.getPos() - begin));
+ return Ok{};
}
- Expression* parseExpression(Module& wasm, Element& s) {
- std::stringstream ss;
- ss << s;
- auto str = ss.str();
- Lexer lexer(str);
- auto arg = WATParser::parseExpression(wasm, lexer);
- if (auto* err = arg.getErr()) {
- Fatal() << err->msg << '\n';
+ Result<> runCommand(WASTCommand& cmd) {
+ if (auto* mod = std::get_if<WASTModule>(&cmd)) {
+ return addModule(*mod);
+ } else if (auto* reg = std::get_if<Register>(&cmd)) {
+ return addRegistration(*reg);
+ } else if (auto* act = std::get_if<Action>(&cmd)) {
+ doAction(*act);
+ return Ok{};
+ } else if (auto* assn = std::get_if<Assertion>(&cmd)) {
+ return doAssertion(*assn);
+ } else {
+ WASM_UNREACHABLE("unexpected command");
}
- return *arg;
}
- Result<> parse(Lexer& lexer) {
- if (auto res = parseModule(lexer)) {
- CHECK_ERR(res);
- return Ok{};
+ Result<std::shared_ptr<Module>> makeModule(WASTModule& mod) {
+ std::shared_ptr<Module> wasm;
+ if (auto* quoted = std::get_if<QuotedModule>(&mod)) {
+ wasm = std::make_shared<Module>();
+ switch (quoted->type) {
+ case QuotedModuleType::Text: {
+ CHECK_ERR(parseModule(*wasm, quoted->module));
+ break;
+ }
+ case QuotedModuleType::Binary: {
+ std::vector<char> buffer(quoted->module.begin(),
+ quoted->module.end());
+ WasmBinaryReader reader(*wasm, FeatureSet::All, buffer);
+ try {
+ reader.read();
+ } catch (ParseException& p) {
+ std::stringstream ss;
+ p.dump(ss);
+ return Err{ss.str()};
+ }
+ break;
+ }
+ }
+ } else if (auto* ptr = std::get_if<std::shared_ptr<Module>>(&mod)) {
+ wasm = *ptr;
+ } else {
+ WASM_UNREACHABLE("unexpected module kind");
}
+ wasm->features = FeatureSet::All;
+ return wasm;
+ }
- auto pos = lexer.getPos();
- auto sexpr = parseSExpr(lexer);
- CHECK_ERR(sexpr);
-
- SExpressionParser parser(sexpr->data());
- Element& s = *parser.root[0][0];
- IString id = s[0]->str();
- if (id == REGISTER) {
- parseRegister(s);
- } else if (id == INVOKE) {
- parseOperation(s);
- } else if (id == ASSERT_RETURN) {
- parseAssertReturn(s);
- } else if (id == ASSERT_TRAP) {
- parseAssertTrap(s);
- } else if (id == ASSERT_EXCEPTION) {
- parseAssertException(s);
- } else if ((id == ASSERT_INVALID) || (id == ASSERT_MALFORMED) ||
- (id == ASSERT_UNLINKABLE)) {
- parseModuleAssertion(s);
- } else {
- return lexer.err(pos, "unrecognized command");
+ Result<> validateModule(Module& wasm) {
+ if (!WasmValidator().validate(wasm)) {
+ return Err{"failed validation"};
}
return Ok{};
}
- MaybeResult<> parseModule(Lexer& lexer) {
- if (!lexer.peekSExprStart("module")) {
- return {};
- }
- Colors::green(std::cerr);
- std::cerr << "BUILDING MODULE [line: " << lexer.position().line << "]\n";
- Colors::normal(std::cerr);
- auto module = std::make_shared<Module>();
-
- CHECK_ERR(WATParser::parseModule(*module, lexer));
-
- auto moduleName = module->name;
- lastModule = module->name;
- modules[moduleName].swap(module);
- modules[moduleName]->features = FeatureSet::All;
- bool valid = WasmValidator().validate(*modules[moduleName]);
- if (!valid) {
- std::cout << *modules[moduleName] << '\n';
- Fatal() << "module failed to validate, see above";
+ using InstanceInfo = std::pair<std::shared_ptr<ShellExternalInterface>,
+ std::shared_ptr<ModuleRunner>>;
+
+ Result<InstanceInfo> instantiate(Module& wasm) {
+ try {
+ auto interface =
+ std::make_shared<ShellExternalInterface>(linkedInstances);
+ auto instance =
+ std::make_shared<ModuleRunner>(wasm, interface.get(), linkedInstances);
+ return {{std::move(interface), std::move(instance)}};
+ } catch (...) {
+ return Err{"failed to instantiate module"};
}
+ }
+
+ Result<> addModule(WASTModule& mod) {
+ auto module = makeModule(mod);
+ CHECK_ERR(module);
+
+ auto wasm = *module;
+ CHECK_ERR(validateModule(*wasm));
+
+ auto instanceInfo = instantiate(*wasm);
+ CHECK_ERR(instanceInfo);
+
+ auto& [interface, instance] = *instanceInfo;
+ lastModule = wasm->name;
+ modules[lastModule] = std::move(wasm);
+ interfaces[lastModule] = std::move(interface);
+ instances[lastModule] = std::move(instance);
- instantiate(modules[moduleName].get());
return Ok{};
}
- void parseRegister(Element& s) {
+ Result<> addRegistration(Register& reg) {
auto instance = instances[lastModule];
if (!instance) {
- Fatal() << "register called without a module";
+ return Err{"register called without a module"};
}
- auto name = s[1]->str();
- linkedInstances[name] = instance;
+ linkedInstances[reg.name] = instance;
- // we copy pointers as a registered module's name might still be used
+ // We copy pointers as a registered module's name might still be used
// in an assertion or invoke command.
- modules[name] = modules[lastModule];
- interfaces[name] = interfaces[lastModule];
- instances[name] = instances[lastModule];
-
- Colors::green(std::cerr);
- std::cerr << "REGISTER MODULE INSTANCE AS \"" << name.str
- << "\" [line: " << s.line << "]\n";
- Colors::normal(std::cerr);
+ modules[reg.name] = modules[lastModule];
+ interfaces[reg.name] = interfaces[lastModule];
+ instances[reg.name] = instances[lastModule];
+ return Ok{};
}
- Literals parseOperation(Element& s) {
- Index i = 1;
- Name moduleName = lastModule;
- if (s[i]->dollared()) {
- moduleName = s[i++]->str();
- }
- ModuleRunner* instance = instances[moduleName].get();
- assert(instance);
-
- std::string baseStr = std::string("\"") + s[i++]->str().toString() + "\"";
- auto base = Lexer(baseStr).takeString();
- if (!base) {
- Fatal() << "expected string\n";
+ struct TrapResult {};
+ struct HostLimitResult {};
+ struct ExceptionResult {};
+ using ActionResult =
+ std::variant<Literals, TrapResult, HostLimitResult, ExceptionResult>;
+
+ std::string resultToString(ActionResult& result) {
+ if (std::get_if<TrapResult>(&result)) {
+ return "trap";
+ } else if (std::get_if<HostLimitResult>(&result)) {
+ return "exceeded host limit";
+ } else if (std::get_if<ExceptionResult>(&result)) {
+ return "exception";
+ } else if (auto* vals = std::get_if<Literals>(&result)) {
+ std::stringstream ss;
+ ss << *vals;
+ return ss.str();
+ } else {
+ WASM_UNREACHABLE("unexpected result");
}
+ }
- if (s[0]->str() == INVOKE) {
- Literals args;
- while (i < s.size()) {
- auto* arg = parseExpression(*modules[moduleName], *s[i++]);
- args.push_back(getLiteralFromConstExpression(arg));
+ ActionResult doAction(Action& act) {
+ ModuleRunner* instance = instances[lastModule].get();
+ assert(instance);
+ if (auto* invoke = std::get_if<InvokeAction>(&act)) {
+ auto it = instances.find(invoke->base ? *invoke->base : lastModule);
+ if (it == instances.end()) {
+ return TrapResult{};
}
- return instance->callExport(*base, args);
- } else if (s[0]->str() == GET) {
- return instance->getExport(*base);
+ auto& instance = it->second;
+ try {
+ return instance->callExport(invoke->name, invoke->args);
+ } catch (TrapException&) {
+ return TrapResult{};
+ } catch (HostLimitException&) {
+ return HostLimitResult{};
+ } catch (WasmException&) {
+ return ExceptionResult{};
+ } catch (...) {
+ WASM_UNREACHABLE("unexpected error");
+ }
+ } else if (auto* get = std::get_if<GetAction>(&act)) {
+ auto it = instances.find(get->base ? *get->base : lastModule);
+ if (it == instances.end()) {
+ return TrapResult{};
+ }
+ auto& instance = it->second;
+ try {
+ return instance->getExport(get->name);
+ } catch (TrapException&) {
+ return TrapResult{};
+ } catch (...) {
+ WASM_UNREACHABLE("unexpected error");
+ }
+ } else {
+ WASM_UNREACHABLE("unexpected action");
}
+ }
- Fatal() << "Invalid operation " << s[0]->toString();
+ Result<> doAssertion(Assertion& assn) {
+ if (auto* ret = std::get_if<AssertReturn>(&assn)) {
+ return assertReturn(*ret);
+ } else if (auto* act = std::get_if<AssertAction>(&assn)) {
+ return assertAction(*act);
+ } else if (auto* mod = std::get_if<AssertModule>(&assn)) {
+ return assertModule(*mod);
+ } else {
+ WASM_UNREACHABLE("unexpected assertion");
+ }
}
- void parseAssertTrap(Element& s) {
- [[maybe_unused]] bool trapped = false;
- auto& inner = *s[1];
- if (inner[0]->str() == MODULE) {
- return parseModuleAssertion(s);
+ Result<> checkNaN(Literal val, NaNResult nan) {
+ std::stringstream err;
+ switch (nan.kind) {
+ case NaNKind::Canonical:
+ if (val.type != nan.type || !val.isCanonicalNaN()) {
+ err << "expected canonical " << nan.type << " NaN, got " << val;
+ return Err{err.str()};
+ }
+ break;
+ case NaNKind::Arithmetic:
+ if (val.type != nan.type || !val.isArithmeticNaN()) {
+ err << "expected arithmetic " << nan.type << " NaN, got " << val;
+ return Err{err.str()};
+ }
+ break;
}
+ return Ok{};
+ }
- try {
- parseOperation(inner);
- } catch (const TrapException&) {
- trapped = true;
+ Result<> checkLane(Literal val, LaneResult expected, Index index) {
+ std::stringstream err;
+ if (auto* e = std::get_if<Literal>(&expected)) {
+ if (*e != val) {
+ err << "expected " << *e << ", got " << val << " at lane " << index;
+ return Err{err.str()};
+ }
+ } else if (auto* nan = std::get_if<NaNResult>(&expected)) {
+ auto check = checkNaN(val, *nan);
+ if (auto* e = check.getErr()) {
+ err << e->msg << " at lane " << index;
+ return Err{err.str()};
+ }
+ } else {
+ WASM_UNREACHABLE("unexpected lane expectation");
}
- assert(trapped);
+ return Ok{};
}
- void parseAssertException(Element& s) {
- [[maybe_unused]] bool thrown = false;
- auto& inner = *s[1];
- if (inner[0]->str() == MODULE) {
- return parseModuleAssertion(s);
+ Result<> assertReturn(AssertReturn& assn) {
+ std::stringstream err;
+ auto result = doAction(assn.action);
+ auto* values = std::get_if<Literals>(&result);
+ if (!values) {
+ return Err{std::string("expected return, got ") + resultToString(result)};
}
+ if (values->size() != assn.expected.size()) {
+ err << "expected " << assn.expected.size() << " values, got "
+ << resultToString(result);
+ return Err{err.str()};
+ }
+ for (Index i = 0; i < values->size(); ++i) {
+ auto atIndex = [&]() {
+ if (values->size() <= 1) {
+ return std::string{};
+ }
+ std::stringstream ss;
+ ss << " at index " << i;
+ return ss.str();
+ };
- try {
- parseOperation(inner);
- } catch (const WasmException& e) {
- std::cout << "[exception thrown: " << e << "]" << std::endl;
- thrown = true;
+ Literal val = (*values)[i];
+ auto& expected = assn.expected[i];
+ if (auto* v = std::get_if<Literal>(&expected)) {
+ if (val != *v) {
+ err << "expected " << *v << ", got " << val << atIndex();
+ return Err{err.str()};
+ }
+ } else if (auto* ref = std::get_if<RefResult>(&expected)) {
+ if (!val.type.isRef() || val.type.getHeapType() != ref->type) {
+ err << "expected " << ref->type << " reference, got " << val
+ << atIndex();
+ return Err{err.str()};
+ }
+ } else if (auto* nan = std::get_if<NaNResult>(&expected)) {
+ auto check = checkNaN(val, *nan);
+ if (auto* e = check.getErr()) {
+ err << e->msg << atIndex();
+ return Err{err.str()};
+ }
+ } else if (auto* lanes = std::get_if<LaneResults>(&expected)) {
+ switch (lanes->size()) {
+ case 4: {
+ auto vals = val.getLanesF32x4();
+ for (Index i = 0; i < 4; ++i) {
+ auto check = checkLane(vals[i], (*lanes)[i], i);
+ if (auto* e = check.getErr()) {
+ err << e->msg << atIndex();
+ return Err{err.str()};
+ }
+ }
+ break;
+ }
+ case 2: {
+ auto vals = val.getLanesF64x2();
+ for (Index i = 0; i < 2; ++i) {
+ auto check = checkLane(vals[i], (*lanes)[i], i);
+ if (auto* e = check.getErr()) {
+ err << e->msg << atIndex();
+ return Err{err.str()};
+ }
+ }
+ break;
+ }
+ default:
+ WASM_UNREACHABLE("unexpected number of lanes");
+ }
+ } else {
+ WASM_UNREACHABLE("unexpected expectation");
+ }
}
- assert(thrown);
+ return Ok{};
}
- void parseAssertReturn(Element& s) {
- Literals actual;
- Literals expected;
- if (s.size() >= 3) {
- expected = getLiteralsFromConstExpression(
- parseExpression(*modules[lastModule], *s[2]));
- }
- [[maybe_unused]] bool trapped = false;
- try {
- actual = parseOperation(*s[1]);
- } catch (const TrapException&) {
- trapped = true;
- } catch (const WasmException& e) {
- std::cout << "[exception thrown: " << e << "]" << std::endl;
- trapped = true;
- }
- assert(!trapped);
- std::cerr << "seen " << actual << ", expected " << expected << '\n';
- if (expected != actual) {
- Fatal() << "unexpected, should be identical\n";
+ Result<> assertAction(AssertAction& assn) {
+ std::stringstream err;
+ auto result = doAction(assn.action);
+ switch (assn.type) {
+ case ActionAssertionType::Trap:
+ if (std::get_if<TrapResult>(&result)) {
+ return Ok{};
+ }
+ err << "expected trap";
+ break;
+ case ActionAssertionType::Exhaustion:
+ if (std::get_if<HostLimitResult>(&result)) {
+ return Ok{};
+ }
+ err << "expected exhaustion";
+ break;
+ case ActionAssertionType::Exception:
+ if (std::get_if<ExceptionResult>(&result)) {
+ return Ok{};
+ }
+ err << "expected exception";
+ break;
}
+ err << ", got " << resultToString(result);
+ return Err{err.str()};
}
- void parseModuleAssertion(Element& s) {
- Module wasm;
- wasm.features = FeatureSet::All;
- std::unique_ptr<SExpressionWasmBuilder> builder;
- auto id = s[0]->str();
+ Result<> assertModule(AssertModule& assn) {
+ auto wasm = makeModule(assn.wasm);
+ if (const auto* err = wasm.getErr()) {
+ if (assn.type == ModuleAssertionType::Malformed ||
+ assn.type == ModuleAssertionType::Invalid) {
+ return Ok{};
+ }
+ return Err{err->msg};
+ }
- bool invalid = false;
- try {
- SExpressionWasmBuilder(wasm, *s[1], IRProfile::Normal);
- } catch (const ParseException&) {
- invalid = true;
+ if (assn.type == ModuleAssertionType::Malformed) {
+ return Err{"expected malformed module"};
}
- if (!invalid) {
- // maybe parsed ok, but otherwise incorrect
- invalid = !WasmValidator().validate(wasm);
+ auto valid = validateModule(**wasm);
+ if (auto* err = valid.getErr()) {
+ if (assn.type == ModuleAssertionType::Invalid) {
+ return Ok{};
+ }
+ return Err{err->msg};
}
- if (!invalid && id == ASSERT_UNLINKABLE) {
- // validate "instantiating" the mdoule
- auto reportUnknownImport = [&](Importable* import) {
- auto it = linkedInstances.find(import->module);
- if (it == linkedInstances.end() ||
- it->second->wasm.getExportOrNull(import->base) == nullptr) {
- std::cerr << "unknown import: " << import->module << '.'
- << import->base << '\n';
- invalid = true;
- }
- };
- ModuleUtils::iterImportedGlobals(wasm, reportUnknownImport);
- ModuleUtils::iterImportedTables(wasm, reportUnknownImport);
- ModuleUtils::iterImportedFunctions(wasm, [&](Importable* import) {
- if (import->module == SPECTEST && import->base.startsWith(PRINT)) {
- // We can handle it.
- } else {
- reportUnknownImport(import);
- }
- });
- ElementUtils::iterAllElementFunctionNames(&wasm, [&](Name name) {
- // spec tests consider it illegal to use spectest.print in a table
- if (auto* import = wasm.getFunction(name)) {
- if (import->imported() && import->module == SPECTEST &&
- import->base.startsWith(PRINT)) {
- std::cerr << "cannot put spectest.print in table\n";
- invalid = true;
- }
- }
- });
- ModuleUtils::iterImportedMemories(wasm, reportUnknownImport);
+ if (assn.type == ModuleAssertionType::Invalid) {
+ return Err{"expected invalid module"};
}
- if (!invalid && (id == ASSERT_TRAP || id == ASSERT_EXCEPTION)) {
- try {
- instantiate(&wasm);
- } catch (const TrapException&) {
- invalid = true;
- } catch (const WasmException& e) {
- std::cout << "[exception thrown: " << e << "]" << std::endl;
- invalid = true;
+ auto instance = instantiate(**wasm);
+ if (auto* err = instance.getErr()) {
+ if (assn.type == ModuleAssertionType::Unlinkable ||
+ assn.type == ModuleAssertionType::Trap) {
+ return Ok{};
}
+ return Err{err->msg};
}
- if (!invalid) {
- Colors::red(std::cerr);
- std::cerr << "[should have been invalid]\n";
- Colors::normal(std::cerr);
- Fatal() << &wasm << '\n';
+ if (assn.type == ModuleAssertionType::Unlinkable) {
+ return Err{"expected unlinkable module"};
+ }
+ if (assn.type == ModuleAssertionType::Trap) {
+ return Err{"expected instantiation to trap"};
}
- }
-protected:
- Options& options;
+ WASM_UNREACHABLE("unexpected module assertion");
+ }
// spectest module is a default host-provided module defined by the spec's
// reference interpreter. It's been replaced by the `(register ...)`
@@ -345,7 +446,7 @@ protected:
// is actually removed from the spec test.
void buildSpectestModule() {
auto spectest = std::make_shared<Module>();
- spectest->name = "spectest";
+ spectest->features = FeatureSet::All;
Builder builder(*spectest);
spectest->addGlobal(builder.makeGlobal(Name::fromInt(0),
@@ -388,45 +489,18 @@ protected:
spectest->addExport(
builder.makeExport("memory", memory->name, ExternalKind::Memory));
- modules["spectest"].swap(spectest);
- modules["spectest"]->features = FeatureSet::All;
- instantiate(modules["spectest"].get());
- linkedInstances["spectest"] = instances["spectest"];
// print_* functions are handled separately, no need to define here.
- }
-
-public:
- Shell(Options& options) : options(options) { buildSpectestModule(); }
-
- MaybeResult<> parseAndRun(Lexer& lexer) {
- size_t i = 0;
- while (!lexer.empty()) {
- auto next = lexer.next();
- auto size = next.find('\n');
- if (size != std::string_view::npos) {
- next = next.substr(0, size);
- } else {
- next = "";
- }
-
- if (!lexer.peekSExprStart("module")) {
- Colors::red(std::cerr);
- std::cerr << i;
- Colors::green(std::cerr);
- std::cerr << " CHECKING: ";
- Colors::normal(std::cerr);
- std::cerr << next;
- Colors::green(std::cerr);
- std::cerr << " [line: " << lexer.position().line << "]\n";
- Colors::normal(std::cerr);
- }
- CHECK_ERR(parse(lexer));
-
- i += 1;
+ WASTModule mod = std::move(spectest);
+ auto added = addModule(mod);
+ if (added.getErr()) {
+ WASM_UNREACHABLE("error building spectest module");
+ }
+ Register registration{"spectest"};
+ auto registered = addRegistration(registration);
+ if (registered.getErr()) {
+ WASM_UNREACHABLE("error registering spectest module");
}
-
- return Ok{};
}
};
@@ -453,16 +527,14 @@ int main(int argc, const char* argv[]) {
}
Lexer lexer(input);
- auto result = Shell(options).parseAndRun(lexer);
+ auto result = Shell(options).run(*script);
if (auto* err = result.getErr()) {
std::cerr << err->msg << '\n';
exit(1);
}
- if (result) {
- Colors::green(std::cerr);
- Colors::bold(std::cerr);
- std::cerr << "all checks passed.\n";
- Colors::normal(std::cerr);
- }
+ Colors::green(std::cerr);
+ Colors::bold(std::cerr);
+ std::cerr << "all checks passed.\n";
+ Colors::normal(std::cerr);
}
diff --git a/src/wasm-builder.h b/src/wasm-builder.h
index 8f86b3647..f37e45240 100644
--- a/src/wasm-builder.h
+++ b/src/wasm-builder.h
@@ -420,7 +420,6 @@ public:
ret->valueType = type;
ret->memory = memory;
ret->finalize();
- assert(ret->value->type.isConcrete() ? ret->value->type == type : true);
return ret;
}
Store* makeAtomicStore(unsigned bytes,
diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h
index fa129f14a..f9b7e8c50 100644
--- a/src/wasm-interpreter.h
+++ b/src/wasm-interpreter.h
@@ -4103,7 +4103,7 @@ public:
// can use it (refactor?)
Literals callFunctionInternal(Name name, Literals arguments) {
if (callDepth > maxDepth) {
- externalInterface->trap("stack limit");
+ hostLimit("stack limit");
}
Flow flow;
diff --git a/src/wasm-type.h b/src/wasm-type.h
index 6422d09ad..b02484ae3 100644
--- a/src/wasm-type.h
+++ b/src/wasm-type.h
@@ -634,6 +634,8 @@ struct TypeBuilder {
ForwardSupertypeReference,
// A child of the type is an invalid forward reference.
ForwardChildReference,
+ // A continuation reference that does not refer to a function type.
+ InvalidFuncType,
};
struct Error {
diff --git a/src/wasm/literal.cpp b/src/wasm/literal.cpp
index 2c61be12a..d4e1bfc45 100644
--- a/src/wasm/literal.cpp
+++ b/src/wasm/literal.cpp
@@ -467,6 +467,22 @@ bool Literal::isNaN() {
return false;
}
+bool Literal::isCanonicalNaN() {
+ if (!isNaN()) {
+ return false;
+ }
+ return (type == Type::f32 && NaNPayload(getf32()) == (1u << 23) - 1) ||
+ (type == Type::f64 && NaNPayload(getf64()) == (1ull << 52) - 1);
+}
+
+bool Literal::isArithmeticNaN() {
+ if (!isNaN()) {
+ return false;
+ }
+ return (type == Type::f32 && NaNPayload(getf32()) > (1u << 23) - 1) ||
+ (type == Type::f64 && NaNPayload(getf64()) > (1ull << 52) - 1);
+}
+
uint32_t Literal::NaNPayload(float f) {
assert(std::isnan(f) && "expected a NaN");
// SEEEEEEE EFFFFFFF FFFFFFFF FFFFFFFF
diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp
index 2eb5369d0..3e215b23a 100644
--- a/src/wasm/wasm-type.cpp
+++ b/src/wasm/wasm-type.cpp
@@ -1666,6 +1666,8 @@ std::ostream& operator<<(std::ostream& os, TypeBuilder::ErrorReason reason) {
return os << "Heap type has an undeclared supertype";
case TypeBuilder::ErrorReason::ForwardChildReference:
return os << "Heap type has an undeclared child";
+ case TypeBuilder::ErrorReason::InvalidFuncType:
+ return os << "Continuation has invalid function type";
}
WASM_UNREACHABLE("Unexpected error reason");
}
@@ -2616,7 +2618,7 @@ buildRecGroup(std::unique_ptr<RecGroupInfo>&& groupInfo,
updateReferencedHeapTypes(info, canonicalized);
}
- // Collect the types and check supertype validity.
+ // Collect the types and check validity.
std::unordered_set<HeapType> seenTypes;
for (size_t i = 0; i < typeInfos.size(); ++i) {
auto& info = typeInfos[i];
@@ -2633,6 +2635,12 @@ buildRecGroup(std::unique_ptr<RecGroupInfo>&& groupInfo,
TypeBuilder::Error{i, TypeBuilder::ErrorReason::InvalidSupertype}};
}
}
+ if (info->isContinuation()) {
+ if (!info->continuation.type.isSignature()) {
+ return {
+ TypeBuilder::Error{i, TypeBuilder::ErrorReason::InvalidFuncType}};
+ }
+ }
seenTypes.insert(asHeapType(info));
}