summaryrefslogtreecommitdiff
path: root/src/parser/wast-parser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/parser/wast-parser.cpp')
-rw-r--r--src/parser/wast-parser.cpp419
1 files changed, 419 insertions, 0 deletions
diff --git a/src/parser/wast-parser.cpp b/src/parser/wast-parser.cpp
new file mode 100644
index 000000000..fb0dce932
--- /dev/null
+++ b/src/parser/wast-parser.cpp
@@ -0,0 +1,419 @@
+/*
+ * Copyright 2024 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lexer.h"
+#include "literal.h"
+#include "wat-parser.h"
+
+namespace wasm::WATParser {
+
+using namespace std::string_view_literals;
+
+namespace {
+
+Result<Literal> const_(Lexer& in) {
+ // TODO: handle `ref.extern n` as well.
+ return parseConst(in);
+}
+
+Result<Literals> consts(Lexer& in) {
+ Literals lits;
+ while (!in.peekRParen()) {
+ auto l = const_(in);
+ CHECK_ERR(l);
+ lits.push_back(*l);
+ }
+ return lits;
+}
+
+MaybeResult<Action> action(Lexer& in) {
+ if (in.takeSExprStart("invoke"sv)) {
+ // TODO: Do we need to use this optional id?
+ in.takeID();
+ auto name = in.takeName();
+ if (!name) {
+ return in.err("expected export name");
+ }
+ auto args = consts(in);
+ CHECK_ERR(args);
+ if (!in.takeRParen()) {
+ return in.err("expected end of invoke action");
+ }
+ return InvokeAction{*name, *args};
+ }
+
+ if (in.takeSExprStart("get"sv)) {
+ // TODO: Do we need to use this optional id?
+ in.takeID();
+ auto name = in.takeName();
+ if (!name) {
+ return in.err("expected export name");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of get action");
+ }
+ return GetAction{*name};
+ }
+
+ return {};
+}
+
+// (module id? binary string*)
+// (module id? quote string*)
+// (module ...)
+Result<WASTModule> wastModule(Lexer& in, bool maybeInvalid = false) {
+ Lexer reset = in;
+ if (!in.takeSExprStart("module"sv)) {
+ return in.err("expected module");
+ }
+ QuotedModuleType type;
+ if (in.takeKeyword("quote"sv)) {
+ type = QuotedModuleType::Text;
+ } else if (in.takeKeyword("binary")) {
+ type = QuotedModuleType::Binary;
+ } else if (maybeInvalid) {
+ // This is not a quoted text or binary module, so it must be a normal inline
+ // module, but we might not be able to parse it. Treat it as through it were
+ // a quoted module instead.
+ int count = 1;
+ while (count && in.takeUntilParen()) {
+ if (in.takeLParen()) {
+ ++count;
+ } else if (in.takeRParen()) {
+ --count;
+ } else {
+ return in.err("unexpected end of script");
+ }
+ }
+ std::string mod(reset.next().substr(0, in.getPos() - reset.getPos()));
+ return QuotedModule{QuotedModuleType::Text, mod};
+ } else {
+ // This is a normal inline module that should be parseable. Reset to the
+ // start and parse it normally.
+ in = std::move(reset);
+ auto wasm = std::make_shared<Module>();
+ CHECK_ERR(parseModule(*wasm, in));
+ return wasm;
+ }
+
+ // We have a quote or binary module. Collect its contents.
+ std::stringstream ss;
+ while (auto s = in.takeString()) {
+ ss << *s;
+ }
+
+ if (!in.takeRParen()) {
+ return in.err("expected end of module");
+ }
+
+ return QuotedModule{type, ss.str()};
+}
+
+Result<NaNKind> nan(Lexer& in) {
+ if (in.takeKeyword("nan:canonical"sv)) {
+ return NaNKind::Canonical;
+ }
+ if (in.takeKeyword("nan:arithmetic"sv)) {
+ return NaNKind::Arithmetic;
+ }
+ return in.err("expected NaN result pattern");
+}
+
+Result<ExpectedResult> result(Lexer& in) {
+ Lexer constLexer = in;
+ auto c = const_(constLexer);
+ // TODO: Generating and discarding errors like this can lead to quadratic
+ // behavior. Optimize this if necessary.
+ if (!c.getErr()) {
+ in = constLexer;
+ return *c;
+ }
+
+ // If we failed to parse a constant, we must have either a nan pattern or a
+ // reference.
+ if (in.takeSExprStart("f32.const"sv)) {
+ auto kind = nan(in);
+ CHECK_ERR(kind);
+ if (!in.takeRParen()) {
+ return in.err("expected end of f32.const");
+ }
+ return NaNResult{*kind, Type::f32};
+ }
+
+ if (in.takeSExprStart("f64.const"sv)) {
+ auto kind = nan(in);
+ CHECK_ERR(kind);
+ if (!in.takeRParen()) {
+ return in.err("expected end of f64.const");
+ }
+ return NaNResult{*kind, Type::f64};
+ }
+
+ if (in.takeSExprStart("v128.const"sv)) {
+ LaneResults lanes;
+ if (in.takeKeyword("f32x4"sv)) {
+ for (int i = 0; i < 4; ++i) {
+ if (auto f = in.takeF32()) {
+ lanes.push_back(Literal(*f));
+ } else {
+ auto kind = nan(in);
+ CHECK_ERR(kind);
+ lanes.push_back(NaNResult{*kind, Type::f32});
+ }
+ }
+ } else if (in.takeKeyword("f64x2"sv)) {
+ for (int i = 0; i < 2; ++i) {
+ if (auto f = in.takeF64()) {
+ lanes.push_back(Literal(*f));
+ } else {
+ auto kind = nan(in);
+ CHECK_ERR(kind);
+ lanes.push_back(NaNResult{*kind, Type::f64});
+ }
+ }
+ } else {
+ return in.err("unexpected vector shape");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of v128.const");
+ }
+ return lanes;
+ }
+
+ if (in.takeSExprStart("ref.extern")) {
+ if (!in.takeRParen()) {
+ return in.err("expected end of ref.extern");
+ }
+ return RefResult{HeapType::ext};
+ }
+
+ if (in.takeSExprStart("ref.func")) {
+ if (!in.takeRParen()) {
+ return in.err("expected end of ref.func");
+ }
+ return RefResult{HeapType::func};
+ }
+
+ return in.err("unrecognized result");
+}
+
+Result<ExpectedResults> results(Lexer& in) {
+ ExpectedResults res;
+ while (!in.peekRParen()) {
+ auto r = result(in);
+ CHECK_ERR(r);
+ res.emplace_back(std::move(*r));
+ }
+ return res;
+}
+
+// (assert_return action result*)
+MaybeResult<AssertReturn> assertReturn(Lexer& in) {
+ if (!in.takeSExprStart("assert_return"sv)) {
+ return {};
+ }
+ auto a = action(in);
+ CHECK_ERR(a);
+ auto expected = results(in);
+ CHECK_ERR(expected);
+ if (!in.takeRParen()) {
+ return in.err("expected end of assert_return");
+ }
+ return AssertReturn{*a, *expected};
+}
+
+// (assert_exception action)
+MaybeResult<AssertException> assertException(Lexer& in) {
+ if (!in.takeSExprStart("assert_exception"sv)) {
+ return {};
+ }
+ auto a = action(in);
+ CHECK_ERR(a);
+ if (!in.takeRParen()) {
+ return in.err("expected end of assert_exception");
+ }
+ return AssertException{*a};
+}
+
+// (assert_exhaustion action msg)
+MaybeResult<AssertAction> assertAction(Lexer& in) {
+ ActionAssertionType type;
+ if (in.takeSExprStart("assert_exhaustion"sv)) {
+ type = ActionAssertionType::Exhaustion;
+ } else {
+ return {};
+ }
+
+ auto a = action(in);
+ CHECK_ERR(a);
+ auto msg = in.takeString();
+ if (!msg) {
+ return in.err("expected error message");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of assertion");
+ }
+ return AssertAction{type, *a, *msg};
+}
+
+// (assert_malformed module msg)
+// (assert_invalid module msg)
+// (assert_unlinkable module msg)
+MaybeResult<AssertModule> assertModule(Lexer& in) {
+ ModuleAssertionType type;
+ if (in.takeSExprStart("assert_malformed"sv)) {
+ type = ModuleAssertionType::Malformed;
+ } else if (in.takeSExprStart("assert_invalid"sv)) {
+ type = ModuleAssertionType::Invalid;
+ } else if (in.takeSExprStart("assert_unlinkable"sv)) {
+ type = ModuleAssertionType::Unlinkable;
+ } else {
+ return {};
+ }
+
+ auto mod = wastModule(in, type == ModuleAssertionType::Invalid);
+ CHECK_ERR(mod);
+ auto msg = in.takeString();
+ if (!msg) {
+ return in.err("expected error message");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of assertion");
+ }
+ return AssertModule{type, *mod, *msg};
+}
+
+// (assert_trap action msg)
+// (assert_trap module msg)
+MaybeResult<Assertion> assertTrap(Lexer& in) {
+ if (!in.takeSExprStart("assert_trap"sv)) {
+ return {};
+ }
+ auto pos = in.getPos();
+ if (auto a = action(in)) {
+ CHECK_ERR(a);
+ auto msg = in.takeString();
+ if (!msg) {
+ return in.err("expected error message");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of assertion");
+ }
+ return Assertion{AssertAction{ActionAssertionType::Trap, *a, *msg}};
+ }
+ auto mod = wastModule(in);
+ if (mod.getErr()) {
+ return in.err(pos, "expected action or module");
+ }
+ auto msg = in.takeString();
+ if (!msg) {
+ return in.err("expected error message");
+ }
+ if (!in.takeRParen()) {
+ return in.err("expected end of assertion");
+ }
+ return Assertion{AssertModule{ModuleAssertionType::Trap, *mod, *msg}};
+}
+
+MaybeResult<Assertion> assertion(Lexer& in) {
+ if (auto a = assertReturn(in)) {
+ CHECK_ERR(a);
+ return Assertion{*a};
+ }
+ if (auto a = assertException(in)) {
+ CHECK_ERR(a);
+ return Assertion{*a};
+ }
+ if (auto a = assertAction(in)) {
+ CHECK_ERR(a);
+ return Assertion{*a};
+ }
+ if (auto a = assertModule(in)) {
+ CHECK_ERR(a);
+ return Assertion{*a};
+ }
+ if (auto a = assertTrap(in)) {
+ CHECK_ERR(a);
+ return *a;
+ }
+ return {};
+}
+
+// (register name id?)
+MaybeResult<Register> register_(Lexer& in) {
+ if (!in.takeSExprStart("register"sv)) {
+ return {};
+ }
+ auto name = in.takeName();
+ if (!name) {
+ return in.err("expected name");
+ }
+
+ // TODO: Do we need to use this optional id?
+ in.takeID();
+
+ if (!in.takeRParen()) {
+ // TODO: handle optional module id.
+ return in.err("expected end of register command");
+ }
+ return Register{*name};
+}
+
+// module | register | action | assertion
+Result<WASTCommand> command(Lexer& in) {
+ if (auto cmd = register_(in)) {
+ CHECK_ERR(cmd);
+ return *cmd;
+ }
+ if (auto cmd = action(in)) {
+ CHECK_ERR(cmd);
+ return *cmd;
+ }
+ if (auto cmd = assertion(in)) {
+ CHECK_ERR(cmd);
+ return *cmd;
+ }
+ auto mod = wastModule(in);
+ CHECK_ERR(mod);
+ return *mod;
+}
+
+Result<WASTScript> wast(Lexer& in) {
+ WASTScript cmds;
+ while (!in.empty()) {
+ auto cmd = command(in);
+ if (cmd.getErr() && cmds.empty()) {
+ // The entire script might be a single module comprising a sequence of
+ // module fields with a top-level `(module ...)`.
+ auto wasm = std::make_shared<Module>();
+ CHECK_ERR(parseModule(*wasm, in.buffer));
+ cmds.emplace_back(std::move(wasm));
+ return cmds;
+ }
+ CHECK_ERR(cmd);
+ cmds.emplace_back(std::move(*cmd));
+ }
+ return cmds;
+}
+
+} // anonymous namespace
+
+Result<WASTScript> parseScript(std::string_view in) {
+ Lexer lexer(in);
+ return wast(lexer);
+}
+
+} // namespace wasm::WATParser