diff options
Diffstat (limited to 'src/wasm-s-parser.h')
-rw-r--r-- | src/wasm-s-parser.h | 115 |
1 files changed, 68 insertions, 47 deletions
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 718c459ca..3d55a1bc1 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -237,24 +237,56 @@ class SExpressionWasmBuilder { MixedArena& allocator; std::function<void ()> onError; int functionCounter; - std::vector<Call*> calls; // we only know call types afterwards, so we set their type in a post-pass + std::map<Name, WasmType> functionTypes; // we need to know function return types before we parse their contents public: // Assumes control of and modifies the input. - SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError), functionCounter(0) { + SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError) { assert(module[0]->str() == MODULE); + functionCounter = 0; for (unsigned i = 1; i < module.size(); i++) { - parseModuleElement(*module[i]); + preParseFunctionType(*module[i]); } - // post-pass, fix up call types - for (auto call : calls) { - call->type = wasm.functionsMap[call->target]->result; + functionCounter = 0; + for (unsigned i = 1; i < module.size(); i++) { + parseModuleElement(*module[i]); } - calls.clear(); } private: + // pre-parse types and function definitions, so we know function return types before parsing their contents + void preParseFunctionType(Element& s) { + IString id = s[0]->str(); + if (id == TYPE) return parseType(s); + if (id != FUNC) return; + size_t i = 1; + Name name; + if (s[i]->isStr()) { + name = s[i]->str(); + i++; + } else { + // unnamed, use an index + name = Name::fromInt(functionCounter); + } + functionCounter++; + for (;i < s.size(); i++) { + Element& curr = *s[i]; + IString id = curr[0]->str(); + if (id == RESULT) { + functionTypes[name] = stringToWasmType(curr[1]->str()); + return; + } else if (id == TYPE) { + Name name = curr[1]->str(); + if (wasm.functionTypesMap.find(name) == wasm.functionTypesMap.end()) onError(); + FunctionType* type = wasm.functionTypesMap[name]; + functionTypes[name] = type->result; + return; + } + } + functionTypes[name] = none; + } + void parseModuleElement(Element& curr) { IString id = curr[0]->str(); if (id == FUNC) return parseFunction(curr); @@ -262,7 +294,7 @@ private: if (id == EXPORT) return parseExport(curr); if (id == IMPORT) return parseImport(curr); if (id == TABLE) return parseTable(curr); - if (id == TYPE) return parseType(curr); + if (id == TYPE) return; // already done std::cerr << "bad module element " << id.str << '\n'; onError(); } @@ -409,8 +441,8 @@ public: if (op[2] == 'p') return makeBinary(s, BinaryOp::CopySign, type); if (op[2] == 'n') { if (op[3] == 'v') { - if (op[8] == 's') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertSInt32 : ConvertOp::ConvertSInt64, type); - if (op[8] == 'u') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertUInt32 : ConvertOp::ConvertUInt64, type); + if (op[8] == 's') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertSInt32 : UnaryOp::ConvertSInt64, type); + if (op[8] == 'u') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertUInt32 : UnaryOp::ConvertUInt64, type); } if (op[3] == 's') return makeConst(s, type); } @@ -423,12 +455,12 @@ public: if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BinaryOp::DivU : BinaryOp::DivS, type); if (op[3] == 0) return makeBinary(s, BinaryOp::Div, type); } - if (op[1] == 'e') return makeConvert(s, ConvertOp::DemoteFloat64, type); + if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type); abort_on(op); } case 'e': { - if (op[1] == 'q') return makeCompare(s, RelationalOp::Eq, type); - if (op[1] == 'x') return makeConvert(s, op[7] == 'u' ? ConvertOp::ExtendUInt32 : ConvertOp::ExtendSInt32, type); + if (op[1] == 'q') return makeBinary(s, BinaryOp::Eq, type); + if (op[1] == 'x') return makeUnary(s, op[7] == 'u' ? UnaryOp::ExtendUInt32 : UnaryOp::ExtendSInt32, type); abort_on(op); } case 'f': { @@ -437,23 +469,23 @@ public: } case 'g': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GtU : RelationalOp::GtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Gt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GtU : BinaryOp::GtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Gt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GeU : RelationalOp::GeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Ge, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GeU : BinaryOp::GeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ge, type); } abort_on(op); } case 'l': { if (op[1] == 't') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LtU : RelationalOp::LtS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Lt, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LtU : BinaryOp::LtS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Lt, type); } if (op[1] == 'e') { - if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LeU : RelationalOp::LeS, type); - if (op[2] == 0) return makeCompare(s, RelationalOp::Le, type); + if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LeU : BinaryOp::LeS, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Le, type); } if (op[1] == 'o') return makeLoad(s, type); abort_on(op); @@ -466,7 +498,7 @@ public: } case 'n': { if (op[1] == 'e') { - if (op[2] == 0) return makeCompare(s, RelationalOp::Ne, type); + if (op[2] == 0) return makeBinary(s, BinaryOp::Ne, type); if (op[2] == 'a') return makeUnary(s, UnaryOp::Nearest, type); if (op[2] == 'g') return makeUnary(s, UnaryOp::Neg, type); } @@ -477,14 +509,14 @@ public: abort_on(op); } case 'p': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::PromoteFloat32, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::PromoteFloat32, type); if (op[1] == 'o') return makeUnary(s, UnaryOp::Popcnt, type); abort_on(op); } case 'r': { if (op[1] == 'e') { if (op[2] == 'm') return makeBinary(s, op[4] == 'u' ? BinaryOp::RemU : BinaryOp::RemS, type); - if (op[2] == 'i') return makeConvert(s, isWasmTypeFloat(type) ? ConvertOp::ReinterpretInt : ConvertOp::ReinterpretFloat, type); + if (op[2] == 'i') return makeUnary(s, isWasmTypeFloat(type) ? UnaryOp::ReinterpretInt : UnaryOp::ReinterpretFloat, type); } abort_on(op); } @@ -501,14 +533,14 @@ public: } case 't': { if (op[1] == 'r') { - if (op[6] == 's') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncSFloat32 : ConvertOp::TruncSFloat64, type); - if (op[6] == 'u') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncUFloat32 : ConvertOp::TruncUFloat64, type); + if (op[6] == 's') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncSFloat32 : UnaryOp::TruncSFloat64, type); + if (op[6] == 'u') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncUFloat32 : UnaryOp::TruncUFloat64, type); if (op[2] == 'u') return makeUnary(s, UnaryOp::Trunc, type); } abort_on(op); } case 'w': { - if (op[1] == 'r') return makeConvert(s, ConvertOp::WrapInt64, type); + if (op[1] == 'r') return makeUnary(s, UnaryOp::WrapInt64, type); abort_on(op); } case 'x': { @@ -591,7 +623,7 @@ private: ret->op = op; ret->left = parseExpression(s[1]); ret->right = parseExpression(s[2]); - ret->type = type; + ret->finalize(); return ret; } @@ -603,23 +635,6 @@ private: return ret; } - Expression* makeCompare(Element& s, RelationalOp op, WasmType type) { - auto ret = allocator.alloc<Compare>(); - ret->op = op; - ret->left = parseExpression(s[1]); - ret->right = parseExpression(s[2]); - ret->inputType = type; - return ret; - } - - Expression* makeConvert(Element& s, ConvertOp op, WasmType type) { - auto ret = allocator.alloc<Convert>(); - ret->op = op; - ret->value = parseExpression(s[1]); - ret->type = type; - return ret; - } - Expression* makeSelect(Element& s, WasmType type) { auto ret = allocator.alloc<Select>(); ret->condition = parseExpression(s[1]); @@ -678,6 +693,7 @@ private: for (; i < s.size(); i++) { ret->list.push_back(parseExpression(s[i])); } + ret->type = ret->list.back()->type; return ret; } @@ -904,6 +920,7 @@ private: ret->ifTrue = parseExpression(s[2]); if (s.size() == 4) { ret->ifFalse = parseExpression(s[3]); + ret->type = ret->ifTrue->type == ret->ifFalse->type ? ret->ifTrue->type : none; // if not the same type, this does not return a value } return ret; } @@ -929,6 +946,7 @@ private: for (; i < s.size() && i < stopAt; i++) { ret->list.push_back(parseExpression(s[i])); } + ret->type = ret->list.back()->type; return ret; } @@ -957,8 +975,8 @@ private: Expression* makeCall(Element& s) { auto ret = allocator.alloc<Call>(); - calls.push_back(ret); ret->target = s[1]->str(); + ret->type = functionTypes[ret->target]; parseCallOperands(s, 2, ret); return ret; } @@ -966,6 +984,8 @@ private: Expression* makeCallImport(Element& s) { auto ret = allocator.alloc<CallImport>(); ret->target = s[1]->str(); + Import* import = wasm.importsMap[ret->target]; + ret->type = import->type.result; parseCallOperands(s, 2, ret); return ret; } @@ -974,7 +994,8 @@ private: auto ret = allocator.alloc<CallIndirect>(); IString type = s[1]->str(); assert(wasm.functionTypesMap.find(type) != wasm.functionTypesMap.end()); - ret->type = wasm.functionTypesMap[type]; + ret->fullType = wasm.functionTypesMap[type]; + ret->type = ret->fullType->result; ret->target = parseExpression(s[2]); parseCallOperands(s, 3, ret); return ret; |