diff options
Diffstat (limited to 'src/parser')
-rw-r--r-- | src/parser/CMakeLists.txt | 9 | ||||
-rw-r--r-- | src/parser/common.h | 31 | ||||
-rw-r--r-- | src/parser/context-decls.cpp | 194 | ||||
-rw-r--r-- | src/parser/context-defs.cpp | 98 | ||||
-rw-r--r-- | src/parser/contexts.h | 1275 | ||||
-rw-r--r-- | src/parser/input-impl.h | 273 | ||||
-rw-r--r-- | src/parser/input.h | 75 | ||||
-rw-r--r-- | src/parser/lexer.cpp | 1038 | ||||
-rw-r--r-- | src/parser/lexer.h | 227 | ||||
-rw-r--r-- | src/parser/parsers.h | 2036 | ||||
-rw-r--r-- | src/parser/wat-parser.cpp | 172 | ||||
-rw-r--r-- | src/parser/wat-parser.h | 32 |
12 files changed, 5460 insertions, 0 deletions
diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt new file mode 100644 index 000000000..bae90379e --- /dev/null +++ b/src/parser/CMakeLists.txt @@ -0,0 +1,9 @@ +FILE(GLOB parser_HEADERS *.h) +set(parser_SOURCES + context-decls.cpp + context-defs.cpp + lexer.cpp + wat-parser.cpp + ${parser_HEADERS} +) +add_library(parser OBJECT ${parser_SOURCES}) diff --git a/src/parser/common.h b/src/parser/common.h new file mode 100644 index 000000000..7adf2e5fa --- /dev/null +++ b/src/parser/common.h @@ -0,0 +1,31 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef parser_common_h +#define parser_common_h + +#include "support/name.h" + +namespace wasm::WATParser { + +struct ImportNames { + Name mod; + Name nm; +}; + +} // namespace wasm::WATParser + +#endif // parser_common_h diff --git a/src/parser/context-decls.cpp b/src/parser/context-decls.cpp new file mode 100644 index 000000000..f668c67ae --- /dev/null +++ b/src/parser/context-decls.cpp @@ -0,0 +1,194 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contexts.h" + +namespace wasm::WATParser { + +namespace { + +void applyImportNames(Importable& item, ImportNames* names) { + if (names) { + item.module = names->mod; + item.base = names->nm; + } +} + +Result<> addExports(ParseInput& in, + Module& wasm, + const Named* item, + const std::vector<Name>& exports, + ExternalKind kind) { + for (auto name : exports) { + if (wasm.getExportOrNull(name)) { + // TODO: Fix error location + return in.err("repeated export name"); + } + wasm.addExport(Builder(wasm).makeExport(name, item->name, kind)); + } + return Ok{}; +} + +} // anonymous namespace + +Result<Function*> +ParseDeclsCtx::addFuncDecl(Index pos, Name name, ImportNames* importNames) { + auto f = std::make_unique<Function>(); + if (name.is()) { + if (wasm.getFunctionOrNull(name)) { + // TDOO: if the existing function is not explicitly named, fix its name + // and continue. + return in.err(pos, "repeated function name"); + } + f->setExplicitName(name); + } else { + name = (importNames ? "fimport$" : "") + std::to_string(funcCounter++); + name = Names::getValidFunctionName(wasm, name); + f->name = name; + } + applyImportNames(*f, importNames); + return wasm.addFunction(std::move(f)); +} + +Result<> ParseDeclsCtx::addFunc(Name name, + const std::vector<Name>& exports, + ImportNames* import, + TypeUseT type, + std::optional<LocalsT>, + std::optional<InstrsT>, + Index pos) { + if (import && hasNonImport) { + return in.err(pos, "import after non-import"); + } + auto f = addFuncDecl(pos, name, import); + CHECK_ERR(f); + CHECK_ERR(addExports(in, wasm, *f, exports, ExternalKind::Function)); + funcDefs.push_back({name, pos, Index(funcDefs.size())}); + return Ok{}; +} + +Result<Memory*> ParseDeclsCtx::addMemoryDecl(Index pos, + Name name, + ImportNames* importNames, + MemType type) { + auto m = std::make_unique<Memory>(); + m->indexType = type.type; + m->initial = type.limits.initial; + m->max = type.limits.max; + m->shared = type.shared; + if (name) { + // TODO: if the existing memory is not explicitly named, fix its name + // and continue. + if (wasm.getMemoryOrNull(name)) { + return in.err(pos, "repeated memory name"); + } + m->setExplicitName(name); + } else { + name = (importNames ? "mimport$" : "") + std::to_string(memoryCounter++); + name = Names::getValidMemoryName(wasm, name); + m->name = name; + } + applyImportNames(*m, importNames); + return wasm.addMemory(std::move(m)); +} + +Result<> ParseDeclsCtx::addMemory(Name name, + const std::vector<Name>& exports, + ImportNames* import, + MemType type, + Index pos) { + if (import && hasNonImport) { + return in.err(pos, "import after non-import"); + } + auto m = addMemoryDecl(pos, name, import, type); + CHECK_ERR(m); + CHECK_ERR(addExports(in, wasm, *m, exports, ExternalKind::Memory)); + memoryDefs.push_back({name, pos, Index(memoryDefs.size())}); + return Ok{}; +} + +Result<> ParseDeclsCtx::addImplicitData(DataStringT&& data) { + auto& mem = *wasm.memories.back(); + auto d = std::make_unique<DataSegment>(); + d->memory = mem.name; + d->isPassive = false; + d->offset = Builder(wasm).makeConstPtr(0, mem.indexType); + d->data = std::move(data); + d->name = Names::getValidDataSegmentName(wasm, "implicit-data"); + wasm.addDataSegment(std::move(d)); + return Ok{}; +} + +Result<Global*> +ParseDeclsCtx::addGlobalDecl(Index pos, Name name, ImportNames* importNames) { + auto g = std::make_unique<Global>(); + if (name) { + if (wasm.getGlobalOrNull(name)) { + // TODO: if the existing global is not explicitly named, fix its name + // and continue. + return in.err(pos, "repeated global name"); + } + g->setExplicitName(name); + } else { + name = (importNames ? "gimport$" : "") + std::to_string(globalCounter++); + name = Names::getValidGlobalName(wasm, name); + g->name = name; + } + applyImportNames(*g, importNames); + return wasm.addGlobal(std::move(g)); +} + +Result<> ParseDeclsCtx::addGlobal(Name name, + const std::vector<Name>& exports, + ImportNames* import, + GlobalTypeT, + std::optional<ExprT>, + Index pos) { + if (import && hasNonImport) { + return in.err(pos, "import after non-import"); + } + auto g = addGlobalDecl(pos, name, import); + CHECK_ERR(g); + CHECK_ERR(addExports(in, wasm, *g, exports, ExternalKind::Global)); + globalDefs.push_back({name, pos, Index(globalDefs.size())}); + return Ok{}; +} + +Result<> ParseDeclsCtx::addData(Name name, + MemoryIdxT*, + std::optional<ExprT>, + std::vector<char>&& data, + Index pos) { + auto d = std::make_unique<DataSegment>(); + if (name) { + if (wasm.getDataSegmentOrNull(name)) { + // TODO: if the existing segment is not explicitly named, fix its name + // and continue. + return in.err(pos, "repeated data segment name"); + } + d->setExplicitName(name); + } else { + name = std::to_string(dataCounter++); + name = Names::getValidDataSegmentName(wasm, name); + d->name = name; + } + d->data = std::move(data); + dataDefs.push_back({name, pos, Index(wasm.dataSegments.size())}); + wasm.addDataSegment(std::move(d)); + return Ok{}; +} + +} // namespace wasm::WATParser diff --git a/src/parser/context-defs.cpp b/src/parser/context-defs.cpp new file mode 100644 index 000000000..ca8f61ec3 --- /dev/null +++ b/src/parser/context-defs.cpp @@ -0,0 +1,98 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contexts.h" + +namespace wasm::WATParser { + +Result<typename ParseDefsCtx::TypeUseT> +ParseDefsCtx::makeTypeUse(Index pos, + std::optional<HeapTypeT> type, + ParamsT* params, + ResultsT* results) { + if (type && (params || results)) { + std::vector<Type> paramTypes; + if (params) { + paramTypes = getUnnamedTypes(*params); + } + + std::vector<Type> resultTypes; + if (results) { + resultTypes = *results; + } + + auto sig = Signature(Type(paramTypes), Type(resultTypes)); + + if (!type->isSignature() || type->getSignature() != sig) { + return in.err(pos, "type does not match provided signature"); + } + } + + if (type) { + return *type; + } + + auto it = implicitTypes.find(pos); + assert(it != implicitTypes.end()); + return it->second; +} + +Result<> ParseDefsCtx::addFunc(Name, + const std::vector<Name>&, + ImportNames*, + TypeUseT, + std::optional<LocalsT>, + std::optional<InstrsT>, + Index pos) { + CHECK_ERR(withLoc(pos, irBuilder.visitEnd())); + auto body = irBuilder.build(); + CHECK_ERR(withLoc(pos, body)); + wasm.functions[index]->body = *body; + return Ok{}; +} + +Result<> ParseDefsCtx::addGlobal(Name, + const std::vector<Name>&, + ImportNames*, + GlobalTypeT, + std::optional<ExprT> exp, + Index) { + if (exp) { + wasm.globals[index]->init = *exp; + } + return Ok{}; +} + +Result<> ParseDefsCtx::addData( + Name, Name* mem, std::optional<ExprT> offset, DataStringT, Index pos) { + auto& d = wasm.dataSegments[index]; + if (offset) { + d->isPassive = false; + d->offset = *offset; + if (mem) { + d->memory = *mem; + } else if (wasm.memories.size() > 0) { + d->memory = wasm.memories[0]->name; + } else { + return in.err(pos, "active segment with no memory"); + } + } else { + d->isPassive = true; + } + return Ok{}; +} + +} // namespace wasm::WATParser diff --git a/src/parser/contexts.h b/src/parser/contexts.h new file mode 100644 index 000000000..210945e8d --- /dev/null +++ b/src/parser/contexts.h @@ -0,0 +1,1275 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef parser_context_h +#define parser_context_h + +#include "common.h" +#include "input.h" +#include "ir/names.h" +#include "support/name.h" +#include "support/result.h" +#include "wasm-builder.h" +#include "wasm-ir-builder.h" +#include "wasm.h" + +namespace wasm::WATParser { + +using IndexMap = std::unordered_map<Name, Index>; + +inline std::vector<Type> getUnnamedTypes(const std::vector<NameType>& named) { + std::vector<Type> types; + types.reserve(named.size()); + for (auto& t : named) { + types.push_back(t.type); + } + return types; +} + +struct Limits { + uint64_t initial; + uint64_t max; +}; + +struct MemType { + Type type; + Limits limits; + bool shared; +}; + +struct Memarg { + uint64_t offset; + uint32_t align; +}; + +// The location, possible name, and index in the respective module index space +// of a module-level definition in the input. +struct DefPos { + Name name; + Index pos; + Index index; +}; + +struct GlobalType { + Mutability mutability; + Type type; +}; + +// A signature type and parameter names (possibly empty), used for parsing +// function types. +struct TypeUse { + HeapType type; + std::vector<Name> names; +}; + +struct NullTypeParserCtx { + using IndexT = Ok; + using HeapTypeT = Ok; + using TypeT = Ok; + using ParamsT = Ok; + using ResultsT = size_t; + using BlockTypeT = Ok; + using SignatureT = Ok; + using StorageT = Ok; + using FieldT = Ok; + using FieldsT = Ok; + using StructT = Ok; + using ArrayT = Ok; + using LimitsT = Ok; + using MemTypeT = Ok; + using GlobalTypeT = Ok; + using TypeUseT = Ok; + using LocalsT = Ok; + using DataStringT = Ok; + + HeapTypeT makeFunc() { return Ok{}; } + HeapTypeT makeAny() { return Ok{}; } + HeapTypeT makeExtern() { return Ok{}; } + HeapTypeT makeEq() { return Ok{}; } + HeapTypeT makeI31() { return Ok{}; } + HeapTypeT makeStructType() { return Ok{}; } + HeapTypeT makeArrayType() { return Ok{}; } + + TypeT makeI32() { return Ok{}; } + TypeT makeI64() { return Ok{}; } + TypeT makeF32() { return Ok{}; } + TypeT makeF64() { return Ok{}; } + TypeT makeV128() { return Ok{}; } + + TypeT makeRefType(HeapTypeT, Nullability) { return Ok{}; } + + ParamsT makeParams() { return Ok{}; } + void appendParam(ParamsT&, Name, TypeT) {} + + // We have to count results because whether or not a block introduces a + // typeuse that may implicitly define a type depends on how many results it + // has. + size_t makeResults() { return 0; } + void appendResult(size_t& results, TypeT) { ++results; } + size_t getResultsSize(size_t results) { return results; } + + SignatureT makeFuncType(ParamsT*, ResultsT*) { return Ok{}; } + + StorageT makeI8() { return Ok{}; } + StorageT makeI16() { return Ok{}; } + StorageT makeStorageType(TypeT) { return Ok{}; } + + FieldT makeFieldType(StorageT, Mutability) { return Ok{}; } + + FieldsT makeFields() { return Ok{}; } + void appendField(FieldsT&, Name, FieldT) {} + + StructT makeStruct(FieldsT&) { return Ok{}; } + + std::optional<ArrayT> makeArray(FieldsT&) { return Ok{}; } + + GlobalTypeT makeGlobalType(Mutability, TypeT) { return Ok{}; } + + LocalsT makeLocals() { return Ok{}; } + void appendLocal(LocalsT&, Name, TypeT) {} + + Result<Index> getTypeIndex(Name) { return 1; } + Result<HeapTypeT> getHeapTypeFromIdx(Index) { return Ok{}; } + + DataStringT makeDataString() { return Ok{}; } + void appendDataString(DataStringT&, std::string_view) {} + + MemTypeT makeMemType(Type, LimitsT, bool) { return Ok{}; } + + BlockTypeT getBlockTypeFromResult(size_t results) { return Ok{}; } + + Result<> getBlockTypeFromTypeUse(Index, TypeUseT) { return Ok{}; } +}; + +template<typename Ctx> struct TypeParserCtx { + using IndexT = Index; + using HeapTypeT = HeapType; + using TypeT = Type; + using ParamsT = std::vector<NameType>; + using ResultsT = std::vector<Type>; + using BlockTypeT = HeapType; + using SignatureT = Signature; + using StorageT = Field; + using FieldT = Field; + using FieldsT = std::pair<std::vector<Name>, std::vector<Field>>; + using StructT = std::pair<std::vector<Name>, Struct>; + using ArrayT = Array; + using LimitsT = Ok; + using MemTypeT = Ok; + using LocalsT = std::vector<NameType>; + using DataStringT = Ok; + + // Map heap type names to their indices. + const IndexMap& typeIndices; + + TypeParserCtx(const IndexMap& typeIndices) : typeIndices(typeIndices) {} + + Ctx& self() { return *static_cast<Ctx*>(this); } + + HeapTypeT makeFunc() { return HeapType::func; } + HeapTypeT makeAny() { return HeapType::any; } + HeapTypeT makeExtern() { return HeapType::ext; } + HeapTypeT makeEq() { return HeapType::eq; } + HeapTypeT makeI31() { return HeapType::i31; } + HeapTypeT makeStructType() { return HeapType::struct_; } + HeapTypeT makeArrayType() { return HeapType::array; } + + TypeT makeI32() { return Type::i32; } + TypeT makeI64() { return Type::i64; } + TypeT makeF32() { return Type::f32; } + TypeT makeF64() { return Type::f64; } + TypeT makeV128() { return Type::v128; } + + TypeT makeRefType(HeapTypeT ht, Nullability nullability) { + return Type(ht, nullability); + } + + TypeT makeTupleType(const std::vector<Type> types) { return Tuple(types); } + + ParamsT makeParams() { return {}; } + void appendParam(ParamsT& params, Name id, TypeT type) { + params.push_back({id, type}); + } + + ResultsT makeResults() { return {}; } + void appendResult(ResultsT& results, TypeT type) { results.push_back(type); } + size_t getResultsSize(const ResultsT& results) { return results.size(); } + + SignatureT makeFuncType(ParamsT* params, ResultsT* results) { + std::vector<Type> empty; + const auto& paramTypes = params ? getUnnamedTypes(*params) : empty; + const auto& resultTypes = results ? *results : empty; + return Signature(self().makeTupleType(paramTypes), + self().makeTupleType(resultTypes)); + } + + StorageT makeI8() { return Field(Field::i8, Immutable); } + StorageT makeI16() { return Field(Field::i16, Immutable); } + StorageT makeStorageType(TypeT type) { return Field(type, Immutable); } + + FieldT makeFieldType(FieldT field, Mutability mutability) { + if (field.packedType == Field::not_packed) { + return Field(field.type, mutability); + } + return Field(field.packedType, mutability); + } + + FieldsT makeFields() { return {}; } + void appendField(FieldsT& fields, Name name, FieldT field) { + fields.first.push_back(name); + fields.second.push_back(field); + } + + StructT makeStruct(FieldsT& fields) { + return {std::move(fields.first), Struct(std::move(fields.second))}; + } + + std::optional<ArrayT> makeArray(FieldsT& fields) { + if (fields.second.size() == 1) { + return Array(fields.second[0]); + } + return {}; + } + + LocalsT makeLocals() { return {}; } + void appendLocal(LocalsT& locals, Name id, TypeT type) { + locals.push_back({id, type}); + } + + Result<Index> getTypeIndex(Name id) { + auto it = typeIndices.find(id); + if (it == typeIndices.end()) { + return self().in.err("unknown type identifier"); + } + return it->second; + } + + DataStringT makeDataString() { return Ok{}; } + void appendDataString(DataStringT&, std::string_view) {} + + LimitsT makeLimits(uint64_t, std::optional<uint64_t>) { return Ok{}; } + LimitsT getLimitsFromData(DataStringT) { return Ok{}; } + + MemTypeT makeMemType(Type, LimitsT, bool) { return Ok{}; } + + HeapType getBlockTypeFromResult(const std::vector<Type> results) { + assert(results.size() == 1); + return HeapType(Signature(Type::none, results[0])); + } +}; + +struct NullInstrParserCtx { + using InstrT = Ok; + using InstrsT = Ok; + using ExprT = Ok; + + using FieldIdxT = Ok; + using LocalIdxT = Ok; + using GlobalIdxT = Ok; + using MemoryIdxT = Ok; + using DataIdxT = Ok; + + using MemargT = Ok; + + InstrsT makeInstrs() { return Ok{}; } + void appendInstr(InstrsT&, InstrT) {} + InstrsT finishInstrs(InstrsT&) { return Ok{}; } + + ExprT makeExpr(InstrsT) { return Ok{}; } + Result<ExprT> instrToExpr(InstrT) { return Ok{}; } + + template<typename HeapTypeT> FieldIdxT getFieldFromIdx(HeapTypeT, uint32_t) { + return Ok{}; + } + template<typename HeapTypeT> FieldIdxT getFieldFromName(HeapTypeT, Name) { + return Ok{}; + } + LocalIdxT getLocalFromIdx(uint32_t) { return Ok{}; } + LocalIdxT getLocalFromName(Name) { return Ok{}; } + GlobalIdxT getGlobalFromIdx(uint32_t) { return Ok{}; } + GlobalIdxT getGlobalFromName(Name) { return Ok{}; } + MemoryIdxT getMemoryFromIdx(uint32_t) { return Ok{}; } + MemoryIdxT getMemoryFromName(Name) { return Ok{}; } + DataIdxT getDataFromIdx(uint32_t) { return Ok{}; } + DataIdxT getDataFromName(Name) { return Ok{}; } + + MemargT getMemarg(uint64_t, uint32_t) { return Ok{}; } + + template<typename BlockTypeT> + InstrT makeBlock(Index, std::optional<Name>, BlockTypeT) { + return Ok{}; + } + InstrT finishBlock(Index, InstrsT) { return Ok{}; } + + InstrT makeUnreachable(Index) { return Ok{}; } + InstrT makeNop(Index) { return Ok{}; } + InstrT makeBinary(Index, BinaryOp) { return Ok{}; } + InstrT makeUnary(Index, UnaryOp) { return Ok{}; } + template<typename ResultsT> InstrT makeSelect(Index, ResultsT*) { + return Ok{}; + } + InstrT makeDrop(Index) { return Ok{}; } + InstrT makeMemorySize(Index, MemoryIdxT*) { return Ok{}; } + InstrT makeMemoryGrow(Index, MemoryIdxT*) { return Ok{}; } + InstrT makeLocalGet(Index, LocalIdxT) { return Ok{}; } + InstrT makeLocalTee(Index, LocalIdxT) { return Ok{}; } + InstrT makeLocalSet(Index, LocalIdxT) { return Ok{}; } + InstrT makeGlobalGet(Index, GlobalIdxT) { return Ok{}; } + InstrT makeGlobalSet(Index, GlobalIdxT) { return Ok{}; } + + InstrT makeI32Const(Index, uint32_t) { return Ok{}; } + InstrT makeI64Const(Index, uint64_t) { return Ok{}; } + InstrT makeF32Const(Index, float) { return Ok{}; } + InstrT makeF64Const(Index, double) { return Ok{}; } + InstrT makeLoad(Index, Type, bool, int, bool, MemoryIdxT*, MemargT) { + return Ok{}; + } + InstrT makeStore(Index, Type, int, bool, MemoryIdxT*, MemargT) { + return Ok{}; + } + InstrT makeAtomicRMW(Index, AtomicRMWOp, Type, int, MemoryIdxT*, MemargT) { + return Ok{}; + } + InstrT makeAtomicCmpxchg(Index, Type, int, MemoryIdxT*, MemargT) { + return Ok{}; + } + InstrT makeAtomicWait(Index, Type, MemoryIdxT*, MemargT) { return Ok{}; } + InstrT makeAtomicNotify(Index, MemoryIdxT*, MemargT) { return Ok{}; } + InstrT makeAtomicFence(Index) { return Ok{}; } + InstrT makeSIMDExtract(Index, SIMDExtractOp, uint8_t) { return Ok{}; } + InstrT makeSIMDReplace(Index, SIMDReplaceOp, uint8_t) { return Ok{}; } + InstrT makeSIMDShuffle(Index, const std::array<uint8_t, 16>&) { return Ok{}; } + InstrT makeSIMDTernary(Index, SIMDTernaryOp) { return Ok{}; } + InstrT makeSIMDShift(Index, SIMDShiftOp) { return Ok{}; } + InstrT makeSIMDLoad(Index, SIMDLoadOp, MemoryIdxT*, MemargT) { return Ok{}; } + InstrT makeSIMDLoadStoreLane( + Index, SIMDLoadStoreLaneOp, MemoryIdxT*, MemargT, uint8_t) { + return Ok{}; + } + InstrT makeMemoryInit(Index, MemoryIdxT*, DataIdxT) { return Ok{}; } + InstrT makeDataDrop(Index, DataIdxT) { return Ok{}; } + + InstrT makeMemoryCopy(Index, MemoryIdxT*, MemoryIdxT*) { return Ok{}; } + InstrT makeMemoryFill(Index, MemoryIdxT*) { return Ok{}; } + + InstrT makeReturn(Index) { return Ok{}; } + template<typename HeapTypeT> InstrT makeRefNull(Index, HeapTypeT) { + return Ok{}; + } + InstrT makeRefIsNull(Index) { return Ok{}; } + + InstrT makeRefEq(Index) { return Ok{}; } + + InstrT makeRefI31(Index) { return Ok{}; } + InstrT makeI31Get(Index, bool) { return Ok{}; } + + template<typename HeapTypeT> InstrT makeStructNew(Index, HeapTypeT) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeStructNewDefault(Index, HeapTypeT) { + return Ok{}; + } + template<typename HeapTypeT> + InstrT makeStructGet(Index, HeapTypeT, FieldIdxT, bool) { + return Ok{}; + } + template<typename HeapTypeT> + InstrT makeStructSet(Index, HeapTypeT, FieldIdxT) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeArrayNew(Index, HeapTypeT) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeArrayNewDefault(Index, HeapTypeT) { + return Ok{}; + } + template<typename HeapTypeT> + InstrT makeArrayNewData(Index, HeapTypeT, DataIdxT) { + return Ok{}; + } + template<typename HeapTypeT> + InstrT makeArrayNewElem(Index, HeapTypeT, DataIdxT) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeArrayGet(Index, HeapTypeT, bool) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeArraySet(Index, HeapTypeT) { + return Ok{}; + } + InstrT makeArrayLen(Index) { return Ok{}; } + template<typename HeapTypeT> + InstrT makeArrayCopy(Index, HeapTypeT, HeapTypeT) { + return Ok{}; + } + template<typename HeapTypeT> InstrT makeArrayFill(Index, HeapTypeT) { + return Ok{}; + } +}; + +// Phase 1: Parse definition spans for top-level module elements and determine +// their indices and names. +struct ParseDeclsCtx : NullTypeParserCtx, NullInstrParserCtx { + using DataStringT = std::vector<char>; + using LimitsT = Limits; + using MemTypeT = MemType; + + ParseInput in; + + // At this stage we only look at types to find implicit type definitions, + // which are inserted directly into the context. We cannot materialize or + // validate any types because we don't know what types exist yet. + // + // Declared module elements are inserted into the module, but their bodies are + // not filled out until later parsing phases. + Module& wasm; + + // The module element definitions we are parsing in this phase. + std::vector<DefPos> typeDefs; + std::vector<DefPos> subtypeDefs; + std::vector<DefPos> funcDefs; + std::vector<DefPos> memoryDefs; + std::vector<DefPos> globalDefs; + std::vector<DefPos> dataDefs; + + // Positions of typeuses that might implicitly define new types. + std::vector<Index> implicitTypeDefs; + + // Counters used for generating names for module elements. + int funcCounter = 0; + int memoryCounter = 0; + int globalCounter = 0; + int dataCounter = 0; + + // Used to verify that all imports come before all non-imports. + bool hasNonImport = false; + + ParseDeclsCtx(std::string_view in, Module& wasm) : in(in), wasm(wasm) {} + + void addFuncType(SignatureT) {} + void addStructType(StructT) {} + void addArrayType(ArrayT) {} + void setOpen() {} + Result<> addSubtype(Index) { return Ok{}; } + void finishSubtype(Name name, Index pos) { + subtypeDefs.push_back({name, pos, Index(subtypeDefs.size())}); + } + size_t getRecGroupStartIndex() { return 0; } + void addRecGroup(Index, size_t) {} + void finishDeftype(Index pos) { + typeDefs.push_back({{}, pos, Index(typeDefs.size())}); + } + + std::vector<char> makeDataString() { return {}; } + void appendDataString(std::vector<char>& data, std::string_view str) { + data.insert(data.end(), str.begin(), str.end()); + } + + Limits makeLimits(uint64_t n, std::optional<uint64_t> m) { + return m ? Limits{n, *m} : Limits{n, Memory::kUnlimitedSize}; + } + Limits getLimitsFromData(const std::vector<char>& data) { + uint64_t size = (data.size() + Memory::kPageSize - 1) / Memory::kPageSize; + return {size, size}; + } + + MemType makeMemType(Type type, Limits limits, bool shared) { + return {type, limits, shared}; + } + + Result<TypeUseT> + makeTypeUse(Index pos, std::optional<HeapTypeT> type, ParamsT*, ResultsT*) { + if (!type) { + implicitTypeDefs.push_back(pos); + } + return Ok{}; + } + + Result<Function*> addFuncDecl(Index pos, Name name, ImportNames* importNames); + Result<> addFunc(Name name, + const std::vector<Name>& exports, + ImportNames* import, + TypeUseT type, + std::optional<LocalsT>, + std::optional<InstrsT>, + Index pos); + + Result<Memory*> + addMemoryDecl(Index pos, Name name, ImportNames* importNames, MemType type); + + Result<> addMemory(Name name, + const std::vector<Name>& exports, + ImportNames* import, + MemType type, + Index pos); + + Result<> addImplicitData(DataStringT&& data); + + Result<Global*> addGlobalDecl(Index pos, Name name, ImportNames* importNames); + + Result<> addGlobal(Name name, + const std::vector<Name>& exports, + ImportNames* import, + GlobalTypeT, + std::optional<ExprT>, + Index pos); + + Result<> addData(Name name, + MemoryIdxT*, + std::optional<ExprT>, + std::vector<char>&& data, + Index pos); +}; + +// Phase 2: Parse type definitions into a TypeBuilder. +struct ParseTypeDefsCtx : TypeParserCtx<ParseTypeDefsCtx> { + ParseInput in; + + // We update slots in this builder as we parse type definitions. + TypeBuilder& builder; + + // Parse the names of types and fields as we go. + std::vector<TypeNames> names; + + // The index of the subtype definition we are parsing. + Index index = 0; + + ParseTypeDefsCtx(std::string_view in, + TypeBuilder& builder, + const IndexMap& typeIndices) + : TypeParserCtx<ParseTypeDefsCtx>(typeIndices), in(in), builder(builder), + names(builder.size()) {} + + TypeT makeRefType(HeapTypeT ht, Nullability nullability) { + return builder.getTempRefType(ht, nullability); + } + + TypeT makeTupleType(const std::vector<Type> types) { + return builder.getTempTupleType(types); + } + + Result<HeapTypeT> getHeapTypeFromIdx(Index idx) { + if (idx >= builder.size()) { + return in.err("type index out of bounds"); + } + return builder[idx]; + } + + void addFuncType(SignatureT& type) { builder[index] = type; } + + void addStructType(StructT& type) { + auto& [fieldNames, str] = type; + builder[index] = str; + for (Index i = 0; i < fieldNames.size(); ++i) { + if (auto name = fieldNames[i]; name.is()) { + names[index].fieldNames[i] = name; + } + } + } + + void addArrayType(ArrayT& type) { builder[index] = type; } + + void setOpen() { builder[index].setOpen(); } + + Result<> addSubtype(Index super) { + if (super >= builder.size()) { + return in.err("supertype index out of bounds"); + } + builder[index].subTypeOf(builder[super]); + return Ok{}; + } + + void finishSubtype(Name name, Index pos) { names[index++].name = name; } + + size_t getRecGroupStartIndex() { return index; } + + void addRecGroup(Index start, size_t len) { + builder.createRecGroup(start, len); + } + + void finishDeftype(Index) {} +}; + +// Phase 3: Parse type uses to find implicitly defined types. +struct ParseImplicitTypeDefsCtx : TypeParserCtx<ParseImplicitTypeDefsCtx> { + using TypeUseT = Ok; + + ParseInput in; + + // Types parsed so far. + std::vector<HeapType>& types; + + // Map typeuse positions without an explicit type to the correct type. + std::unordered_map<Index, HeapType>& implicitTypes; + + // Map signatures to the first defined heap type they match. + std::unordered_map<Signature, HeapType> sigTypes; + + ParseImplicitTypeDefsCtx(std::string_view in, + std::vector<HeapType>& types, + std::unordered_map<Index, HeapType>& implicitTypes, + const IndexMap& typeIndices) + : TypeParserCtx<ParseImplicitTypeDefsCtx>(typeIndices), in(in), + types(types), implicitTypes(implicitTypes) { + for (auto type : types) { + if (type.isSignature() && type.getRecGroup().size() == 1) { + sigTypes.insert({type.getSignature(), type}); + } + } + } + + Result<HeapTypeT> getHeapTypeFromIdx(Index idx) { + if (idx >= types.size()) { + return in.err("type index out of bounds"); + } + return types[idx]; + } + + Result<TypeUseT> makeTypeUse(Index pos, + std::optional<HeapTypeT>, + ParamsT* params, + ResultsT* results) { + std::vector<Type> paramTypes; + if (params) { + paramTypes = getUnnamedTypes(*params); + } + + std::vector<Type> resultTypes; + if (results) { + resultTypes = *results; + } + + auto sig = Signature(Type(paramTypes), Type(resultTypes)); + auto [it, inserted] = sigTypes.insert({sig, HeapType::func}); + if (inserted) { + auto type = HeapType(sig); + it->second = type; + types.push_back(type); + } + implicitTypes.insert({pos, it->second}); + + return Ok{}; + } +}; + +// Phase 4: Parse and set the types of module elements. +struct ParseModuleTypesCtx : TypeParserCtx<ParseModuleTypesCtx>, + NullInstrParserCtx { + // In this phase we have constructed all the types, so we can materialize and + // validate them when they are used. + + using GlobalTypeT = GlobalType; + using TypeUseT = TypeUse; + + ParseInput in; + + Module& wasm; + + const std::vector<HeapType>& types; + const std::unordered_map<Index, HeapType>& implicitTypes; + + // The index of the current type. + Index index = 0; + + ParseModuleTypesCtx(std::string_view in, + Module& wasm, + const std::vector<HeapType>& types, + const std::unordered_map<Index, HeapType>& implicitTypes, + const IndexMap& typeIndices) + : TypeParserCtx<ParseModuleTypesCtx>(typeIndices), in(in), wasm(wasm), + types(types), implicitTypes(implicitTypes) {} + + Result<HeapTypeT> getHeapTypeFromIdx(Index idx) { + if (idx >= types.size()) { + return in.err("type index out of bounds"); + } + return types[idx]; + } + + Result<TypeUseT> makeTypeUse(Index pos, + std::optional<HeapTypeT> type, + ParamsT* params, + ResultsT* results) { + std::vector<Name> ids; + if (params) { + ids.reserve(params->size()); + for (auto& p : *params) { + ids.push_back(p.name); + } + } + + if (type) { + return TypeUse{*type, ids}; + } + + auto it = implicitTypes.find(pos); + assert(it != implicitTypes.end()); + + return TypeUse{it->second, ids}; + } + + Result<HeapType> getBlockTypeFromTypeUse(Index pos, TypeUse use) { + assert(use.type.isSignature()); + if (use.type.getSignature().params != Type::none) { + return in.err(pos, "block parameters not yet supported"); + } + // TODO: Once we support block parameters, return an error here if any of + // them are named. + return use.type; + } + + GlobalTypeT makeGlobalType(Mutability mutability, TypeT type) { + return {mutability, type}; + } + + Result<> addFunc(Name name, + const std::vector<Name>&, + ImportNames*, + TypeUse type, + std::optional<LocalsT> locals, + std::optional<InstrsT>, + Index pos) { + auto& f = wasm.functions[index]; + if (!type.type.isSignature()) { + return in.err(pos, "expected signature type"); + } + f->type = type.type; + for (Index i = 0; i < type.names.size(); ++i) { + if (type.names[i].is()) { + f->setLocalName(i, type.names[i]); + } + } + if (locals) { + for (auto& l : *locals) { + Builder::addVar(f.get(), l.name, l.type); + } + } + return Ok{}; + } + + Result<> + addMemory(Name, const std::vector<Name>&, ImportNames*, MemTypeT, Index) { + return Ok{}; + } + + Result<> addImplicitData(DataStringT&& data) { return Ok{}; } + + Result<> addGlobal(Name, + const std::vector<Name>&, + ImportNames*, + GlobalType type, + std::optional<ExprT>, + Index) { + auto& g = wasm.globals[index]; + g->mutable_ = type.mutability; + g->type = type.type; + return Ok{}; + } +}; + +// Phase 5: Parse module element definitions, including instructions. +struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> { + using GlobalTypeT = Ok; + using TypeUseT = HeapType; + + // Keep track of instructions internally rather than letting the general + // parser collect them. + using InstrT = Ok; + using InstrsT = Ok; + using ExprT = Expression*; + + using FieldIdxT = Index; + using LocalIdxT = Index; + using GlobalIdxT = Name; + using MemoryIdxT = Name; + using DataIdxT = Name; + + using MemargT = Memarg; + + ParseInput in; + + Module& wasm; + Builder builder; + + const std::vector<HeapType>& types; + const std::unordered_map<Index, HeapType>& implicitTypes; + + // The index of the current module element. + Index index = 0; + + // The current function being parsed, used to create scratch locals, type + // local.get, etc. + Function* func = nullptr; + + IRBuilder irBuilder; + + void setFunction(Function* func) { + this->func = func; + irBuilder.setFunction(func); + } + + ParseDefsCtx(std::string_view in, + Module& wasm, + const std::vector<HeapType>& types, + const std::unordered_map<Index, HeapType>& implicitTypes, + const IndexMap& typeIndices) + : TypeParserCtx(typeIndices), in(in), wasm(wasm), builder(wasm), + types(types), implicitTypes(implicitTypes), irBuilder(wasm) {} + + template<typename T> Result<T> withLoc(Index pos, Result<T> res) { + if (auto err = res.getErr()) { + return in.err(pos, err->msg); + } + return res; + } + + template<typename T> Result<T> withLoc(Result<T> res) { + return withLoc(in.getPos(), res); + } + + HeapType getBlockTypeFromResult(const std::vector<Type> results) { + assert(results.size() == 1); + return HeapType(Signature(Type::none, results[0])); + } + + Result<HeapType> getBlockTypeFromTypeUse(Index pos, HeapType type) { + return type; + } + + Ok makeInstrs() { return Ok{}; } + + void appendInstr(Ok&, InstrT instr) {} + + Result<InstrsT> finishInstrs(Ok&) { return Ok{}; } + + Result<Expression*> instrToExpr(Ok&) { return irBuilder.build(); } + + GlobalTypeT makeGlobalType(Mutability, TypeT) { return Ok{}; } + + Result<HeapTypeT> getHeapTypeFromIdx(Index idx) { + if (idx >= types.size()) { + return in.err("type index out of bounds"); + } + return types[idx]; + } + + Result<Index> getFieldFromIdx(HeapType type, uint32_t idx) { + if (!type.isStruct()) { + return in.err("expected struct type"); + } + if (idx >= type.getStruct().fields.size()) { + return in.err("struct index out of bounds"); + } + return idx; + } + + Result<Index> getFieldFromName(HeapType type, Name name) { + // TODO: Field names + return in.err("symbolic field names note yet supported"); + } + + Result<Index> getLocalFromIdx(uint32_t idx) { + if (!func) { + return in.err("cannot access locals outside of a function"); + } + if (idx >= func->getNumLocals()) { + return in.err("local index out of bounds"); + } + return idx; + } + + Result<Index> getLocalFromName(Name name) { + if (!func) { + return in.err("cannot access locals outside of a function"); + } + if (!func->hasLocalIndex(name)) { + return in.err("local $" + name.toString() + " does not exist"); + } + return func->getLocalIndex(name); + } + + Result<Name> getGlobalFromIdx(uint32_t idx) { + if (idx >= wasm.globals.size()) { + return in.err("global index out of bounds"); + } + return wasm.globals[idx]->name; + } + + Result<Name> getGlobalFromName(Name name) { + if (!wasm.getGlobalOrNull(name)) { + return in.err("global $" + name.toString() + " does not exist"); + } + return name; + } + + Result<Name> getMemoryFromIdx(uint32_t idx) { + if (idx >= wasm.memories.size()) { + return in.err("memory index out of bounds"); + } + return wasm.memories[idx]->name; + } + + Result<Name> getMemoryFromName(Name name) { + if (!wasm.getMemoryOrNull(name)) { + return in.err("memory $" + name.toString() + " does not exist"); + } + return name; + } + + Result<Name> getDataFromIdx(uint32_t idx) { + if (idx >= wasm.dataSegments.size()) { + return in.err("data index out of bounds"); + } + return wasm.dataSegments[idx]->name; + } + + Result<Name> getDataFromName(Name name) { + if (!wasm.getDataSegmentOrNull(name)) { + return in.err("data $" + name.toString() + " does not exist"); + } + return name; + } + + Result<TypeUseT> makeTypeUse(Index pos, + std::optional<HeapTypeT> type, + ParamsT* params, + ResultsT* results); + Result<> addFunc(Name, + const std::vector<Name>&, + ImportNames*, + TypeUseT, + std::optional<LocalsT>, + std::optional<InstrsT>, + Index pos); + + Result<> addGlobal(Name, + const std::vector<Name>&, + ImportNames*, + GlobalTypeT, + std::optional<ExprT> exp, + Index); + Result<> + addData(Name, Name* mem, std::optional<ExprT> offset, DataStringT, Index pos); + Result<Index> addScratchLocal(Index pos, Type type) { + if (!func) { + return in.err(pos, + "scratch local required, but there is no function context"); + } + Name name = Names::getValidLocalName(*func, "scratch"); + return Builder::addVar(func, name, type); + } + + Result<Expression*> makeExpr(InstrsT& instrs) { return irBuilder.build(); } + + Memarg getMemarg(uint64_t offset, uint32_t align) { return {offset, align}; } + + Result<Name> getMemory(Index pos, Name* mem) { + if (mem) { + return *mem; + } + if (wasm.memories.empty()) { + return in.err(pos, "memory required, but there is no memory"); + } + return wasm.memories[0]->name; + } + + Result<> makeBlock(Index pos, std::optional<Name> label, HeapType type) { + // TODO: validate labels? + // TODO: Move error on input types to here? + return withLoc(pos, + irBuilder.makeBlock(label ? *label : Name{}, + type.getSignature().results)); + } + + Result<> finishBlock(Index pos, InstrsT) { + return withLoc(pos, irBuilder.visitEnd()); + } + + Result<> makeUnreachable(Index pos) { + return withLoc(pos, irBuilder.makeUnreachable()); + } + + Result<> makeNop(Index pos) { return withLoc(pos, irBuilder.makeNop()); } + + Result<> makeBinary(Index pos, BinaryOp op) { + return withLoc(pos, irBuilder.makeBinary(op)); + } + + Result<> makeUnary(Index pos, UnaryOp op) { + return withLoc(pos, irBuilder.makeUnary(op)); + } + + Result<> makeSelect(Index pos, std::vector<Type>* res) { + if (res && res->size()) { + if (res->size() > 1) { + return in.err(pos, "select may not have more than one result type"); + } + return withLoc(pos, irBuilder.makeSelect((*res)[0])); + } + return withLoc(pos, irBuilder.makeSelect()); + } + + Result<> makeDrop(Index pos) { return withLoc(pos, irBuilder.makeDrop()); } + + Result<> makeMemorySize(Index pos, Name* mem) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeMemorySize(*m)); + } + + Result<> makeMemoryGrow(Index pos, Name* mem) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeMemoryGrow(*m)); + } + + Result<> makeLocalGet(Index pos, Index local) { + return withLoc(pos, irBuilder.makeLocalGet(local)); + } + + Result<> makeLocalTee(Index pos, Index local) { + return withLoc(pos, irBuilder.makeLocalTee(local)); + } + + Result<> makeLocalSet(Index pos, Index local) { + return withLoc(pos, irBuilder.makeLocalSet(local)); + } + + Result<> makeGlobalGet(Index pos, Name global) { + return withLoc(pos, irBuilder.makeGlobalGet(global)); + } + + Result<> makeGlobalSet(Index pos, Name global) { + assert(wasm.getGlobalOrNull(global)); + return withLoc(pos, irBuilder.makeGlobalSet(global)); + } + + Result<> makeI32Const(Index pos, uint32_t c) { + return withLoc(pos, irBuilder.makeConst(Literal(c))); + } + + Result<> makeI64Const(Index pos, uint64_t c) { + return withLoc(pos, irBuilder.makeConst(Literal(c))); + } + + Result<> makeF32Const(Index pos, float c) { + return withLoc(pos, irBuilder.makeConst(Literal(c))); + } + + Result<> makeF64Const(Index pos, double c) { + return withLoc(pos, irBuilder.makeConst(Literal(c))); + } + + Result<> makeLoad(Index pos, + Type type, + bool signed_, + int bytes, + bool isAtomic, + Name* mem, + Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + if (isAtomic) { + return withLoc(pos, + irBuilder.makeAtomicLoad(bytes, memarg.offset, type, *m)); + } + return withLoc(pos, + irBuilder.makeLoad( + bytes, signed_, memarg.offset, memarg.align, type, *m)); + } + + Result<> makeStore( + Index pos, Type type, int bytes, bool isAtomic, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + if (isAtomic) { + return withLoc(pos, + irBuilder.makeAtomicStore(bytes, memarg.offset, type, *m)); + } + return withLoc( + pos, irBuilder.makeStore(bytes, memarg.offset, memarg.align, type, *m)); + } + + Result<> makeAtomicRMW( + Index pos, AtomicRMWOp op, Type type, int bytes, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, + irBuilder.makeAtomicRMW(op, bytes, memarg.offset, type, *m)); + } + + Result<> + makeAtomicCmpxchg(Index pos, Type type, int bytes, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, + irBuilder.makeAtomicCmpxchg(bytes, memarg.offset, type, *m)); + } + + Result<> makeAtomicWait(Index pos, Type type, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeAtomicWait(type, memarg.offset, *m)); + } + + Result<> makeAtomicNotify(Index pos, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeAtomicNotify(memarg.offset, *m)); + } + + Result<> makeAtomicFence(Index pos) { + return withLoc(pos, irBuilder.makeAtomicFence()); + } + + Result<> makeSIMDExtract(Index pos, SIMDExtractOp op, uint8_t lane) { + return withLoc(pos, irBuilder.makeSIMDExtract(op, lane)); + } + + Result<> makeSIMDReplace(Index pos, SIMDReplaceOp op, uint8_t lane) { + return withLoc(pos, irBuilder.makeSIMDReplace(op, lane)); + } + + Result<> makeSIMDShuffle(Index pos, const std::array<uint8_t, 16>& lanes) { + return withLoc(pos, irBuilder.makeSIMDShuffle(lanes)); + } + + Result<> makeSIMDTernary(Index pos, SIMDTernaryOp op) { + return withLoc(pos, irBuilder.makeSIMDTernary(op)); + } + + Result<> makeSIMDShift(Index pos, SIMDShiftOp op) { + return withLoc(pos, irBuilder.makeSIMDShift(op)); + } + + Result<> makeSIMDLoad(Index pos, SIMDLoadOp op, Name* mem, Memarg memarg) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, + irBuilder.makeSIMDLoad(op, memarg.offset, memarg.align, *m)); + } + + Result<> makeSIMDLoadStoreLane( + Index pos, SIMDLoadStoreLaneOp op, Name* mem, Memarg memarg, uint8_t lane) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, + irBuilder.makeSIMDLoadStoreLane( + op, memarg.offset, memarg.align, lane, *m)); + } + + Result<> makeMemoryInit(Index pos, Name* mem, Name data) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeMemoryInit(data, *m)); + } + + Result<> makeDataDrop(Index pos, Name data) { + return withLoc(pos, irBuilder.makeDataDrop(data)); + } + + Result<> makeMemoryCopy(Index pos, Name* destMem, Name* srcMem) { + auto destMemory = getMemory(pos, destMem); + CHECK_ERR(destMemory); + auto srcMemory = getMemory(pos, srcMem); + CHECK_ERR(srcMemory); + return withLoc(pos, irBuilder.makeMemoryCopy(*destMemory, *srcMemory)); + } + + Result<> makeMemoryFill(Index pos, Name* mem) { + auto m = getMemory(pos, mem); + CHECK_ERR(m); + return withLoc(pos, irBuilder.makeMemoryFill(*m)); + } + + Result<> makeReturn(Index pos) { + return withLoc(pos, irBuilder.makeReturn()); + } + + Result<> makeRefNull(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeRefNull(type)); + } + + Result<> makeRefIsNull(Index pos) { + return withLoc(pos, irBuilder.makeRefIsNull()); + } + + Result<> makeRefEq(Index pos) { return withLoc(pos, irBuilder.makeRefEq()); } + + Result<> makeRefI31(Index pos) { + return withLoc(pos, irBuilder.makeRefI31()); + } + + Result<> makeI31Get(Index pos, bool signed_) { + return withLoc(pos, irBuilder.makeI31Get(signed_)); + } + + Result<> makeStructNew(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeStructNew(type)); + } + + Result<> makeStructNewDefault(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeStructNewDefault(type)); + } + + Result<> makeStructGet(Index pos, HeapType type, Index field, bool signed_) { + return withLoc(pos, irBuilder.makeStructGet(type, field, signed_)); + } + + Result<> makeStructSet(Index pos, HeapType type, Index field) { + return withLoc(pos, irBuilder.makeStructSet(type, field)); + } + + Result<> makeArrayNew(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeArrayNew(type)); + } + + Result<> makeArrayNewDefault(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeArrayNewDefault(type)); + } + + Result<> makeArrayNewData(Index pos, HeapType type, Name data) { + return withLoc(pos, irBuilder.makeArrayNewData(type, data)); + } + + Result<> makeArrayNewElem(Index pos, HeapType type, Name elem) { + return withLoc(pos, irBuilder.makeArrayNewElem(type, elem)); + } + + Result<> makeArrayGet(Index pos, HeapType type, bool signed_) { + return withLoc(pos, irBuilder.makeArrayGet(type, signed_)); + } + + Result<> makeArraySet(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeArraySet(type)); + } + + Result<> makeArrayLen(Index pos) { + return withLoc(pos, irBuilder.makeArrayLen()); + } + + Result<> makeArrayCopy(Index pos, HeapType destType, HeapType srcType) { + return withLoc(pos, irBuilder.makeArrayCopy(destType, srcType)); + } + + Result<> makeArrayFill(Index pos, HeapType type) { + return withLoc(pos, irBuilder.makeArrayFill(type)); + } +}; + +} // namespace wasm::WATParser + +#endif // parser_context_h diff --git a/src/parser/input-impl.h b/src/parser/input-impl.h new file mode 100644 index 000000000..35a39b2f3 --- /dev/null +++ b/src/parser/input-impl.h @@ -0,0 +1,273 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "input.h" + +#ifndef parser_input_impl_h +#define parser_input_impl_h + +inline std::optional<Token> ParseInput::peek() { + if (!empty()) { + return *lexer; + } + return {}; +} + +inline bool ParseInput::takeLParen() { + auto t = peek(); + if (!t || !t->isLParen()) { + return false; + } + ++lexer; + return true; +} + +inline bool ParseInput::takeRParen() { + auto t = peek(); + if (!t || !t->isRParen()) { + return false; + } + ++lexer; + return true; +} + +inline bool ParseInput::takeUntilParen() { + while (true) { + auto t = peek(); + if (!t) { + return false; + } + if (t->isLParen() || t->isRParen()) { + return true; + } + ++lexer; + } +} + +inline std::optional<Name> ParseInput::takeID() { + if (auto t = peek()) { + if (auto id = t->getID()) { + ++lexer; + // See comment on takeName. + return Name(std::string(*id)); + } + } + return {}; +} + +inline std::optional<std::string_view> ParseInput::takeKeyword() { + if (auto t = peek()) { + if (auto keyword = t->getKeyword()) { + ++lexer; + return *keyword; + } + } + return {}; +} + +inline bool ParseInput::takeKeyword(std::string_view expected) { + if (auto t = peek()) { + if (auto keyword = t->getKeyword()) { + if (*keyword == expected) { + ++lexer; + return true; + } + } + } + return false; +} + +inline std::optional<uint64_t> ParseInput::takeOffset() { + if (auto t = peek()) { + if (auto keyword = t->getKeyword()) { + if (keyword->substr(0, 7) != "offset="sv) { + return {}; + } + Lexer subLexer(keyword->substr(7)); + if (subLexer == subLexer.end()) { + return {}; + } + if (auto o = subLexer->getU64()) { + ++subLexer; + if (subLexer == subLexer.end()) { + ++lexer; + return o; + } + } + } + } + return std::nullopt; +} + +inline std::optional<uint32_t> ParseInput::takeAlign() { + if (auto t = peek()) { + if (auto keyword = t->getKeyword()) { + if (keyword->substr(0, 6) != "align="sv) { + return {}; + } + Lexer subLexer(keyword->substr(6)); + if (subLexer == subLexer.end()) { + return {}; + } + if (auto a = subLexer->getU32()) { + ++subLexer; + if (subLexer == subLexer.end()) { + ++lexer; + return a; + } + } + } + } + return {}; +} + +inline std::optional<uint64_t> ParseInput::takeU64() { + if (auto t = peek()) { + if (auto n = t->getU64()) { + ++lexer; + return n; + } + } + return std::nullopt; +} + +inline std::optional<int64_t> ParseInput::takeS64() { + if (auto t = peek()) { + if (auto n = t->getS64()) { + ++lexer; + return n; + } + } + return {}; +} + +inline std::optional<int64_t> ParseInput::takeI64() { + if (auto t = peek()) { + if (auto n = t->getI64()) { + ++lexer; + return n; + } + } + return {}; +} + +inline std::optional<uint32_t> ParseInput::takeU32() { + if (auto t = peek()) { + if (auto n = t->getU32()) { + ++lexer; + return n; + } + } + return std::nullopt; +} + +inline std::optional<int32_t> ParseInput::takeS32() { + if (auto t = peek()) { + if (auto n = t->getS32()) { + ++lexer; + return n; + } + } + return {}; +} + +inline std::optional<int32_t> ParseInput::takeI32() { + if (auto t = peek()) { + if (auto n = t->getI32()) { + ++lexer; + return n; + } + } + return {}; +} + +inline std::optional<uint8_t> ParseInput::takeU8() { + if (auto t = peek()) { + if (auto n = t->getU32()) { + if (n <= std::numeric_limits<uint8_t>::max()) { + ++lexer; + return uint8_t(*n); + } + } + } + return {}; +} + +inline std::optional<double> ParseInput::takeF64() { + if (auto t = peek()) { + if (auto d = t->getF64()) { + ++lexer; + return d; + } + } + return std::nullopt; +} + +inline std::optional<float> ParseInput::takeF32() { + if (auto t = peek()) { + if (auto f = t->getF32()) { + ++lexer; + return f; + } + } + return std::nullopt; +} + +inline std::optional<std::string_view> ParseInput::takeString() { + if (auto t = peek()) { + if (auto s = t->getString()) { + ++lexer; + return s; + } + } + return {}; +} + +inline std::optional<Name> ParseInput::takeName() { + // TODO: Move this to lexer and validate UTF. + if (auto str = takeString()) { + // Copy to a std::string to make sure we have a null terminator, otherwise + // the `Name` constructor won't work correctly. + // TODO: Update `Name` to use string_view instead of char* and/or to take + // rvalue strings to avoid this extra copy. + return Name(std::string(*str)); + } + return {}; +} + +inline bool ParseInput::takeSExprStart(std::string_view expected) { + auto original = lexer; + if (takeLParen() && takeKeyword(expected)) { + return true; + } + lexer = original; + return false; +} + +inline Index ParseInput::getPos() { + if (auto t = peek()) { + return lexer.getIndex() - t->span.size(); + } + return lexer.getIndex(); +} + +inline Err ParseInput::err(Index pos, std::string reason) { + std::stringstream msg; + msg << lexer.position(pos) << ": error: " << reason; + return Err{msg.str()}; +} + +#endif // parser_input_impl_h diff --git a/src/parser/input.h b/src/parser/input.h new file mode 100644 index 000000000..5c7c57d20 --- /dev/null +++ b/src/parser/input.h @@ -0,0 +1,75 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef parser_input_h +#define parser_input_h + +#include "lexer.h" +#include "support/result.h" +#include "wasm.h" + +namespace wasm::WATParser { + +using namespace std::string_view_literals; + +// Wraps a lexer and provides utilities for consuming tokens. +struct ParseInput { + Lexer lexer; + + explicit ParseInput(std::string_view in) : lexer(in) {} + + ParseInput(std::string_view in, size_t index) : lexer(in) { + lexer.setIndex(index); + } + + ParseInput(const ParseInput& other, size_t index) : lexer(other.lexer) { + lexer.setIndex(index); + } + + bool empty() { return lexer.empty(); } + + std::optional<Token> peek(); + bool takeLParen(); + bool takeRParen(); + bool takeUntilParen(); + std::optional<Name> takeID(); + std::optional<std::string_view> takeKeyword(); + bool takeKeyword(std::string_view expected); + std::optional<uint64_t> takeOffset(); + std::optional<uint32_t> takeAlign(); + std::optional<uint64_t> takeU64(); + std::optional<int64_t> takeS64(); + std::optional<int64_t> takeI64(); + std::optional<uint32_t> takeU32(); + std::optional<int32_t> takeS32(); + std::optional<int32_t> takeI32(); + std::optional<uint8_t> takeU8(); + std::optional<double> takeF64(); + std::optional<float> takeF32(); + std::optional<std::string_view> takeString(); + std::optional<Name> takeName(); + bool takeSExprStart(std::string_view expected); + + Index getPos(); + [[nodiscard]] Err err(Index pos, std::string reason); + [[nodiscard]] Err err(std::string reason) { return err(getPos(), reason); } +}; + +#include "input-impl.h" + +} // namespace wasm::WATParser + +#endif // parser_input_h diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp new file mode 100644 index 000000000..0796013fe --- /dev/null +++ b/src/parser/lexer.cpp @@ -0,0 +1,1038 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cassert> +#include <cctype> +#include <cmath> +#include <iostream> +#include <optional> +#include <sstream> +#include <variant> + +#include "lexer.h" + +using namespace std::string_view_literals; + +namespace wasm::WATParser { + +namespace { + +// ================ +// Lexical Analysis +// ================ + +// The result of lexing a token fragment. +struct LexResult { + std::string_view span; +}; + +// Lexing context that accumulates lexed input to produce a token fragment. +struct LexCtx { +private: + // The input we are lexing. + std::string_view input; + + // How much of the input we have already lexed. + size_t lexedSize = 0; + +public: + explicit LexCtx(std::string_view in) : input(in) {} + + // Return the fragment that has been lexed so far. + std::optional<LexResult> lexed() const { + if (lexedSize > 0) { + return {LexResult{input.substr(0, lexedSize)}}; + } + return {}; + } + + // The next input that has not already been lexed. + std::string_view next() const { return input.substr(lexedSize); } + + // Get the next character without consuming it. + uint8_t peek() const { return next()[0]; } + + // The size of the unlexed input. + size_t size() const { return input.size() - lexedSize; } + + // Whether there is no more input. + bool empty() const { return size() == 0; } + + // Tokens must be separated by spaces or parentheses. + bool canFinish() const; + + // Whether the unlexed input starts with prefix `sv`. + size_t startsWith(std::string_view sv) const { + return next().substr(0, sv.size()) == sv; + } + + // Consume the next `n` characters. + void take(size_t n) { lexedSize += n; } + + // Consume an additional lexed fragment. + void take(const LexResult& res) { lexedSize += res.span.size(); } + + // Consume the prefix and return true if possible. + bool takePrefix(std::string_view sv) { + if (startsWith(sv)) { + take(sv.size()); + return true; + } + return false; + } + + // Consume the rest of the input. + void takeAll() { lexedSize = input.size(); } +}; + +enum OverflowBehavior { DisallowOverflow, IgnoreOverflow }; + +std::optional<int> getDigit(char c) { + if ('0' <= c && c <= '9') { + return c - '0'; + } + return {}; +} + +std::optional<int> getHexDigit(char c) { + if ('0' <= c && c <= '9') { + return c - '0'; + } + if ('A' <= c && c <= 'F') { + return 10 + c - 'A'; + } + if ('a' <= c && c <= 'f') { + return 10 + c - 'a'; + } + return {}; +} + +// The result of lexing an integer token fragment. +struct LexIntResult : LexResult { + uint64_t n; + Sign sign; +}; + +// Lexing context that accumulates lexed input to produce an integer token +// fragment. +struct LexIntCtx : LexCtx { + using LexCtx::take; + +private: + uint64_t n = 0; + Sign sign = NoSign; + bool overflow = false; + +public: + explicit LexIntCtx(std::string_view in) : LexCtx(in) {} + + // Lex only the underlying span, ignoring the overflow and value. + std::optional<LexIntResult> lexedRaw() { + if (auto basic = LexCtx::lexed()) { + return LexIntResult{*basic, 0, NoSign}; + } + return {}; + } + + std::optional<LexIntResult> lexed() { + if (overflow) { + return {}; + } + if (auto basic = LexCtx::lexed()) { + return LexIntResult{*basic, sign == Neg ? -n : n, sign}; + } + return {}; + } + + void takeSign() { + if (takePrefix("+"sv)) { + sign = Pos; + } else if (takePrefix("-"sv)) { + sign = Neg; + } else { + sign = NoSign; + } + } + + bool takeDigit() { + if (!empty()) { + if (auto d = getDigit(peek())) { + take(1); + uint64_t newN = n * 10 + *d; + if (newN < n) { + overflow = true; + } + n = newN; + return true; + } + } + return false; + } + + bool takeHexdigit() { + if (!empty()) { + if (auto h = getHexDigit(peek())) { + take(1); + uint64_t newN = n * 16 + *h; + if (newN < n) { + overflow = true; + } + n = newN; + return true; + } + } + return false; + } + + void take(const LexIntResult& res) { + LexCtx::take(res); + n = res.n; + } +}; + +struct LexFloatResult : LexResult { + // The payload if we lexed a nan with payload. We cannot store the payload + // directly in `d` because we do not know at this point whether we are parsing + // an f32 or f64 and therefore we do not know what the allowable payloads are. + // No payload with NaN means to use the default payload for the expected float + // width. + std::optional<uint64_t> nanPayload; + double d; +}; + +struct LexFloatCtx : LexCtx { + std::optional<uint64_t> nanPayload; + + LexFloatCtx(std::string_view in) : LexCtx(in) {} + + std::optional<LexFloatResult> lexed() { + const double posNan = std::copysign(NAN, 1.0); + const double negNan = std::copysign(NAN, -1.0); + assert(!std::signbit(posNan) && "expected positive NaN to be positive"); + assert(std::signbit(negNan) && "expected negative NaN to be negative"); + auto basic = LexCtx::lexed(); + if (!basic) { + return {}; + } + // strtod does not return NaNs with the expected signs on all platforms. + // TODO: use starts_with once we have C++20. + if (basic->span.substr(0, 3) == "nan"sv || + basic->span.substr(0, 4) == "+nan"sv) { + return LexFloatResult{*basic, nanPayload, posNan}; + } + if (basic->span.substr(0, 4) == "-nan"sv) { + return LexFloatResult{*basic, nanPayload, negNan}; + } + // Do not try to implement fully general and precise float parsing + // ourselves. Instead, call out to std::strtod to do our parsing. This means + // we need to strip any underscores since `std::strtod` does not understand + // them. + std::stringstream ss; + for (const char *curr = basic->span.data(), + *end = curr + basic->span.size(); + curr != end; + ++curr) { + if (*curr != '_') { + ss << *curr; + } + } + std::string str = ss.str(); + char* last; + double d = std::strtod(str.data(), &last); + assert(last == str.data() + str.size() && "could not parse float"); + return LexFloatResult{*basic, {}, d}; + } +}; + +struct LexStrResult : LexResult { + // Allocate a string only if there are escape sequences, otherwise just use + // the original string_view. + std::optional<std::string> str; +}; + +struct LexStrCtx : LexCtx { +private: + // Used to build a string with resolved escape sequences. Only used when the + // parsed string contains escape sequences, otherwise we can just use the + // parsed string directly. + std::optional<std::stringstream> escapeBuilder; + +public: + LexStrCtx(std::string_view in) : LexCtx(in) {} + + std::optional<LexStrResult> lexed() { + if (auto basic = LexCtx::lexed()) { + if (escapeBuilder) { + return LexStrResult{*basic, {escapeBuilder->str()}}; + } else { + return LexStrResult{*basic, {}}; + } + } + return {}; + } + + void takeChar() { + if (escapeBuilder) { + *escapeBuilder << peek(); + } + LexCtx::take(1); + } + + void ensureBuildingEscaped() { + if (escapeBuilder) { + return; + } + // Drop the opening '"'. + escapeBuilder = std::stringstream{}; + *escapeBuilder << LexCtx::lexed()->span.substr(1); + } + + void appendEscaped(char c) { *escapeBuilder << c; } + + bool appendUnicode(uint64_t u) { + if ((0xd800 <= u && u < 0xe000) || 0x110000 <= u) { + return false; + } + if (u < 0x80) { + // 0xxxxxxx + *escapeBuilder << uint8_t(u); + } else if (u < 0x800) { + // 110xxxxx 10xxxxxx + *escapeBuilder << uint8_t(0b11000000 | ((u >> 6) & 0b00011111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } else if (u < 0x10000) { + // 1110xxxx 10xxxxxx 10xxxxxx + *escapeBuilder << uint8_t(0b11100000 | ((u >> 12) & 0b00001111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } else { + // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + *escapeBuilder << uint8_t(0b11110000 | ((u >> 18) & 0b00000111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 12) & 0b00111111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 6) & 0b00111111)); + *escapeBuilder << uint8_t(0b10000000 | ((u >> 0) & 0b00111111)); + } + return true; + } +}; + +std::optional<LexResult> lparen(std::string_view in) { + LexCtx ctx(in); + ctx.takePrefix("("sv); + return ctx.lexed(); +} + +std::optional<LexResult> rparen(std::string_view in) { + LexCtx ctx(in); + ctx.takePrefix(")"sv); + return ctx.lexed(); +} + +// comment ::= linecomment | blockcomment +// linecomment ::= ';;' linechar* ('\n' | eof) +// linechar ::= c:char (if c != '\n') +// blockcomment ::= '(;' blockchar* ';)' +// blockchar ::= c:char (if c != ';' and c != '(') +// | ';' (if the next char is not ')') +// | '(' (if the next char is not ';') +// | blockcomment +std::optional<LexResult> comment(std::string_view in) { + LexCtx ctx(in); + if (ctx.size() < 2) { + return {}; + } + + // Line comment + if (ctx.takePrefix(";;"sv)) { + if (auto size = ctx.next().find('\n'); size != ""sv.npos) { + ctx.take(size); + } else { + ctx.takeAll(); + } + return ctx.lexed(); + } + + // Block comment (possibly nested!) + if (ctx.takePrefix("(;"sv)) { + size_t depth = 1; + while (depth > 0 && ctx.size() >= 2) { + if (ctx.takePrefix("(;"sv)) { + ++depth; + } else if (ctx.takePrefix(";)"sv)) { + --depth; + } else { + ctx.take(1); + } + } + if (depth > 0) { + // TODO: Add error production for non-terminated block comment. + return {}; + } + return ctx.lexed(); + } + + return {}; +} + +std::optional<LexResult> spacechar(std::string_view in) { + LexCtx ctx(in); + ctx.takePrefix(" "sv) || ctx.takePrefix("\n"sv) || ctx.takePrefix("\r"sv) || + ctx.takePrefix("\t"sv); + return ctx.lexed(); +} + +// space ::= (' ' | format | comment)* +// format ::= '\t' | '\n' | '\r' +std::optional<LexResult> space(std::string_view in) { + LexCtx ctx(in); + while (ctx.size()) { + if (auto lexed = spacechar(ctx.next())) { + ctx.take(*lexed); + } else if (auto lexed = comment(ctx.next())) { + ctx.take(*lexed); + } else { + break; + } + } + return ctx.lexed(); +} + +bool LexCtx::canFinish() const { + // Logically we want to check for eof, parens, and space. But we don't + // actually want to parse more than a couple characters of space, so check for + // individual space chars or comment starts instead. + return empty() || lparen(next()) || rparen(next()) || spacechar(next()) || + startsWith(";;"sv); +} + +// num ::= d:digit => d +// | n:num '_'? d:digit => 10*n + d +// digit ::= '0' => 0 | ... | '9' => 9 +std::optional<LexIntResult> num(std::string_view in, + OverflowBehavior overflow = DisallowOverflow) { + LexIntCtx ctx(in); + if (ctx.empty()) { + return {}; + } + if (!ctx.takeDigit()) { + return {}; + } + while (true) { + bool under = ctx.takePrefix("_"sv); + if (!ctx.takeDigit()) { + if (!under) { + return overflow == DisallowOverflow ? ctx.lexed() : ctx.lexedRaw(); + } + // TODO: Add error production for trailing underscore. + return {}; + } + } +} + +// hexnum ::= h:hexdigit => h +// | n:hexnum '_'? h:hexdigit => 16*n + h +// hexdigit ::= d:digit => d +// | 'A' => 10 | ... | 'F' => 15 +// | 'a' => 10 | ... | 'f' => 15 +std::optional<LexIntResult> +hexnum(std::string_view in, OverflowBehavior overflow = DisallowOverflow) { + LexIntCtx ctx(in); + if (!ctx.takeHexdigit()) { + return {}; + } + while (true) { + bool under = ctx.takePrefix("_"sv); + if (!ctx.takeHexdigit()) { + if (!under) { + return overflow == DisallowOverflow ? ctx.lexed() : ctx.lexedRaw(); + } + // TODO: Add error production for trailing underscore. + return {}; + } + } +} + +// uN ::= n:num => n (if n < 2^N) +// | '0x' n:hexnum => n (if n < 2^N) +// sN ::= s:sign n:num => [s]n (if -2^(N-1) <= [s]n < 2^(N-1)) +// | s:sign '0x' n:hexnum => [s]n (if -2^(N-1) <= [s]n < 2^(N-1)) +// sign ::= {} => + | '+' => + | '-' => - +// +// Note: Defer bounds and sign checking until we know what kind of integer we +// expect. +std::optional<LexIntResult> integer(std::string_view in) { + LexIntCtx ctx(in); + ctx.takeSign(); + if (ctx.takePrefix("0x"sv)) { + if (auto lexed = hexnum(ctx.next())) { + ctx.take(*lexed); + if (ctx.canFinish()) { + return ctx.lexed(); + } + } + // TODO: Add error production for unrecognized hexnum. + return {}; + } + if (auto lexed = num(ctx.next())) { + ctx.take(*lexed); + if (ctx.canFinish()) { + return ctx.lexed(); + } + } + return {}; +} + +// float ::= p:num '.'? => p +// | p:num '.' q:frac => p + q +// | p:num '.'? ('E'|'e') s:sign e:num => p * 10^([s]e) +// | p:num '.' q:frac ('E'|'e') s:sign e:num => (p + q) * 10^([s]e) +// frac ::= d:digit => d/10 +// | d:digit '_'? p:frac => (d + p/10) / 10 +std::optional<LexResult> decfloat(std::string_view in) { + LexCtx ctx(in); + if (auto lexed = num(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } else { + return {}; + } + // Optional '.' followed by optional frac + if (ctx.takePrefix("."sv)) { + if (auto lexed = num(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } + } + if (ctx.takePrefix("E"sv) || ctx.takePrefix("e"sv)) { + // Optional sign + ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); + if (auto lexed = num(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } else { + // TODO: Add error production for missing exponent. + return {}; + } + } + return ctx.lexed(); +} + +// hexfloat ::= '0x' p:hexnum '.'? => p +// | '0x' p:hexnum '.' q:hexfrac => p + q +// | '0x' p:hexnum '.'? ('P'|'p') s:sign e:num => p * 2^([s]e) +// | '0x' p:hexnum '.' q:hexfrac ('P'|'p') s:sign e:num +// => (p + q) * 2^([s]e) +// hexfrac ::= h:hexdigit => h/16 +// | h:hexdigit '_'? p:hexfrac => (h + p/16) / 16 +std::optional<LexResult> hexfloat(std::string_view in) { + LexCtx ctx(in); + if (!ctx.takePrefix("0x"sv)) { + return {}; + } + if (auto lexed = hexnum(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } else { + return {}; + } + // Optional '.' followed by optional hexfrac + if (ctx.takePrefix("."sv)) { + if (auto lexed = hexnum(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } + } + if (ctx.takePrefix("P"sv) || ctx.takePrefix("p"sv)) { + // Optional sign + ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); + if (auto lexed = num(ctx.next(), IgnoreOverflow)) { + ctx.take(*lexed); + } else { + // TODO: Add error production for missing exponent. + return {}; + } + } + return ctx.lexed(); +} + +// fN ::= s:sign z:fNmag => [s]z +// fNmag ::= z:float => float_N(z) (if float_N(z) != +/-infinity) +// | z:hexfloat => float_N(z) (if float_N(z) != +/-infinity) +// | 'inf' => infinity +// | 'nan' => nan(2^(signif(N)-1)) +// | 'nan:0x' n:hexnum => nan(n) (if 1 <= n < 2^signif(N)) +std::optional<LexFloatResult> float_(std::string_view in) { + LexFloatCtx ctx(in); + // Optional sign + ctx.takePrefix("+"sv) || ctx.takePrefix("-"sv); + if (auto lexed = hexfloat(ctx.next())) { + ctx.take(*lexed); + } else if (auto lexed = decfloat(ctx.next())) { + ctx.take(*lexed); + } else if (ctx.takePrefix("inf"sv)) { + // nop + } else if (ctx.takePrefix("nan"sv)) { + if (ctx.takePrefix(":0x"sv)) { + if (auto lexed = hexnum(ctx.next())) { + ctx.take(*lexed); + ctx.nanPayload = lexed->n; + } else { + // TODO: Add error production for malformed NaN payload. + return {}; + } + } else { + // No explicit payload necessary; we will inject the default payload + // later. + } + } else { + return {}; + } + if (ctx.canFinish()) { + return ctx.lexed(); + } + return {}; +} + +// idchar ::= '0' | ... | '9' +// | 'A' | ... | 'Z' +// | 'a' | ... | 'z' +// | '!' | '#' | '$' | '%' | '&' | ''' | '*' | '+' +// | '-' | '.' | '/' | ':' | '<' | '=' | '>' | '?' +// | '@' | '\' | '^' | '_' | '`' | '|' | '~' +std::optional<LexResult> idchar(std::string_view in) { + LexCtx ctx(in); + if (ctx.empty()) { + return {}; + } + uint8_t c = ctx.peek(); + if (('0' <= c && c <= '9') || ('A' <= c && c <= 'Z') || + ('a' <= c && c <= 'z')) { + ctx.take(1); + } else { + switch (c) { + case '!': + case '#': + case '$': + case '%': + case '&': + case '\'': + case '*': + case '+': + case '-': + case '.': + case '/': + case ':': + case '<': + case '=': + case '>': + case '?': + case '@': + case '\\': + case '^': + case '_': + case '`': + case '|': + case '~': + ctx.take(1); + } + } + return ctx.lexed(); +} + +// id ::= '$' idchar+ +std::optional<LexResult> ident(std::string_view in) { + LexCtx ctx(in); + if (!ctx.takePrefix("$"sv)) { + return {}; + } + if (auto lexed = idchar(ctx.next())) { + ctx.take(*lexed); + } else { + return {}; + } + while (auto lexed = idchar(ctx.next())) { + ctx.take(*lexed); + } + if (ctx.canFinish()) { + return ctx.lexed(); + } + return {}; +} + +// string ::= '"' (b*:stringelem)* '"' => concat((b*)*) +// (if |concat((b*)*)| < 2^32) +// stringelem ::= c:stringchar => utf8(c) +// | '\' n:hexdigit m:hexdigit => 16*n + m +// stringchar ::= c:char => c +// (if c >= U+20 && c != U+7f && c != '"' && c != '\') +// | '\t' => \t | '\n' => \n | '\r' => \r +// | '\\' => \ | '\"' => " | '\'' => ' +// | '\u{' n:hexnum '}' => U+(n) +// (if n < 0xD800 and 0xE000 <= n <= 0x110000) +std::optional<LexStrResult> str(std::string_view in) { + LexStrCtx ctx(in); + if (!ctx.takePrefix("\""sv)) { + return {}; + } + while (!ctx.takePrefix("\""sv)) { + if (ctx.empty()) { + // TODO: Add error production for unterminated string. + return {}; + } + if (ctx.startsWith("\\"sv)) { + // Escape sequences + ctx.ensureBuildingEscaped(); + ctx.take(1); + if (ctx.takePrefix("t"sv)) { + ctx.appendEscaped('\t'); + } else if (ctx.takePrefix("n"sv)) { + ctx.appendEscaped('\n'); + } else if (ctx.takePrefix("r"sv)) { + ctx.appendEscaped('\r'); + } else if (ctx.takePrefix("\\"sv)) { + ctx.appendEscaped('\\'); + } else if (ctx.takePrefix("\""sv)) { + ctx.appendEscaped('"'); + } else if (ctx.takePrefix("'"sv)) { + ctx.appendEscaped('\''); + } else if (ctx.takePrefix("u{"sv)) { + auto lexed = hexnum(ctx.next()); + if (!lexed) { + // TODO: Add error production for malformed unicode escapes. + return {}; + } + ctx.take(*lexed); + if (!ctx.takePrefix("}"sv)) { + // TODO: Add error production for malformed unicode escapes. + return {}; + } + if (!ctx.appendUnicode(lexed->n)) { + // TODO: Add error production for invalid unicode values. + return {}; + } + } else { + LexIntCtx ictx(ctx.next()); + if (!ictx.takeHexdigit() || !ictx.takeHexdigit()) { + // TODO: Add error production for unrecognized escape sequence. + return {}; + } + auto lexed = *ictx.lexed(); + ctx.take(lexed); + ctx.appendEscaped(char(lexed.n)); + } + } else { + // Normal characters + if (uint8_t c = ctx.peek(); c >= 0x20 && c != 0x7F) { + ctx.takeChar(); + } else { + // TODO: Add error production for unescaped control characters. + return {}; + } + } + } + return ctx.lexed(); +} + +// keyword ::= ( 'a' | ... | 'z' ) idchar* (if literal terminal in grammar) +// reserved ::= idchar+ +// +// The "keyword" token we lex here covers both keywords as well as any reserved +// tokens that match the keyword format. This saves us from having to enumerate +// all the valid keywords here. These invalid keywords will still produce +// errors, just at a higher level of the parser. +std::optional<LexResult> keyword(std::string_view in) { + LexCtx ctx(in); + if (ctx.empty()) { + return {}; + } + uint8_t start = ctx.peek(); + if ('a' <= start && start <= 'z') { + ctx.take(1); + } else { + return {}; + } + while (auto lexed = idchar(ctx.next())) { + ctx.take(*lexed); + } + return ctx.lexed(); +} + +} // anonymous namespace + +std::optional<uint64_t> Token::getU64() const { + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == NoSign) { + return tok->n; + } + } + return {}; +} + +std::optional<int64_t> Token::getS64() const { + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == Neg) { + if (uint64_t(INT64_MIN) <= tok->n || tok->n == 0) { + return int64_t(tok->n); + } + // TODO: Add error production for signed underflow. + } else { + if (tok->n <= uint64_t(INT64_MAX)) { + return int64_t(tok->n); + } + // TODO: Add error production for signed overflow. + } + } + return {}; +} + +std::optional<uint64_t> Token::getI64() const { + if (auto n = getU64()) { + return *n; + } + if (auto n = getS64()) { + return *n; + } + return {}; +} + +std::optional<uint32_t> Token::getU32() const { + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == NoSign && tok->n <= UINT32_MAX) { + return int32_t(tok->n); + } + // TODO: Add error production for unsigned overflow. + } + return {}; +} + +std::optional<int32_t> Token::getS32() const { + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == Neg) { + if (uint64_t(INT32_MIN) <= tok->n || tok->n == 0) { + return int32_t(tok->n); + } + } else { + if (tok->n <= uint64_t(INT32_MAX)) { + return int32_t(tok->n); + } + } + } + return {}; +} + +std::optional<uint32_t> Token::getI32() const { + if (auto n = getU32()) { + return *n; + } + if (auto n = getS32()) { + return uint32_t(*n); + } + return {}; +} + +std::optional<double> Token::getF64() const { + constexpr int signif = 52; + constexpr uint64_t payloadMask = (1ull << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + if (auto* tok = std::get_if<FloatTok>(&data)) { + double d = tok->d; + if (std::isnan(d)) { + // Inject payload. + uint64_t payload = tok->nanPayload ? *tok->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + return {}; + } + uint64_t bits; + static_assert(sizeof(bits) == sizeof(d)); + memcpy(&bits, &d, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&d, &bits, sizeof(bits)); + } + return d; + } + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == Neg) { + if (tok->n == 0) { + return -0.0; + } + return double(int64_t(tok->n)); + } + return double(tok->n); + } + return {}; +} + +std::optional<float> Token::getF32() const { + constexpr int signif = 23; + constexpr uint32_t payloadMask = (1u << signif) - 1; + constexpr uint64_t nanDefault = 1ull << (signif - 1); + if (auto* tok = std::get_if<FloatTok>(&data)) { + float f = tok->d; + if (std::isnan(f)) { + // Validate and inject payload. + uint64_t payload = tok->nanPayload ? *tok->nanPayload : nanDefault; + if (payload == 0 || payload > payloadMask) { + // TODO: Add error production for out-of-bounds payload. + return {}; + } + uint32_t bits; + static_assert(sizeof(bits) == sizeof(f)); + memcpy(&bits, &f, sizeof(bits)); + bits = (bits & ~payloadMask) | payload; + memcpy(&f, &bits, sizeof(bits)); + } + return f; + } + if (auto* tok = std::get_if<IntTok>(&data)) { + if (tok->sign == Neg) { + if (tok->n == 0) { + return -0.0f; + } + return float(int64_t(tok->n)); + } + return float(tok->n); + } + return {}; +} + +std::optional<std::string_view> Token::getString() const { + if (auto* tok = std::get_if<StringTok>(&data)) { + if (tok->str) { + return std::string_view(*tok->str); + } + return span.substr(1, span.size() - 2); + } + return {}; +} + +void Lexer::skipSpace() { + if (auto ctx = space(next())) { + index += ctx->span.size(); + } +} + +void Lexer::lexToken() { + // TODO: Ensure we're getting the longest possible match. + Token tok; + if (auto t = lparen(next())) { + tok = Token{t->span, LParenTok{}}; + } else if (auto t = rparen(next())) { + tok = Token{t->span, RParenTok{}}; + } else if (auto t = ident(next())) { + tok = Token{t->span, IdTok{}}; + } else if (auto t = integer(next())) { + tok = Token{t->span, IntTok{t->n, t->sign}}; + } else if (auto t = float_(next())) { + tok = Token{t->span, FloatTok{t->nanPayload, t->d}}; + } else if (auto t = str(next())) { + tok = Token{t->span, StringTok{t->str}}; + } else if (auto t = keyword(next())) { + tok = Token{t->span, KeywordTok{}}; + } else { + // TODO: Do something about lexing errors. + curr = std::nullopt; + return; + } + index += tok.span.size(); + curr = {tok}; +} + +TextPos Lexer::position(const char* c) const { + assert(size_t(c - buffer.data()) <= buffer.size()); + TextPos pos{1, 0}; + for (const char* p = buffer.data(); p != c; ++p) { + if (*p == '\n') { + pos.line++; + pos.col = 0; + } else { + pos.col++; + } + } + return pos; +} + +bool TextPos::operator==(const TextPos& other) const { + return line == other.line && col == other.col; +} + +bool IntTok::operator==(const IntTok& other) const { + return n == other.n && sign == other.sign; +} + +bool FloatTok::operator==(const FloatTok& other) const { + return std::signbit(d) == std::signbit(other.d) && + (d == other.d || (std::isnan(d) && std::isnan(other.d) && + nanPayload == other.nanPayload)); +} + +bool Token::operator==(const Token& other) const { + return span == other.span && + std::visit( + [](auto& t1, auto& t2) { + if constexpr (std::is_same_v<decltype(t1), decltype(t2)>) { + return t1 == t2; + } else { + return false; + } + }, + data, + other.data); +} + +std::ostream& operator<<(std::ostream& os, const TextPos& pos) { + return os << pos.line << ":" << pos.col; +} + +std::ostream& operator<<(std::ostream& os, const LParenTok&) { + return os << "'('"; +} + +std::ostream& operator<<(std::ostream& os, const RParenTok&) { + return os << "')'"; +} + +std::ostream& operator<<(std::ostream& os, const IdTok&) { return os << "id"; } + +std::ostream& operator<<(std::ostream& os, const IntTok& tok) { + return os << (tok.sign == Pos ? "+" : tok.sign == Neg ? "-" : "") << tok.n; +} + +std::ostream& operator<<(std::ostream& os, const FloatTok& tok) { + if (std::isnan(tok.d)) { + os << (std::signbit(tok.d) ? "+" : "-"); + if (tok.nanPayload) { + return os << "nan:0x" << std::hex << *tok.nanPayload << std::dec; + } + return os << "nan"; + } + return os << tok.d; +} + +std::ostream& operator<<(std::ostream& os, const StringTok& tok) { + if (tok.str) { + os << '"' << *tok.str << '"'; + } else { + os << "(raw string)"; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const KeywordTok&) { + return os << "keyword"; +} + +std::ostream& operator<<(std::ostream& os, const Token& tok) { + std::visit([&](const auto& t) { os << t; }, tok.data); + return os << " \"" << tok.span << "\""; +} + +} // namespace wasm::WATParser diff --git a/src/parser/lexer.h b/src/parser/lexer.h new file mode 100644 index 000000000..67d29b002 --- /dev/null +++ b/src/parser/lexer.h @@ -0,0 +1,227 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <iterator> +#include <optional> +#include <ostream> +#include <string_view> +#include <variant> + +#ifndef parser_lexer_h +#define parser_lexer_h + +namespace wasm::WATParser { + +struct TextPos { + size_t line; + size_t col; + + bool operator==(const TextPos& other) const; + bool operator!=(const TextPos& other) const { return !(*this == other); } + + friend std::ostream& operator<<(std::ostream& os, const TextPos& pos); +}; + +// ====== +// Tokens +// ====== + +struct LParenTok { + bool operator==(const LParenTok&) const { return true; } + friend std::ostream& operator<<(std::ostream&, const LParenTok&); +}; + +struct RParenTok { + bool operator==(const RParenTok&) const { return true; } + friend std::ostream& operator<<(std::ostream&, const RParenTok&); +}; + +struct IdTok { + bool operator==(const IdTok&) const { return true; } + friend std::ostream& operator<<(std::ostream&, const IdTok&); +}; + +enum Sign { NoSign, Pos, Neg }; + +struct IntTok { + uint64_t n; + Sign sign; + + bool operator==(const IntTok&) const; + friend std::ostream& operator<<(std::ostream&, const IntTok&); +}; + +struct FloatTok { + // The payload if we lexed a nan with payload. We cannot store the payload + // directly in `d` because we do not know at this point whether we are parsing + // an f32 or f64 and therefore we do not know what the allowable payloads are. + // No payload with NaN means to use the default payload for the expected float + // width. + std::optional<uint64_t> nanPayload; + double d; + + bool operator==(const FloatTok&) const; + friend std::ostream& operator<<(std::ostream&, const FloatTok&); +}; + +struct StringTok { + std::optional<std::string> str; + + bool operator==(const StringTok& other) const { return str == other.str; } + friend std::ostream& operator<<(std::ostream&, const StringTok&); +}; + +struct KeywordTok { + bool operator==(const KeywordTok&) const { return true; } + friend std::ostream& operator<<(std::ostream&, const KeywordTok&); +}; + +struct Token { + using Data = std::variant<LParenTok, + RParenTok, + IdTok, + IntTok, + FloatTok, + StringTok, + KeywordTok>; + std::string_view span; + Data data; + + // ==================== + // Token classification + // ==================== + + bool isLParen() const { return std::get_if<LParenTok>(&data); } + + bool isRParen() const { return std::get_if<RParenTok>(&data); } + + std::optional<std::string_view> getID() const { + if (std::get_if<IdTok>(&data)) { + // Drop leading '$'. + return span.substr(1); + } + return {}; + } + + std::optional<std::string_view> getKeyword() const { + if (std::get_if<KeywordTok>(&data)) { + return span; + } + return {}; + } + std::optional<uint64_t> getU64() const; + std::optional<int64_t> getS64() const; + std::optional<uint64_t> getI64() const; + std::optional<uint32_t> getU32() const; + std::optional<int32_t> getS32() const; + std::optional<uint32_t> getI32() const; + std::optional<double> getF64() const; + std::optional<float> getF32() const; + std::optional<std::string_view> getString() const; + + bool operator==(const Token&) const; + friend std::ostream& operator<<(std::ostream& os, const Token&); +}; + +// ===== +// Lexer +// ===== + +// Lexer's purpose is twofold. First, it wraps a buffer to provide a tokenizing +// iterator over it. Second, it implements that iterator itself. Also provides +// utilities for locating the text position of tokens within the buffer. Text +// positions are computed on demand rather than eagerly because they are +// typically only needed when there is an error to report. +struct Lexer { + using iterator = Lexer; + using difference_type = std::ptrdiff_t; + using value_type = Token; + using pointer = const Token*; + using reference = const Token&; + using iterator_category = std::forward_iterator_tag; + +private: + std::string_view buffer; + size_t index = 0; + std::optional<Token> curr; + +public: + // The end sentinel. + Lexer() = default; + + Lexer(std::string_view buffer) : buffer(buffer) { setIndex(0); } + + size_t getIndex() const { return index; } + + void setIndex(size_t i) { + index = i; + skipSpace(); + lexToken(); + } + + std::string_view next() const { return buffer.substr(index); } + Lexer& operator++() { + // Preincrement + skipSpace(); + lexToken(); + return *this; + } + + Lexer operator++(int) { + // Postincrement + Lexer ret = *this; + ++(*this); + return ret; + } + + const Token& operator*() { return *curr; } + const Token* operator->() { return &*curr; } + + bool operator==(const Lexer& other) const { + // The iterator is equal to the end sentinel when there is no current token. + if (!curr && !other.curr) { + return true; + } + // Otherwise they are equivalent when they are at the same position. + return index == other.index; + } + + bool operator!=(const Lexer& other) const { return !(*this == other); } + + Lexer begin() { return *this; } + + Lexer end() const { return Lexer(); } + + bool empty() const { return *this == end(); } + + TextPos position(const char* c) const; + TextPos position(size_t i) const { return position(buffer.data() + i); } + TextPos position(std::string_view span) const { + return position(span.data()); + } + TextPos position(Token tok) const { return position(tok.span); } + +private: + void skipSpace(); + void lexToken(); +}; + +} // namespace wasm::WATParser + +#endif // parser_lexer_h diff --git a/src/parser/parsers.h b/src/parser/parsers.h new file mode 100644 index 000000000..5f9f23a2a --- /dev/null +++ b/src/parser/parsers.h @@ -0,0 +1,2036 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef parser_parsers_h +#define parser_parsers_h + +#include "common.h" +#include "input.h" + +namespace wasm::WATParser { + +using namespace std::string_view_literals; + +// Types +template<typename Ctx> Result<typename Ctx::HeapTypeT> heaptype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::RefTypeT> reftype(Ctx&); +template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::ResultsT> results(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::SignatureT> functype(Ctx&); +template<typename Ctx> Result<typename Ctx::FieldT> storagetype(Ctx&); +template<typename Ctx> Result<typename Ctx::FieldT> fieldtype(Ctx&); +template<typename Ctx> Result<typename Ctx::FieldsT> fields(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::StructT> structtype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::ArrayT> arraytype(Ctx&); +template<typename Ctx> Result<typename Ctx::LimitsT> limits32(Ctx&); +template<typename Ctx> Result<typename Ctx::LimitsT> limits64(Ctx&); +template<typename Ctx> Result<typename Ctx::MemTypeT> memtype(Ctx&); +template<typename Ctx> Result<typename Ctx::GlobalTypeT> globaltype(Ctx&); + +// Instructions +template<typename Ctx> MaybeResult<typename Ctx::InstrT> foldedBlockinstr(Ctx&); +template<typename Ctx> +MaybeResult<typename Ctx::InstrT> unfoldedBlockinstr(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::InstrT> blockinstr(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::InstrT> plaininstr(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::InstrT> instr(Ctx&); +template<typename Ctx> Result<typename Ctx::InstrsT> instrs(Ctx&); +template<typename Ctx> Result<typename Ctx::ExprT> expr(Ctx&); +template<typename Ctx> Result<typename Ctx::MemargT> memarg(Ctx&, uint32_t); +template<typename Ctx> Result<typename Ctx::BlockTypeT> blocktype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::InstrT> block(Ctx&, bool); +template<typename Ctx> +Result<typename Ctx::InstrT> makeUnreachable(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeNop(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeBinary(Ctx&, Index, BinaryOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeUnary(Ctx&, Index, UnaryOp op); +template<typename Ctx> Result<typename Ctx::InstrT> makeSelect(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeDrop(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeMemorySize(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeMemoryGrow(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeLocalGet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeLocalTee(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeLocalSet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeGlobalGet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeGlobalSet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeBlock(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeThenOrElse(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeConst(Ctx&, Index, Type type); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeLoad(Ctx&, Index, Type type, bool signed_, int bytes, bool isAtomic); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStore(Ctx&, Index, Type type, int bytes, bool isAtomic); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeAtomicRMW(Ctx&, Index, AtomicRMWOp op, Type type, uint8_t bytes); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeAtomicCmpxchg(Ctx&, Index, Type type, uint8_t bytes); +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicWait(Ctx&, Index, Type type); +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicNotify(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicFence(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDExtract(Ctx&, Index, SIMDExtractOp op, size_t lanes); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDReplace(Ctx&, Index, SIMDReplaceOp op, size_t lanes); +template<typename Ctx> +Result<typename Ctx::InstrT> makeSIMDShuffle(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeSIMDTernary(Ctx&, Index, SIMDTernaryOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeSIMDShift(Ctx&, Index, SIMDShiftOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDLoad(Ctx&, Index, SIMDLoadOp op, int bytes); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDLoadStoreLane(Ctx&, Index, SIMDLoadStoreLaneOp op, int bytes); +template<typename Ctx> Result<typename Ctx::InstrT> makeMemoryInit(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeDataDrop(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeMemoryCopy(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeMemoryFill(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makePop(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeIf(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeMaybeBlock(Ctx&, Index, size_t i, Type type); +template<typename Ctx> Result<typename Ctx::InstrT> makeLoop(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeCall(Ctx&, Index, bool isReturn); +template<typename Ctx> +Result<typename Ctx::InstrT> makeCallIndirect(Ctx&, Index, bool isReturn); +template<typename Ctx> Result<typename Ctx::InstrT> makeBreak(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeBreakTable(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeReturn(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefNull(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefIsNull(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefFunc(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefEq(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTableGet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTableSet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTableSize(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTableGrow(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTableFill(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTry(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeTryOrCatchBody(Ctx&, Index, Type type, bool isTry); +template<typename Ctx> Result<typename Ctx::InstrT> makeThrow(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRethrow(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeTupleMake(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeTupleExtract(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeCallRef(Ctx&, Index, bool isReturn); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefI31(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeI31Get(Ctx&, Index, bool signed_); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefTest(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeRefCast(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeBrOnNull(Ctx&, Index, bool onFail = false); +template<typename Ctx> +Result<typename Ctx::InstrT> makeBrOnCast(Ctx&, Index, bool onFail = false); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStructNew(Ctx&, Index, bool default_); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStructGet(Ctx&, Index, bool signed_ = false); +template<typename Ctx> Result<typename Ctx::InstrT> makeStructSet(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNew(Ctx&, Index, bool default_); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewData(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewElem(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewFixed(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayGet(Ctx&, Index, bool signed_ = false); +template<typename Ctx> Result<typename Ctx::InstrT> makeArraySet(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeArrayLen(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeArrayCopy(Ctx&, Index); +template<typename Ctx> Result<typename Ctx::InstrT> makeArrayFill(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayInitData(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayInitElem(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefAs(Ctx&, Index, RefAsOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringNew(Ctx&, Index, StringNewOp op, bool try_); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringConst(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringMeasure(Ctx&, Index, StringMeasureOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringEncode(Ctx&, Index, StringEncodeOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringConcat(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringEq(Ctx&, Index, StringEqOp); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringAs(Ctx&, Index, StringAsOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringWTF8Advance(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringWTF16Get(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringIterNext(Ctx&, Index); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringIterMove(Ctx&, Index, StringIterMoveOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringSliceWTF(Ctx&, Index, StringSliceWTFOp op); +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringSliceIter(Ctx&, Index); + +// Modules +template<typename Ctx> MaybeResult<Index> maybeTypeidx(Ctx& ctx); +template<typename Ctx> Result<typename Ctx::HeapTypeT> typeidx(Ctx&); +template<typename Ctx> +Result<typename Ctx::FieldIdxT> fieldidx(Ctx&, typename Ctx::HeapTypeT); +template<typename Ctx> MaybeResult<typename Ctx::MemoryIdxT> maybeMemidx(Ctx&); +template<typename Ctx> Result<typename Ctx::MemoryIdxT> memidx(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::MemoryIdxT> maybeMemuse(Ctx&); +template<typename Ctx> Result<typename Ctx::GlobalIdxT> globalidx(Ctx&); +template<typename Ctx> Result<typename Ctx::LocalIdxT> localidx(Ctx&); +template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx&); +MaybeResult<ImportNames> inlineImport(ParseInput&); +Result<std::vector<Name>> inlineExports(ParseInput&); +template<typename Ctx> Result<> strtype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::ModuleNameT> subtype(Ctx&); +template<typename Ctx> MaybeResult<> deftype(Ctx&); +template<typename Ctx> MaybeResult<typename Ctx::LocalsT> locals(Ctx&); +template<typename Ctx> MaybeResult<> func(Ctx&); +template<typename Ctx> MaybeResult<> memory(Ctx&); +template<typename Ctx> MaybeResult<> global(Ctx&); +template<typename Ctx> Result<typename Ctx::DataStringT> datastring(Ctx&); +template<typename Ctx> MaybeResult<> data(Ctx&); +template<typename Ctx> MaybeResult<> modulefield(Ctx&); +template<typename Ctx> Result<> module(Ctx&); + +// ========= +// Utilities +// ========= + +// RAII utility for temporarily changing the parsing position of a parsing +// context. +template<typename Ctx> struct WithPosition { + Ctx& ctx; + Index original; + + WithPosition(Ctx& ctx, Index pos) : ctx(ctx), original(ctx.in.getPos()) { + ctx.in.lexer.setIndex(pos); + } + + ~WithPosition() { ctx.in.lexer.setIndex(original); } +}; + +// Deduction guide to satisfy -Wctad-maybe-unsupported. +template<typename Ctx> WithPosition(Ctx& ctx, Index) -> WithPosition<Ctx>; + +// ===== +// Types +// ===== + +// heaptype ::= x:typeidx => types[x] +// | 'func' => func +// | 'extern' => extern +template<typename Ctx> Result<typename Ctx::HeapTypeT> heaptype(Ctx& ctx) { + if (ctx.in.takeKeyword("func"sv)) { + return ctx.makeFunc(); + } + if (ctx.in.takeKeyword("any"sv)) { + return ctx.makeAny(); + } + if (ctx.in.takeKeyword("extern"sv)) { + return ctx.makeExtern(); + } + if (ctx.in.takeKeyword("eq"sv)) { + return ctx.makeEq(); + } + if (ctx.in.takeKeyword("i31"sv)) { + return ctx.makeI31(); + } + if (ctx.in.takeKeyword("struct"sv)) { + return ctx.makeStructType(); + } + if (ctx.in.takeKeyword("array"sv)) { + return ctx.makeArrayType(); + } + auto type = typeidx(ctx); + CHECK_ERR(type); + return *type; +} + +// reftype ::= 'funcref' => funcref +// | 'externref' => externref +// | 'anyref' => anyref +// | 'eqref' => eqref +// | 'i31ref' => i31ref +// | 'structref' => structref +// | 'arrayref' => arrayref +// | '(' ref null? t:heaptype ')' => ref null? t +template<typename Ctx> MaybeResult<typename Ctx::TypeT> reftype(Ctx& ctx) { + if (ctx.in.takeKeyword("funcref"sv)) { + return ctx.makeRefType(ctx.makeFunc(), Nullable); + } + if (ctx.in.takeKeyword("externref"sv)) { + return ctx.makeRefType(ctx.makeExtern(), Nullable); + } + if (ctx.in.takeKeyword("anyref"sv)) { + return ctx.makeRefType(ctx.makeAny(), Nullable); + } + if (ctx.in.takeKeyword("eqref"sv)) { + return ctx.makeRefType(ctx.makeEq(), Nullable); + } + if (ctx.in.takeKeyword("i31ref"sv)) { + return ctx.makeRefType(ctx.makeI31(), Nullable); + } + if (ctx.in.takeKeyword("structref"sv)) { + return ctx.makeRefType(ctx.makeStructType(), Nullable); + } + if (ctx.in.takeKeyword("arrayref"sv)) { + return ctx.in.err("arrayref not yet supported"); + } + + if (!ctx.in.takeSExprStart("ref"sv)) { + return {}; + } + + auto nullability = ctx.in.takeKeyword("null"sv) ? Nullable : NonNullable; + + auto type = heaptype(ctx); + CHECK_ERR(type); + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of reftype"); + } + + return ctx.makeRefType(*type, nullability); +} + +// numtype ::= 'i32' => i32 +// | 'i64' => i64 +// | 'f32' => f32 +// | 'f64' => f64 +// vectype ::= 'v128' => v128 +// valtype ::= t:numtype => t +// | t:vectype => t +// | t:reftype => t +template<typename Ctx> Result<typename Ctx::TypeT> valtype(Ctx& ctx) { + if (ctx.in.takeKeyword("i32"sv)) { + return ctx.makeI32(); + } else if (ctx.in.takeKeyword("i64"sv)) { + return ctx.makeI64(); + } else if (ctx.in.takeKeyword("f32"sv)) { + return ctx.makeF32(); + } else if (ctx.in.takeKeyword("f64"sv)) { + return ctx.makeF64(); + } else if (ctx.in.takeKeyword("v128"sv)) { + return ctx.makeV128(); + } else if (auto type = reftype(ctx)) { + CHECK_ERR(type); + return *type; + } else { + return ctx.in.err("expected valtype"); + } +} + +// param ::= '(' 'param id? t:valtype ')' => [t] +// | '(' 'param t*:valtype* ')' => [t*] +// params ::= param* +template<typename Ctx> MaybeResult<typename Ctx::ParamsT> params(Ctx& ctx) { + bool hasAny = false; + auto res = ctx.makeParams(); + while (ctx.in.takeSExprStart("param"sv)) { + hasAny = true; + if (auto id = ctx.in.takeID()) { + // Single named param + auto type = valtype(ctx); + CHECK_ERR(type); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of param"); + } + ctx.appendParam(res, *id, *type); + } else { + // Repeated unnamed params + while (!ctx.in.takeRParen()) { + auto type = valtype(ctx); + CHECK_ERR(type); + ctx.appendParam(res, {}, *type); + } + } + } + if (hasAny) { + return res; + } + return {}; +} + +// result ::= '(' 'result' t*:valtype ')' => [t*] +// results ::= result* +template<typename Ctx> MaybeResult<typename Ctx::ResultsT> results(Ctx& ctx) { + bool hasAny = false; + auto res = ctx.makeResults(); + while (ctx.in.takeSExprStart("result"sv)) { + hasAny = true; + while (!ctx.in.takeRParen()) { + auto type = valtype(ctx); + CHECK_ERR(type); + ctx.appendResult(res, *type); + } + } + if (hasAny) { + return res; + } + return {}; +} + +// functype ::= '(' 'func' t1*:vec(param) t2*:vec(result) ')' => [t1*] -> [t2*] +template<typename Ctx> +MaybeResult<typename Ctx::SignatureT> functype(Ctx& ctx) { + if (!ctx.in.takeSExprStart("func"sv)) { + return {}; + } + + auto parsedParams = params(ctx); + CHECK_ERR(parsedParams); + + auto parsedResults = results(ctx); + CHECK_ERR(parsedResults); + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of functype"); + } + + return ctx.makeFuncType(parsedParams.getPtr(), parsedResults.getPtr()); +} + +// storagetype ::= valtype | packedtype +// packedtype ::= i8 | i16 +template<typename Ctx> Result<typename Ctx::FieldT> storagetype(Ctx& ctx) { + if (ctx.in.takeKeyword("i8"sv)) { + return ctx.makeI8(); + } + if (ctx.in.takeKeyword("i16"sv)) { + return ctx.makeI16(); + } + auto type = valtype(ctx); + CHECK_ERR(type); + return ctx.makeStorageType(*type); +} + +// fieldtype ::= t:storagetype => const t +// | '(' 'mut' t:storagetype ')' => var t +template<typename Ctx> Result<typename Ctx::FieldT> fieldtype(Ctx& ctx) { + auto mutability = Immutable; + if (ctx.in.takeSExprStart("mut"sv)) { + mutability = Mutable; + } + + auto field = storagetype(ctx); + CHECK_ERR(field); + + if (mutability == Mutable) { + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of field type"); + } + } + + return ctx.makeFieldType(*field, mutability); +} + +// field ::= '(' 'field' id t:fieldtype ')' => [(id, t)] +// | '(' 'field' t*:fieldtype* ')' => [(_, t*)*] +// | fieldtype +template<typename Ctx> Result<typename Ctx::FieldsT> fields(Ctx& ctx) { + auto res = ctx.makeFields(); + while (true) { + if (auto t = ctx.in.peek(); !t || t->isRParen()) { + return res; + } + if (ctx.in.takeSExprStart("field")) { + if (auto id = ctx.in.takeID()) { + auto field = fieldtype(ctx); + CHECK_ERR(field); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of field"); + } + ctx.appendField(res, *id, *field); + } else { + while (!ctx.in.takeRParen()) { + auto field = fieldtype(ctx); + CHECK_ERR(field); + ctx.appendField(res, {}, *field); + } + } + } else { + auto field = fieldtype(ctx); + CHECK_ERR(field); + ctx.appendField(res, {}, *field); + } + } +} + +// structtype ::= '(' 'struct' field* ')' +template<typename Ctx> MaybeResult<typename Ctx::StructT> structtype(Ctx& ctx) { + if (!ctx.in.takeSExprStart("struct"sv)) { + return {}; + } + auto namedFields = fields(ctx); + CHECK_ERR(namedFields); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of struct definition"); + } + + return ctx.makeStruct(*namedFields); +} + +// arraytype ::= '(' 'array' field ')' +template<typename Ctx> MaybeResult<typename Ctx::ArrayT> arraytype(Ctx& ctx) { + if (!ctx.in.takeSExprStart("array"sv)) { + return {}; + } + auto namedFields = fields(ctx); + CHECK_ERR(namedFields); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of array definition"); + } + + if (auto array = ctx.makeArray(*namedFields)) { + return *array; + } + return ctx.in.err("expected exactly one field in array definition"); +} + +// limits32 ::= n:u32 m:u32? +template<typename Ctx> Result<typename Ctx::LimitsT> limits32(Ctx& ctx) { + auto n = ctx.in.takeU32(); + if (!n) { + return ctx.in.err("expected initial size"); + } + std::optional<uint64_t> m = ctx.in.takeU32(); + return ctx.makeLimits(uint64_t(*n), m); +} + +// limits64 ::= n:u64 m:u64? +template<typename Ctx> Result<typename Ctx::LimitsT> limits64(Ctx& ctx) { + auto n = ctx.in.takeU64(); + if (!n) { + return ctx.in.err("expected initial size"); + } + std::optional<uint64_t> m = ctx.in.takeU64(); + return ctx.makeLimits(uint64_t(*n), m); +} + +// memtype ::= (limits32 | 'i32' limits32 | 'i64' limit64) shared? +template<typename Ctx> Result<typename Ctx::MemTypeT> memtype(Ctx& ctx) { + auto type = Type::i32; + if (ctx.in.takeKeyword("i64"sv)) { + type = Type::i64; + } else { + ctx.in.takeKeyword("i32"sv); + } + auto limits = type == Type::i32 ? limits32(ctx) : limits64(ctx); + CHECK_ERR(limits); + bool shared = false; + if (ctx.in.takeKeyword("shared"sv)) { + shared = true; + } + return ctx.makeMemType(type, *limits, shared); +} + +// globaltype ::= t:valtype => const t +// | '(' 'mut' t:valtype ')' => var t +template<typename Ctx> Result<typename Ctx::GlobalTypeT> globaltype(Ctx& ctx) { + auto mutability = Immutable; + if (ctx.in.takeSExprStart("mut"sv)) { + mutability = Mutable; + } + + auto type = valtype(ctx); + CHECK_ERR(type); + + if (mutability == Mutable && !ctx.in.takeRParen()) { + return ctx.in.err("expected end of globaltype"); + } + + return ctx.makeGlobalType(mutability, *type); +} + +// ============ +// Instructions +// ============ + +// blockinstr ::= block | loop | if-else | try-catch +template<typename Ctx> +MaybeResult<typename Ctx::InstrT> foldedBlockinstr(Ctx& ctx) { + if (auto i = block(ctx, true)) { + return i; + } + // TODO: Other block instructions + return {}; +} + +template<typename Ctx> +MaybeResult<typename Ctx::InstrT> unfoldedBlockinstr(Ctx& ctx) { + if (auto i = block(ctx, false)) { + return i; + } + // TODO: Other block instructions + return {}; +} + +template<typename Ctx> MaybeResult<typename Ctx::InstrT> blockinstr(Ctx& ctx) { + if (auto i = foldedBlockinstr(ctx)) { + return i; + } + if (auto i = unfoldedBlockinstr(ctx)) { + return i; + } + return {}; +} + +// plaininstr ::= ... all plain instructions ... +template<typename Ctx> MaybeResult<typename Ctx::InstrT> plaininstr(Ctx& ctx) { + auto pos = ctx.in.getPos(); + auto keyword = ctx.in.takeKeyword(); + if (!keyword) { + return {}; + } + +#define NEW_INSTRUCTION_PARSER +#define NEW_WAT_PARSER +#include <gen-s-parser.inc> +} + +// instr ::= plaininstr | blockinstr +template<typename Ctx> MaybeResult<typename Ctx::InstrT> instr(Ctx& ctx) { + // Check for valid strings that are not instructions. + if (auto tok = ctx.in.peek()) { + if (auto keyword = tok->getKeyword()) { + if (keyword == "end"sv) { + return {}; + } + } + } + if (auto i = blockinstr(ctx)) { + return i; + } + if (auto i = plaininstr(ctx)) { + return i; + } + // TODO: Handle folded plain instructions as well. + return {}; +} + +template<typename Ctx> Result<typename Ctx::InstrsT> instrs(Ctx& ctx) { + auto insts = ctx.makeInstrs(); + + while (true) { + if (auto blockinst = foldedBlockinstr(ctx)) { + CHECK_ERR(blockinst); + ctx.appendInstr(insts, *blockinst); + continue; + } + // Parse an arbitrary number of folded instructions. + if (ctx.in.takeLParen()) { + // A stack of (start, end) position pairs defining the positions of + // instructions that need to be parsed after their folded children. + std::vector<std::pair<Index, std::optional<Index>>> foldedInstrs; + + // Begin a folded instruction. Push its start position and a placeholder + // end position. + foldedInstrs.push_back({ctx.in.getPos(), {}}); + while (!foldedInstrs.empty()) { + // Consume everything up to the next paren. This span will be parsed as + // an instruction later after its folded children have been parsed. + if (!ctx.in.takeUntilParen()) { + return ctx.in.err(foldedInstrs.back().first, + "unterminated folded instruction"); + } + + if (!foldedInstrs.back().second) { + // The folded instruction we just started should end here. + foldedInstrs.back().second = ctx.in.getPos(); + } + + // We have either the start of a new folded child or the end of the last + // one. + if (auto blockinst = foldedBlockinstr(ctx)) { + CHECK_ERR(blockinst); + ctx.appendInstr(insts, *blockinst); + } else if (ctx.in.takeLParen()) { + foldedInstrs.push_back({ctx.in.getPos(), {}}); + } else if (ctx.in.takeRParen()) { + auto [start, end] = foldedInstrs.back(); + assert(end && "Should have found end of instruction"); + foldedInstrs.pop_back(); + + WithPosition with(ctx, start); + if (auto inst = plaininstr(ctx)) { + CHECK_ERR(inst); + ctx.appendInstr(insts, *inst); + } else { + return ctx.in.err(start, "expected folded instruction"); + } + + if (ctx.in.getPos() != *end) { + return ctx.in.err("expected end of instruction"); + } + } else { + WASM_UNREACHABLE("expected paren"); + } + } + continue; + } + + // A non-folded instruction. + if (auto inst = instr(ctx)) { + CHECK_ERR(inst); + ctx.appendInstr(insts, *inst); + } else { + break; + } + } + + return ctx.finishInstrs(insts); +} + +template<typename Ctx> Result<typename Ctx::ExprT> expr(Ctx& ctx) { + auto insts = instrs(ctx); + CHECK_ERR(insts); + return ctx.makeExpr(*insts); +} + +// memarg_n ::= o:offset a:align_n +// offset ::= 'offset='o:u64 => o | _ => 0 +// align_n ::= 'align='a:u32 => a | _ => n +template<typename Ctx> +Result<typename Ctx::MemargT> memarg(Ctx& ctx, uint32_t n) { + uint64_t offset = 0; + uint32_t align = n; + if (auto o = ctx.in.takeOffset()) { + offset = *o; + } + if (auto a = ctx.in.takeAlign()) { + align = *a; + } + return ctx.getMemarg(offset, align); +} + +// blocktype ::= (t:result)? => t? | x,I:typeuse => x if I = {} +template<typename Ctx> Result<typename Ctx::BlockTypeT> blocktype(Ctx& ctx) { + auto pos = ctx.in.getPos(); + + if (auto res = results(ctx)) { + CHECK_ERR(res); + if (ctx.getResultsSize(*res) == 1) { + return ctx.getBlockTypeFromResult(*res); + } + } + + // We either had no results or multiple results. Reset and parse again as a + // type use. + ctx.in.lexer.setIndex(pos); + auto use = typeuse(ctx); + CHECK_ERR(use); + + auto type = ctx.getBlockTypeFromTypeUse(pos, *use); + CHECK_ERR(type); + return *type; +} + +// block ::= 'block' label blocktype instr* 'end' id? if id = {} or id = label +// | '(' 'block' label blocktype instr* ')' +template<typename Ctx> +MaybeResult<typename Ctx::InstrT> block(Ctx& ctx, bool folded) { + auto pos = ctx.in.getPos(); + + if (folded) { + if (!ctx.in.takeSExprStart("block"sv)) { + return {}; + } + } else { + if (!ctx.in.takeKeyword("block"sv)) { + return {}; + } + } + + auto label = ctx.in.takeID(); + + auto type = blocktype(ctx); + CHECK_ERR(type); + + ctx.makeBlock(pos, label, *type); + + auto insts = instrs(ctx); + CHECK_ERR(insts); + + if (folded) { + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected ')' at end of block"); + } + } else { + if (!ctx.in.takeKeyword("end"sv)) { + return ctx.in.err("expected 'end' at end of block"); + } + auto id = ctx.in.takeID(); + if (id && id != label) { + return ctx.in.err("end label does not match block label"); + } + } + + return ctx.finishBlock(pos, std::move(*insts)); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeUnreachable(Ctx& ctx, Index pos) { + return ctx.makeUnreachable(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeNop(Ctx& ctx, Index pos) { + return ctx.makeNop(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBinary(Ctx& ctx, Index pos, BinaryOp op) { + return ctx.makeBinary(pos, op); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeUnary(Ctx& ctx, Index pos, UnaryOp op) { + return ctx.makeUnary(pos, op); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeSelect(Ctx& ctx, Index pos) { + auto res = results(ctx); + CHECK_ERR(res); + return ctx.makeSelect(pos, res.getPtr()); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeDrop(Ctx& ctx, Index pos) { + return ctx.makeDrop(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeMemorySize(Ctx& ctx, Index pos) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + return ctx.makeMemorySize(pos, mem.getPtr()); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeMemoryGrow(Ctx& ctx, Index pos) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + return ctx.makeMemoryGrow(pos, mem.getPtr()); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeLocalGet(Ctx& ctx, Index pos) { + auto local = localidx(ctx); + CHECK_ERR(local); + return ctx.makeLocalGet(pos, *local); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeLocalTee(Ctx& ctx, Index pos) { + auto local = localidx(ctx); + CHECK_ERR(local); + return ctx.makeLocalTee(pos, *local); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeLocalSet(Ctx& ctx, Index pos) { + auto local = localidx(ctx); + CHECK_ERR(local); + return ctx.makeLocalSet(pos, *local); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeGlobalGet(Ctx& ctx, Index pos) { + auto global = globalidx(ctx); + CHECK_ERR(global); + return ctx.makeGlobalGet(pos, *global); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeGlobalSet(Ctx& ctx, Index pos) { + auto global = globalidx(ctx); + CHECK_ERR(global); + return ctx.makeGlobalSet(pos, *global); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBlock(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeThenOrElse(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeConst(Ctx& ctx, Index pos, Type type) { + assert(type.isBasic()); + switch (type.getBasic()) { + case Type::i32: + if (auto c = ctx.in.takeI32()) { + return ctx.makeI32Const(pos, *c); + } + return ctx.in.err("expected i32"); + case Type::i64: + if (auto c = ctx.in.takeI64()) { + return ctx.makeI64Const(pos, *c); + } + return ctx.in.err("expected i64"); + case Type::f32: + if (auto c = ctx.in.takeF32()) { + return ctx.makeF32Const(pos, *c); + } + return ctx.in.err("expected f32"); + case Type::f64: + if (auto c = ctx.in.takeF64()) { + return ctx.makeF64Const(pos, *c); + } + return ctx.in.err("expected f64"); + case Type::v128: + return ctx.in.err("unimplemented instruction"); + case Type::none: + case Type::unreachable: + break; + } + WASM_UNREACHABLE("unexpected type"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeLoad( + Ctx& ctx, Index pos, Type type, bool signed_, int bytes, bool isAtomic) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + return ctx.makeLoad(pos, type, signed_, bytes, isAtomic, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStore(Ctx& ctx, Index pos, Type type, int bytes, bool isAtomic) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + return ctx.makeStore(pos, type, bytes, isAtomic, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeAtomicRMW(Ctx& ctx, Index pos, AtomicRMWOp op, Type type, uint8_t bytes) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + return ctx.makeAtomicRMW(pos, op, type, bytes, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeAtomicCmpxchg(Ctx& ctx, Index pos, Type type, uint8_t bytes) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + return ctx.makeAtomicCmpxchg(pos, type, bytes, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicWait(Ctx& ctx, Index pos, Type type) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, type == Type::i32 ? 4 : 8); + CHECK_ERR(arg); + return ctx.makeAtomicWait(pos, type, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicNotify(Ctx& ctx, Index pos) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, 4); + CHECK_ERR(arg); + return ctx.makeAtomicNotify(pos, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeAtomicFence(Ctx& ctx, Index pos) { + return ctx.makeAtomicFence(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDExtract(Ctx& ctx, Index pos, SIMDExtractOp op, size_t) { + auto lane = ctx.in.takeU8(); + if (!lane) { + return ctx.in.err("expected lane index"); + } + return ctx.makeSIMDExtract(pos, op, *lane); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDReplace(Ctx& ctx, Index pos, SIMDReplaceOp op, size_t lanes) { + auto lane = ctx.in.takeU8(); + if (!lane) { + return ctx.in.err("expected lane index"); + } + return ctx.makeSIMDReplace(pos, op, *lane); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeSIMDShuffle(Ctx& ctx, Index pos) { + std::array<uint8_t, 16> lanes; + for (int i = 0; i < 16; ++i) { + auto lane = ctx.in.takeU8(); + if (!lane) { + return ctx.in.err("expected lane index"); + } + lanes[i] = *lane; + } + return ctx.makeSIMDShuffle(pos, lanes); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDTernary(Ctx& ctx, Index pos, SIMDTernaryOp op) { + return ctx.makeSIMDTernary(pos, op); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDShift(Ctx& ctx, Index pos, SIMDShiftOp op) { + return ctx.makeSIMDShift(pos, op); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDLoad(Ctx& ctx, Index pos, SIMDLoadOp op, int bytes) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + return ctx.makeSIMDLoad(pos, op, mem.getPtr(), *arg); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeSIMDLoadStoreLane(Ctx& ctx, Index pos, SIMDLoadStoreLaneOp op, int bytes) { + auto reset = ctx.in.getPos(); + + auto retry = [&]() -> Result<typename Ctx::InstrT> { + // We failed to parse. Maybe the lane index was accidentally parsed as the + // optional memory index. Try again without parsing a memory index. + WithPosition with(ctx, reset); + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + auto lane = ctx.in.takeU8(); + if (!lane) { + return ctx.in.err("expected lane index"); + } + return ctx.makeSIMDLoadStoreLane(pos, op, nullptr, *arg, *lane); + }; + + auto mem = maybeMemidx(ctx); + if (mem.getErr()) { + return retry(); + } + auto arg = memarg(ctx, bytes); + CHECK_ERR(arg); + auto lane = ctx.in.takeU8(); + if (!lane) { + return retry(); + } + return ctx.makeSIMDLoadStoreLane(pos, op, mem.getPtr(), *arg, *lane); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeMemoryInit(Ctx& ctx, Index pos) { + auto reset = ctx.in.getPos(); + + auto retry = [&]() -> Result<typename Ctx::InstrT> { + // We failed to parse. Maybe the data index was accidentally parsed as the + // optional memory index. Try again without parsing a memory index. + WithPosition with(ctx, reset); + auto data = dataidx(ctx); + CHECK_ERR(data); + return ctx.makeMemoryInit(pos, nullptr, *data); + }; + + auto mem = maybeMemidx(ctx); + if (mem.getErr()) { + return retry(); + } + auto data = dataidx(ctx); + if (data.getErr()) { + return retry(); + } + return ctx.makeMemoryInit(pos, mem.getPtr(), *data); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeDataDrop(Ctx& ctx, Index pos) { + auto data = dataidx(ctx); + CHECK_ERR(data); + return ctx.makeDataDrop(pos, *data); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeMemoryCopy(Ctx& ctx, Index pos) { + auto destMem = maybeMemidx(ctx); + CHECK_ERR(destMem); + std::optional<typename Ctx::MemoryIdxT> srcMem = std::nullopt; + if (destMem) { + auto mem = memidx(ctx); + CHECK_ERR(mem); + srcMem = *mem; + } + return ctx.makeMemoryCopy(pos, destMem.getPtr(), srcMem ? &*srcMem : nullptr); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeMemoryFill(Ctx& ctx, Index pos) { + auto mem = maybeMemidx(ctx); + CHECK_ERR(mem); + return ctx.makeMemoryFill(pos, mem.getPtr()); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makePop(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeIf(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeMaybeBlock(Ctx& ctx, Index pos, size_t i, Type type) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeLoop(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeCall(Ctx& ctx, Index pos, bool isReturn) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeCallIndirect(Ctx& ctx, Index pos, bool isReturn) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBreak(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBreakTable(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeReturn(Ctx& ctx, Index pos) { + return ctx.makeReturn(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefNull(Ctx& ctx, Index pos) { + auto t = heaptype(ctx); + CHECK_ERR(t); + return ctx.makeRefNull(pos, *t); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefIsNull(Ctx& ctx, Index pos) { + return ctx.makeRefIsNull(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefFunc(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefEq(Ctx& ctx, Index pos) { + return ctx.makeRefEq(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTableGet(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTableSet(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTableSize(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTableGrow(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTableFill(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTry(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeTryOrCatchBody(Ctx& ctx, Index pos, Type type, bool isTry) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeThrow(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRethrow(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTupleMake(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeTupleExtract(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeCallRef(Ctx& ctx, Index pos, bool isReturn) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefI31(Ctx& ctx, Index pos) { + return ctx.makeRefI31(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeI31Get(Ctx& ctx, Index pos, bool signed_) { + return ctx.makeI31Get(pos, signed_); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefTest(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefCast(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBrOnNull(Ctx& ctx, Index pos, bool onFail) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeBrOnCast(Ctx& ctx, Index pos, bool onFail) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStructNew(Ctx& ctx, Index pos, bool default_) { + auto type = typeidx(ctx); + CHECK_ERR(type); + if (default_) { + return ctx.makeStructNewDefault(pos, *type); + } + return ctx.makeStructNew(pos, *type); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStructGet(Ctx& ctx, Index pos, bool signed_) { + auto type = typeidx(ctx); + CHECK_ERR(type); + auto field = fieldidx(ctx, *type); + CHECK_ERR(field); + return ctx.makeStructGet(pos, *type, *field, signed_); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStructSet(Ctx& ctx, Index pos) { + auto type = typeidx(ctx); + CHECK_ERR(type); + auto field = fieldidx(ctx, *type); + CHECK_ERR(field); + return ctx.makeStructSet(pos, *type, *field); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNew(Ctx& ctx, Index pos, bool default_) { + auto type = typeidx(ctx); + CHECK_ERR(type); + if (default_) { + return ctx.makeArrayNewDefault(pos, *type); + } + return ctx.makeArrayNew(pos, *type); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewData(Ctx& ctx, Index pos) { + auto type = typeidx(ctx); + CHECK_ERR(type); + auto data = dataidx(ctx); + CHECK_ERR(data); + return ctx.makeArrayNewData(pos, *type, *data); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewElem(Ctx& ctx, Index pos) { + auto type = typeidx(ctx); + CHECK_ERR(type); + auto data = dataidx(ctx); + CHECK_ERR(data); + return ctx.makeArrayNewElem(pos, *type, *data); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayNewFixed(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayGet(Ctx& ctx, Index pos, bool signed_) { + auto type = typeidx(ctx); + CHECK_ERR(type); + return ctx.makeArrayGet(pos, *type, signed_); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArraySet(Ctx& ctx, Index pos) { + auto type = typeidx(ctx); + CHECK_ERR(type); + return ctx.makeArraySet(pos, *type); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayLen(Ctx& ctx, Index pos) { + return ctx.makeArrayLen(pos); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayCopy(Ctx& ctx, Index pos) { + auto destType = typeidx(ctx); + CHECK_ERR(destType); + auto srcType = typeidx(ctx); + CHECK_ERR(srcType); + return ctx.makeArrayCopy(pos, *destType, *srcType); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayFill(Ctx& ctx, Index pos) { + auto type = typeidx(ctx); + CHECK_ERR(type); + return ctx.makeArrayFill(pos, *type); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayInitData(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeArrayInitElem(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeRefAs(Ctx& ctx, Index pos, RefAsOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringNew(Ctx& ctx, Index pos, StringNewOp op, bool try_) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringConst(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringMeasure(Ctx& ctx, Index pos, StringMeasureOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringEncode(Ctx& ctx, Index pos, StringEncodeOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringConcat(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringEq(Ctx& ctx, Index pos, StringEqOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringAs(Ctx& ctx, Index pos, StringAsOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringWTF8Advance(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringWTF16Get(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringIterNext(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringIterMove(Ctx& ctx, Index pos, StringIterMoveOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> +makeStringSliceWTF(Ctx& ctx, Index pos, StringSliceWTFOp op) { + return ctx.in.err("unimplemented instruction"); +} + +template<typename Ctx> +Result<typename Ctx::InstrT> makeStringSliceIter(Ctx& ctx, Index pos) { + return ctx.in.err("unimplemented instruction"); +} + +// ======= +// Modules +// ======= + +// typeidx ::= x:u32 => x +// | v:id => x (if types[x] = v) +template<typename Ctx> MaybeResult<Index> maybeTypeidx(Ctx& ctx) { + if (auto x = ctx.in.takeU32()) { + return *x; + } + if (auto id = ctx.in.takeID()) { + // TODO: Fix position to point to start of id, not next element. + auto idx = ctx.getTypeIndex(*id); + CHECK_ERR(idx); + return *idx; + } + return {}; +} + +template<typename Ctx> Result<typename Ctx::HeapTypeT> typeidx(Ctx& ctx) { + if (auto idx = maybeTypeidx(ctx)) { + CHECK_ERR(idx); + return ctx.getHeapTypeFromIdx(*idx); + } + return ctx.in.err("expected type index or identifier"); +} + +// fieldidx_t ::= x:u32 => x +// | v:id => x (if t.fields[x] = v) +template<typename Ctx> +Result<typename Ctx::FieldIdxT> fieldidx(Ctx& ctx, + typename Ctx::HeapTypeT type) { + if (auto x = ctx.in.takeU32()) { + return ctx.getFieldFromIdx(type, *x); + } + if (auto id = ctx.in.takeID()) { + return ctx.getFieldFromName(type, *id); + } + return ctx.in.err("expected field index or identifier"); +} + +// memidx ::= x:u32 => x +// | v:id => x (if memories[x] = v) +template<typename Ctx> +MaybeResult<typename Ctx::MemoryIdxT> maybeMemidx(Ctx& ctx) { + if (auto x = ctx.in.takeU32()) { + return ctx.getMemoryFromIdx(*x); + } + if (auto id = ctx.in.takeID()) { + return ctx.getMemoryFromName(*id); + } + return {}; +} + +template<typename Ctx> Result<typename Ctx::MemoryIdxT> memidx(Ctx& ctx) { + if (auto idx = maybeMemidx(ctx)) { + CHECK_ERR(idx); + return *idx; + } + return ctx.in.err("expected memory index or identifier"); +} + +// memuse ::= '(' 'memory' x:memidx ')' => x +template<typename Ctx> +MaybeResult<typename Ctx::MemoryIdxT> maybeMemuse(Ctx& ctx) { + if (!ctx.in.takeSExprStart("memory"sv)) { + return {}; + } + auto idx = memidx(ctx); + CHECK_ERR(idx); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of memory use"); + } + return *idx; +} + +// globalidx ::= x:u32 => x +// | v:id => x (if globals[x] = v) +template<typename Ctx> Result<typename Ctx::GlobalIdxT> globalidx(Ctx& ctx) { + if (auto x = ctx.in.takeU32()) { + return ctx.getGlobalFromIdx(*x); + } + if (auto id = ctx.in.takeID()) { + return ctx.getGlobalFromName(*id); + } + return ctx.in.err("expected global index or identifier"); +} + +// dataidx ::= x:u32 => x +// | v:id => x (if data[x] = v) +template<typename Ctx> Result<typename Ctx::DataIdxT> dataidx(Ctx& ctx) { + if (auto x = ctx.in.takeU32()) { + return ctx.getDataFromIdx(*x); + } + if (auto id = ctx.in.takeID()) { + return ctx.getDataFromName(*id); + } + return ctx.in.err("expected data index or identifier"); +} + +// localidx ::= x:u32 => x +// | v:id => x (if locals[x] = v) +template<typename Ctx> Result<typename Ctx::LocalIdxT> localidx(Ctx& ctx) { + if (auto x = ctx.in.takeU32()) { + return ctx.getLocalFromIdx(*x); + } + if (auto id = ctx.in.takeID()) { + return ctx.getLocalFromName(*id); + } + return ctx.in.err("expected local index or identifier"); +} + +// typeuse ::= '(' 'type' x:typeidx ')' => x, [] +// (if typedefs[x] = [t1*] -> [t2*] +// | '(' 'type' x:typeidx ')' ((t1,IDs):param)* (t2:result)* => x, IDs +// (if typedefs[x] = [t1*] -> [t2*]) +// | ((t1,IDs):param)* (t2:result)* => x, IDs +// (if x is minimum s.t. typedefs[x] = [t1*] -> [t2*]) +template<typename Ctx> Result<typename Ctx::TypeUseT> typeuse(Ctx& ctx) { + auto pos = ctx.in.getPos(); + std::optional<typename Ctx::HeapTypeT> type; + if (ctx.in.takeSExprStart("type"sv)) { + auto x = typeidx(ctx); + CHECK_ERR(x); + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of type use"); + } + + type = *x; + } + + auto namedParams = params(ctx); + CHECK_ERR(namedParams); + + auto resultTypes = results(ctx); + CHECK_ERR(resultTypes); + + return ctx.makeTypeUse(pos, type, namedParams.getPtr(), resultTypes.getPtr()); +} + +// ('(' 'import' mod:name nm:name ')')? +MaybeResult<ImportNames> inlineImport(ParseInput& in) { + if (!in.takeSExprStart("import"sv)) { + return {}; + } + auto mod = in.takeName(); + if (!mod) { + return in.err("expected import module"); + } + auto nm = in.takeName(); + if (!nm) { + return in.err("expected import name"); + } + if (!in.takeRParen()) { + return in.err("expected end of import"); + } + // TODO: Return Ok when parsing Decls. + return {{*mod, *nm}}; +} + +// ('(' 'export' name ')')* +Result<std::vector<Name>> inlineExports(ParseInput& in) { + std::vector<Name> exports; + while (in.takeSExprStart("export"sv)) { + auto name = in.takeName(); + if (!name) { + return in.err("expected export name"); + } + if (!in.takeRParen()) { + return in.err("expected end of import"); + } + exports.push_back(*name); + } + return exports; +} + +// strtype ::= ft:functype => ft +// | st:structtype => st +// | at:arraytype => at +template<typename Ctx> Result<> strtype(Ctx& ctx) { + if (auto type = functype(ctx)) { + CHECK_ERR(type); + ctx.addFuncType(*type); + return Ok{}; + } + if (auto type = structtype(ctx)) { + CHECK_ERR(type); + ctx.addStructType(*type); + return Ok{}; + } + if (auto type = arraytype(ctx)) { + CHECK_ERR(type); + ctx.addArrayType(*type); + return Ok{}; + } + return ctx.in.err("expected type description"); +} + +// subtype ::= '(' 'type' id? '(' 'sub' typeidx? strtype ')' ')' +// | '(' 'type' id? strtype ')' +template<typename Ctx> MaybeResult<> subtype(Ctx& ctx) { + auto pos = ctx.in.getPos(); + + if (!ctx.in.takeSExprStart("type"sv)) { + return {}; + } + + Name name; + if (auto id = ctx.in.takeID()) { + name = *id; + } + + if (ctx.in.takeSExprStart("sub"sv)) { + if (ctx.in.takeKeyword("open"sv)) { + ctx.setOpen(); + } + if (auto super = maybeTypeidx(ctx)) { + CHECK_ERR(super); + CHECK_ERR(ctx.addSubtype(*super)); + } + + CHECK_ERR(strtype(ctx)); + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of subtype definition"); + } + } else { + CHECK_ERR(strtype(ctx)); + } + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of type definition"); + } + + ctx.finishSubtype(name, pos); + return Ok{}; +} + +// deftype ::= '(' 'rec' subtype* ')' +// | subtype +template<typename Ctx> MaybeResult<> deftype(Ctx& ctx) { + auto pos = ctx.in.getPos(); + + if (ctx.in.takeSExprStart("rec"sv)) { + size_t startIndex = ctx.getRecGroupStartIndex(); + size_t groupLen = 0; + while (auto type = subtype(ctx)) { + CHECK_ERR(type); + ++groupLen; + } + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected type definition or end of recursion group"); + } + ctx.addRecGroup(startIndex, groupLen); + } else if (auto type = subtype(ctx)) { + CHECK_ERR(type); + } else { + return {}; + } + + ctx.finishDeftype(pos); + return Ok{}; +} + +// local ::= '(' 'local id? t:valtype ')' => [t] +// | '(' 'local t*:valtype* ')' => [t*] +// locals ::= local* +template<typename Ctx> MaybeResult<typename Ctx::LocalsT> locals(Ctx& ctx) { + bool hasAny = false; + auto res = ctx.makeLocals(); + while (ctx.in.takeSExprStart("local"sv)) { + hasAny = true; + if (auto id = ctx.in.takeID()) { + // Single named local + auto type = valtype(ctx); + CHECK_ERR(type); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of local"); + } + ctx.appendLocal(res, *id, *type); + } else { + // Repeated unnamed locals + while (!ctx.in.takeRParen()) { + auto type = valtype(ctx); + CHECK_ERR(type); + ctx.appendLocal(res, {}, *type); + } + } + } + if (hasAny) { + return res; + } + return {}; +} + +// func ::= '(' 'func' id? ('(' 'export' name ')')* +// x,I:typeuse t*:vec(local) (in:instr)* ')' +// | '(' 'func' id? ('(' 'export' name ')')* +// '(' 'import' mod:name nm:name ')' typeuse ')' +template<typename Ctx> MaybeResult<> func(Ctx& ctx) { + auto pos = ctx.in.getPos(); + if (!ctx.in.takeSExprStart("func"sv)) { + return {}; + } + + Name name; + if (auto id = ctx.in.takeID()) { + name = *id; + } + + auto exports = inlineExports(ctx.in); + CHECK_ERR(exports); + + auto import = inlineImport(ctx.in); + CHECK_ERR(import); + + auto type = typeuse(ctx); + CHECK_ERR(type); + + std::optional<typename Ctx::LocalsT> localVars; + if (!import) { + if (auto l = locals(ctx)) { + CHECK_ERR(l); + localVars = *l; + } + } + + std::optional<typename Ctx::InstrsT> insts; + if (!import) { + auto i = instrs(ctx); + CHECK_ERR(i); + insts = *i; + } + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of function"); + } + + CHECK_ERR( + ctx.addFunc(name, *exports, import.getPtr(), *type, localVars, insts, pos)); + return Ok{}; +} + +// mem ::= '(' 'memory' id? ('(' 'export' name ')')* +// ('(' 'data' b:datastring ')' | memtype) ')' +// | '(' 'memory' id? ('(' 'export' name ')')* +// '(' 'import' mod:name nm:name ')' memtype ')' +template<typename Ctx> MaybeResult<> memory(Ctx& ctx) { + auto pos = ctx.in.getPos(); + if (!ctx.in.takeSExprStart("memory"sv)) { + return {}; + } + + Name name; + if (auto id = ctx.in.takeID()) { + name = *id; + } + + auto exports = inlineExports(ctx.in); + CHECK_ERR(exports); + + auto import = inlineImport(ctx.in); + CHECK_ERR(import); + + std::optional<typename Ctx::MemTypeT> mtype; + std::optional<typename Ctx::DataStringT> data; + if (ctx.in.takeSExprStart("data"sv)) { + if (import) { + return ctx.in.err("imported memories cannot have inline data"); + } + auto datastr = datastring(ctx); + CHECK_ERR(datastr); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of inline data"); + } + mtype = ctx.makeMemType(Type::i32, ctx.getLimitsFromData(*datastr), false); + data = *datastr; + } else { + auto type = memtype(ctx); + CHECK_ERR(type); + mtype = *type; + } + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of memory declaration"); + } + + CHECK_ERR(ctx.addMemory(name, *exports, import.getPtr(), *mtype, pos)); + + if (data) { + CHECK_ERR(ctx.addImplicitData(std::move(*data))); + } + + return Ok{}; +} + +// global ::= '(' 'global' id? ('(' 'export' name ')')* gt:globaltype e:expr ')' +// | '(' 'global' id? ('(' 'export' name ')')* +// '(' 'import' mod:name nm:name ')' gt:globaltype ')' +template<typename Ctx> MaybeResult<> global(Ctx& ctx) { + auto pos = ctx.in.getPos(); + if (!ctx.in.takeSExprStart("global"sv)) { + return {}; + } + + Name name; + if (auto id = ctx.in.takeID()) { + name = *id; + } + + auto exports = inlineExports(ctx.in); + CHECK_ERR(exports); + + auto import = inlineImport(ctx.in); + CHECK_ERR(import); + + auto type = globaltype(ctx); + CHECK_ERR(type); + + std::optional<typename Ctx::ExprT> exp; + if (!import) { + auto e = expr(ctx); + CHECK_ERR(e); + exp = *e; + } + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of global"); + } + + CHECK_ERR(ctx.addGlobal(name, *exports, import.getPtr(), *type, exp, pos)); + return Ok{}; +} + +// datastring ::= (b:string)* => concat(b*) +template<typename Ctx> Result<typename Ctx::DataStringT> datastring(Ctx& ctx) { + auto data = ctx.makeDataString(); + while (auto str = ctx.in.takeString()) { + ctx.appendDataString(data, *str); + } + return data; +} + +// data ::= '(' 'data' id? b*:datastring ')' => {init b*, mode passive} +// | '(' 'data' id? x:memuse? ('(' 'offset' e:expr ')' | e:instr) +// b*:datastring ') +// => {init b*, mode active {memory x, offset e}} +template<typename Ctx> MaybeResult<> data(Ctx& ctx) { + auto pos = ctx.in.getPos(); + if (!ctx.in.takeSExprStart("data"sv)) { + return {}; + } + + Name name; + if (auto id = ctx.in.takeID()) { + name = *id; + } + + auto mem = maybeMemuse(ctx); + CHECK_ERR(mem); + + std::optional<typename Ctx::ExprT> offset; + if (ctx.in.takeSExprStart("offset"sv)) { + auto e = expr(ctx); + CHECK_ERR(e); + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of offset expression"); + } + offset = *e; + } else if (ctx.in.takeLParen()) { + auto inst = instr(ctx); + CHECK_ERR(inst); + auto offsetExpr = ctx.instrToExpr(*inst); + CHECK_ERR(offsetExpr); + offset = *offsetExpr; + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of offset instruction"); + } + } + + if (mem && !offset) { + return ctx.in.err("expected offset for active segment"); + } + + auto str = datastring(ctx); + CHECK_ERR(str); + + if (!ctx.in.takeRParen()) { + return ctx.in.err("expected end of data segment"); + } + + CHECK_ERR(ctx.addData(name, mem.getPtr(), offset, std::move(*str), pos)); + + return Ok{}; +} + +// modulefield ::= deftype +// | import +// | func +// | table +// | memory +// | global +// | export +// | start +// | elem +// | data +template<typename Ctx> MaybeResult<> modulefield(Ctx& ctx) { + if (auto t = ctx.in.peek(); !t || t->isRParen()) { + return {}; + } + if (auto res = deftype(ctx)) { + CHECK_ERR(res); + return Ok{}; + } + if (auto res = func(ctx)) { + CHECK_ERR(res); + return Ok{}; + } + if (auto res = memory(ctx)) { + CHECK_ERR(res); + return Ok{}; + } + if (auto res = global(ctx)) { + CHECK_ERR(res); + return Ok{}; + } + if (auto res = data(ctx)) { + CHECK_ERR(res); + return Ok{}; + } + return ctx.in.err("unrecognized module field"); +} + +// module ::= '(' 'module' id? (m:modulefield)* ')' +// | (m:modulefield)* eof +template<typename Ctx> Result<> module(Ctx& ctx) { + bool outer = ctx.in.takeSExprStart("module"sv); + + if (outer) { + if (auto id = ctx.in.takeID()) { + ctx.wasm.name = *id; + } + } + + while (auto field = modulefield(ctx)) { + CHECK_ERR(field); + } + + if (outer && !ctx.in.takeRParen()) { + return ctx.in.err("expected end of module"); + } + + return Ok{}; +} + +} // namespace wasm::WATParser + +#endif // parser_parsers_h diff --git a/src/parser/wat-parser.cpp b/src/parser/wat-parser.cpp new file mode 100644 index 000000000..7b58be4d5 --- /dev/null +++ b/src/parser/wat-parser.cpp @@ -0,0 +1,172 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "wat-parser.h" +#include "contexts.h" +#include "ir/names.h" +#include "lexer.h" +#include "parsers.h" +#include "wasm-type.h" +#include "wasm.h" + +// The WebAssembly text format is recursive in the sense that elements may be +// referred to before they are declared. Furthermore, elements may be referred +// to by index or by name. As a result, we need to parse text modules in +// multiple phases. +// +// In the first phase, we find all of the module element declarations and +// record, but do not interpret, the input spans of their corresponding +// definitions. This phase establishes the indices and names of each module +// element so that subsequent phases can look them up. +// +// The second phase parses type definitions to construct the types used in the +// module. This has to be its own phase because we have no way to refer to a +// type before it has been built along with all the other types, unlike for +// other module elements that can be referred to by name before their +// definitions have been parsed. +// +// The third phase further parses and constructs types implicitly defined by +// type uses in functions, blocks, and call_indirect instructions. These +// implicitly defined types may be referred to by index elsewhere. +// +// The fourth phase parses and sets the types of globals, functions, and other +// top-level module elements. These types need to be set before we parse +// instructions because they determine the types of instructions such as +// global.get and ref.func. +// +// The fifth and final phase parses the remaining contents of all module +// elements, including instructions. +// +// Each phase of parsing gets its own context type that is passed to the +// individual parsing functions. There is a parsing function for each element of +// the grammar given in the spec. Parsing functions are templatized so that they +// may be passed the appropriate context type and return the correct result type +// for each phase. + +namespace wasm::WATParser { + +namespace { + +Result<IndexMap> createIndexMap(ParseInput& in, + const std::vector<DefPos>& defs) { + IndexMap indices; + for (auto& def : defs) { + if (def.name.is()) { + if (!indices.insert({def.name, def.index}).second) { + return in.err(def.pos, "duplicate element name"); + } + } + } + return indices; +} + +template<typename Ctx> +Result<> parseDefs(Ctx& ctx, + const std::vector<DefPos>& defs, + MaybeResult<> (*parser)(Ctx&)) { + for (auto& def : defs) { + ctx.index = def.index; + WithPosition with(ctx, def.pos); + auto parsed = parser(ctx); + CHECK_ERR(parsed); + assert(parsed); + } + return Ok{}; +} + +// ================ +// Parser Functions +// ================ + +} // anonymous namespace + +Result<> parseModule(Module& wasm, std::string_view input) { + // Parse module-level declarations. + ParseDeclsCtx decls(input, wasm); + CHECK_ERR(module(decls)); + if (!decls.in.empty()) { + return decls.in.err("Unexpected tokens after module"); + } + + auto typeIndices = createIndexMap(decls.in, decls.subtypeDefs); + CHECK_ERR(typeIndices); + + // Parse type definitions. + std::vector<HeapType> types; + { + TypeBuilder builder(decls.subtypeDefs.size()); + ParseTypeDefsCtx ctx(input, builder, *typeIndices); + for (auto& typeDef : decls.typeDefs) { + WithPosition with(ctx, typeDef.pos); + CHECK_ERR(deftype(ctx)); + } + auto built = builder.build(); + if (auto* err = built.getError()) { + std::stringstream msg; + msg << "invalid type: " << err->reason; + return ctx.in.err(decls.typeDefs[err->index].pos, msg.str()); + } + types = *built; + // Record type names on the module. + for (size_t i = 0; i < types.size(); ++i) { + auto& names = ctx.names[i]; + if (names.name.is() || names.fieldNames.size()) { + wasm.typeNames.insert({types[i], names}); + } + } + } + + // Parse implicit type definitions and map typeuses without explicit types to + // the correct types. + std::unordered_map<Index, HeapType> implicitTypes; + { + ParseImplicitTypeDefsCtx ctx(input, types, implicitTypes, *typeIndices); + for (Index pos : decls.implicitTypeDefs) { + WithPosition with(ctx, pos); + CHECK_ERR(typeuse(ctx)); + } + } + + { + // Parse module-level types. + ParseModuleTypesCtx ctx(input, wasm, types, implicitTypes, *typeIndices); + CHECK_ERR(parseDefs(ctx, decls.funcDefs, func)); + CHECK_ERR(parseDefs(ctx, decls.memoryDefs, memory)); + CHECK_ERR(parseDefs(ctx, decls.globalDefs, global)); + // TODO: Parse types of other module elements. + } + { + // Parse definitions. + // TODO: Parallelize this. + ParseDefsCtx ctx(input, wasm, types, implicitTypes, *typeIndices); + CHECK_ERR(parseDefs(ctx, decls.globalDefs, global)); + CHECK_ERR(parseDefs(ctx, decls.dataDefs, data)); + + for (Index i = 0; i < decls.funcDefs.size(); ++i) { + ctx.index = i; + ctx.setFunction(wasm.functions[i].get()); + CHECK_ERR(ctx.irBuilder.makeBlock(Name{}, ctx.func->getResults())); + WithPosition with(ctx, decls.funcDefs[i].pos); + auto parsed = func(ctx); + CHECK_ERR(parsed); + assert(parsed); + } + } + + return Ok{}; +} + +} // namespace wasm::WATParser diff --git a/src/parser/wat-parser.h b/src/parser/wat-parser.h new file mode 100644 index 000000000..d3ad8d7f3 --- /dev/null +++ b/src/parser/wat-parser.h @@ -0,0 +1,32 @@ +/* + * Copyright 2023 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef parser_wat_parser_h +#define parser_wat_parser_h + +#include <string_view> + +#include "support/result.h" +#include "wasm.h" + +namespace wasm::WATParser { + +// Parse a single WAT module. +Result<> parseModule(Module& wasm, std::string_view in); + +} // namespace wasm::WATParser + +#endif // paser_wat_parser_h |