summaryrefslogtreecommitdiff
path: root/src/wasm-s-parser.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-s-parser.h')
-rw-r--r--src/wasm-s-parser.h115
1 files changed, 68 insertions, 47 deletions
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index 718c459ca..3d55a1bc1 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -237,24 +237,56 @@ class SExpressionWasmBuilder {
MixedArena& allocator;
std::function<void ()> onError;
int functionCounter;
- std::vector<Call*> calls; // we only know call types afterwards, so we set their type in a post-pass
+ std::map<Name, WasmType> functionTypes; // we need to know function return types before we parse their contents
public:
// Assumes control of and modifies the input.
- SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError), functionCounter(0) {
+ SExpressionWasmBuilder(AllocatingModule& wasm, Element& module, std::function<void ()> onError) : wasm(wasm), allocator(wasm.allocator), onError(onError) {
assert(module[0]->str() == MODULE);
+ functionCounter = 0;
for (unsigned i = 1; i < module.size(); i++) {
- parseModuleElement(*module[i]);
+ preParseFunctionType(*module[i]);
}
- // post-pass, fix up call types
- for (auto call : calls) {
- call->type = wasm.functionsMap[call->target]->result;
+ functionCounter = 0;
+ for (unsigned i = 1; i < module.size(); i++) {
+ parseModuleElement(*module[i]);
}
- calls.clear();
}
private:
+ // pre-parse types and function definitions, so we know function return types before parsing their contents
+ void preParseFunctionType(Element& s) {
+ IString id = s[0]->str();
+ if (id == TYPE) return parseType(s);
+ if (id != FUNC) return;
+ size_t i = 1;
+ Name name;
+ if (s[i]->isStr()) {
+ name = s[i]->str();
+ i++;
+ } else {
+ // unnamed, use an index
+ name = Name::fromInt(functionCounter);
+ }
+ functionCounter++;
+ for (;i < s.size(); i++) {
+ Element& curr = *s[i];
+ IString id = curr[0]->str();
+ if (id == RESULT) {
+ functionTypes[name] = stringToWasmType(curr[1]->str());
+ return;
+ } else if (id == TYPE) {
+ Name name = curr[1]->str();
+ if (wasm.functionTypesMap.find(name) == wasm.functionTypesMap.end()) onError();
+ FunctionType* type = wasm.functionTypesMap[name];
+ functionTypes[name] = type->result;
+ return;
+ }
+ }
+ functionTypes[name] = none;
+ }
+
void parseModuleElement(Element& curr) {
IString id = curr[0]->str();
if (id == FUNC) return parseFunction(curr);
@@ -262,7 +294,7 @@ private:
if (id == EXPORT) return parseExport(curr);
if (id == IMPORT) return parseImport(curr);
if (id == TABLE) return parseTable(curr);
- if (id == TYPE) return parseType(curr);
+ if (id == TYPE) return; // already done
std::cerr << "bad module element " << id.str << '\n';
onError();
}
@@ -409,8 +441,8 @@ public:
if (op[2] == 'p') return makeBinary(s, BinaryOp::CopySign, type);
if (op[2] == 'n') {
if (op[3] == 'v') {
- if (op[8] == 's') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertSInt32 : ConvertOp::ConvertSInt64, type);
- if (op[8] == 'u') return makeConvert(s, op[11] == '3' ? ConvertOp::ConvertUInt32 : ConvertOp::ConvertUInt64, type);
+ if (op[8] == 's') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertSInt32 : UnaryOp::ConvertSInt64, type);
+ if (op[8] == 'u') return makeUnary(s, op[11] == '3' ? UnaryOp::ConvertUInt32 : UnaryOp::ConvertUInt64, type);
}
if (op[3] == 's') return makeConst(s, type);
}
@@ -423,12 +455,12 @@ public:
if (op[3] == '_') return makeBinary(s, op[4] == 'u' ? BinaryOp::DivU : BinaryOp::DivS, type);
if (op[3] == 0) return makeBinary(s, BinaryOp::Div, type);
}
- if (op[1] == 'e') return makeConvert(s, ConvertOp::DemoteFloat64, type);
+ if (op[1] == 'e') return makeUnary(s, UnaryOp::DemoteFloat64, type);
abort_on(op);
}
case 'e': {
- if (op[1] == 'q') return makeCompare(s, RelationalOp::Eq, type);
- if (op[1] == 'x') return makeConvert(s, op[7] == 'u' ? ConvertOp::ExtendUInt32 : ConvertOp::ExtendSInt32, type);
+ if (op[1] == 'q') return makeBinary(s, BinaryOp::Eq, type);
+ if (op[1] == 'x') return makeUnary(s, op[7] == 'u' ? UnaryOp::ExtendUInt32 : UnaryOp::ExtendSInt32, type);
abort_on(op);
}
case 'f': {
@@ -437,23 +469,23 @@ public:
}
case 'g': {
if (op[1] == 't') {
- if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GtU : RelationalOp::GtS, type);
- if (op[2] == 0) return makeCompare(s, RelationalOp::Gt, type);
+ if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GtU : BinaryOp::GtS, type);
+ if (op[2] == 0) return makeBinary(s, BinaryOp::Gt, type);
}
if (op[1] == 'e') {
- if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::GeU : RelationalOp::GeS, type);
- if (op[2] == 0) return makeCompare(s, RelationalOp::Ge, type);
+ if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::GeU : BinaryOp::GeS, type);
+ if (op[2] == 0) return makeBinary(s, BinaryOp::Ge, type);
}
abort_on(op);
}
case 'l': {
if (op[1] == 't') {
- if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LtU : RelationalOp::LtS, type);
- if (op[2] == 0) return makeCompare(s, RelationalOp::Lt, type);
+ if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LtU : BinaryOp::LtS, type);
+ if (op[2] == 0) return makeBinary(s, BinaryOp::Lt, type);
}
if (op[1] == 'e') {
- if (op[2] == '_') return makeCompare(s, op[3] == 'u' ? RelationalOp::LeU : RelationalOp::LeS, type);
- if (op[2] == 0) return makeCompare(s, RelationalOp::Le, type);
+ if (op[2] == '_') return makeBinary(s, op[3] == 'u' ? BinaryOp::LeU : BinaryOp::LeS, type);
+ if (op[2] == 0) return makeBinary(s, BinaryOp::Le, type);
}
if (op[1] == 'o') return makeLoad(s, type);
abort_on(op);
@@ -466,7 +498,7 @@ public:
}
case 'n': {
if (op[1] == 'e') {
- if (op[2] == 0) return makeCompare(s, RelationalOp::Ne, type);
+ if (op[2] == 0) return makeBinary(s, BinaryOp::Ne, type);
if (op[2] == 'a') return makeUnary(s, UnaryOp::Nearest, type);
if (op[2] == 'g') return makeUnary(s, UnaryOp::Neg, type);
}
@@ -477,14 +509,14 @@ public:
abort_on(op);
}
case 'p': {
- if (op[1] == 'r') return makeConvert(s, ConvertOp::PromoteFloat32, type);
+ if (op[1] == 'r') return makeUnary(s, UnaryOp::PromoteFloat32, type);
if (op[1] == 'o') return makeUnary(s, UnaryOp::Popcnt, type);
abort_on(op);
}
case 'r': {
if (op[1] == 'e') {
if (op[2] == 'm') return makeBinary(s, op[4] == 'u' ? BinaryOp::RemU : BinaryOp::RemS, type);
- if (op[2] == 'i') return makeConvert(s, isWasmTypeFloat(type) ? ConvertOp::ReinterpretInt : ConvertOp::ReinterpretFloat, type);
+ if (op[2] == 'i') return makeUnary(s, isWasmTypeFloat(type) ? UnaryOp::ReinterpretInt : UnaryOp::ReinterpretFloat, type);
}
abort_on(op);
}
@@ -501,14 +533,14 @@ public:
}
case 't': {
if (op[1] == 'r') {
- if (op[6] == 's') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncSFloat32 : ConvertOp::TruncSFloat64, type);
- if (op[6] == 'u') return makeConvert(s, op[9] == '3' ? ConvertOp::TruncUFloat32 : ConvertOp::TruncUFloat64, type);
+ if (op[6] == 's') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncSFloat32 : UnaryOp::TruncSFloat64, type);
+ if (op[6] == 'u') return makeUnary(s, op[9] == '3' ? UnaryOp::TruncUFloat32 : UnaryOp::TruncUFloat64, type);
if (op[2] == 'u') return makeUnary(s, UnaryOp::Trunc, type);
}
abort_on(op);
}
case 'w': {
- if (op[1] == 'r') return makeConvert(s, ConvertOp::WrapInt64, type);
+ if (op[1] == 'r') return makeUnary(s, UnaryOp::WrapInt64, type);
abort_on(op);
}
case 'x': {
@@ -591,7 +623,7 @@ private:
ret->op = op;
ret->left = parseExpression(s[1]);
ret->right = parseExpression(s[2]);
- ret->type = type;
+ ret->finalize();
return ret;
}
@@ -603,23 +635,6 @@ private:
return ret;
}
- Expression* makeCompare(Element& s, RelationalOp op, WasmType type) {
- auto ret = allocator.alloc<Compare>();
- ret->op = op;
- ret->left = parseExpression(s[1]);
- ret->right = parseExpression(s[2]);
- ret->inputType = type;
- return ret;
- }
-
- Expression* makeConvert(Element& s, ConvertOp op, WasmType type) {
- auto ret = allocator.alloc<Convert>();
- ret->op = op;
- ret->value = parseExpression(s[1]);
- ret->type = type;
- return ret;
- }
-
Expression* makeSelect(Element& s, WasmType type) {
auto ret = allocator.alloc<Select>();
ret->condition = parseExpression(s[1]);
@@ -678,6 +693,7 @@ private:
for (; i < s.size(); i++) {
ret->list.push_back(parseExpression(s[i]));
}
+ ret->type = ret->list.back()->type;
return ret;
}
@@ -904,6 +920,7 @@ private:
ret->ifTrue = parseExpression(s[2]);
if (s.size() == 4) {
ret->ifFalse = parseExpression(s[3]);
+ ret->type = ret->ifTrue->type == ret->ifFalse->type ? ret->ifTrue->type : none; // if not the same type, this does not return a value
}
return ret;
}
@@ -929,6 +946,7 @@ private:
for (; i < s.size() && i < stopAt; i++) {
ret->list.push_back(parseExpression(s[i]));
}
+ ret->type = ret->list.back()->type;
return ret;
}
@@ -957,8 +975,8 @@ private:
Expression* makeCall(Element& s) {
auto ret = allocator.alloc<Call>();
- calls.push_back(ret);
ret->target = s[1]->str();
+ ret->type = functionTypes[ret->target];
parseCallOperands(s, 2, ret);
return ret;
}
@@ -966,6 +984,8 @@ private:
Expression* makeCallImport(Element& s) {
auto ret = allocator.alloc<CallImport>();
ret->target = s[1]->str();
+ Import* import = wasm.importsMap[ret->target];
+ ret->type = import->type.result;
parseCallOperands(s, 2, ret);
return ret;
}
@@ -974,7 +994,8 @@ private:
auto ret = allocator.alloc<CallIndirect>();
IString type = s[1]->str();
assert(wasm.functionTypesMap.find(type) != wasm.functionTypesMap.end());
- ret->type = wasm.functionTypesMap[type];
+ ret->fullType = wasm.functionTypesMap[type];
+ ret->type = ret->fullType->result;
ret->target = parseExpression(s[2]);
parseCallOperands(s, 3, ret);
return ret;