diff options
-rw-r--r-- | src/wasm-binary.h | 121 |
1 files changed, 71 insertions, 50 deletions
diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 37f2d459b..870f5832d 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1061,10 +1061,11 @@ class WasmBinaryBuilder { std::vector<char>& input; bool debug; - size_t pos; + size_t pos = 0; + int32_t startIndex = -1; public: - WasmBinaryBuilder(AllocatingModule& wasm, std::vector<char>& input, bool debug) : wasm(wasm), allocator(wasm.allocator), input(input), debug(debug), pos(0) {} + WasmBinaryBuilder(AllocatingModule& wasm, std::vector<char>& input, bool debug) : wasm(wasm), allocator(wasm.allocator), input(input), debug(debug) {} void read() { @@ -1224,7 +1225,7 @@ public: void readStart() { if (debug) std::cerr << "== readStart" << std::endl; - wasm.start = wasm.functions[getLEB128()]->name; + startIndex = getLEB128(); } void readMemory() { @@ -1289,6 +1290,11 @@ public: return cashew::IString(("label$" + std::to_string(nextLabel++)).c_str(), false); } + // We read functions before we know their names, so we need to backpatch the names later + + std::vector<Function*> functions; // we store functions here before wasm.addFunction after we know their names + std::map<size_t, std::vector<Call*>> functionCalls; // at index i we have all calls to i + void readFunctions() { if (debug) std::cerr << "== readFunctions" << std::endl; size_t total = getLEB128(); @@ -1327,15 +1333,39 @@ public: addLocals(f64); } size_t size = getInt32(); // XXX int32, diverge from v8 format, to get more code to compile - // we can't read the function yet - it might call other functions that are defined later, - // and we do depend on the function type. - functions.emplace_back(func, pos, size); - pos += size; - func->body = nullptr; // will be filled later. but we do have the name and the type already. - wasm.addFunction(func); + assert(size > 0); // we could also check it matches the see in the next {} + { + // process the function body + if (debug) std::cerr << "processing function: " << i << std::endl; + nextLabel = 0; + // prepare locals + mappedLocals.clear(); + localTypes.clear(); + for (size_t i = 0; i < func->params.size(); i++) { + mappedLocals.push_back(func->params[i].name); + localTypes[func->params[i].name] = func->params[i].type; + } + for (size_t i = 0; i < func->locals.size(); i++) { + mappedLocals.push_back(func->locals[i].name); + localTypes[func->locals[i].name] = func->locals[i].type; + } + // process body + assert(breakStack.empty()); + assert(expressionStack.empty()); + depth = 0; + processExpressions(); + assert(expressionStack.size() == 1); + func->body = popExpression(); + assert(depth == 0); + assert(breakStack.empty()); + assert(expressionStack.empty()); + } + functions.push_back(func); } } + std::map<Export*, size_t> exportIndexes; + void readExports() { if (debug) std::cerr << "== readExports" << std::endl; size_t num = getLEB128(); @@ -1344,22 +1374,12 @@ public: if (debug) std::cerr << "read one" << std::endl; auto curr = allocator.alloc<Export>(); auto index = getLEB128(); - assert(index < wasm.functions.size()); - curr->value = wasm.functions[index]->name; - assert(curr->value.is()); + assert(index < functionTypes.size()); curr->name = getInlineString(); - wasm.addExport(curr); + exportIndexes[curr] = index; } } - struct FunctionData { - Function* func; - size_t pos, size; - FunctionData(Function* func, size_t pos, size_t size) : func(func), pos(pos), size(size) {} - }; - - std::vector<FunctionData> functions; - std::vector<Name> mappedLocals; // index => local name std::map<Name, WasmType> localTypes; // TODO: optimize @@ -1385,32 +1405,31 @@ public: void processFunctions() { for (auto& func : functions) { - Function* curr = func.func; - if (debug) std::cerr << "processing function: " << curr->name << std::endl; - pos = func.pos; - nextLabel = 0; - // prepare locals - mappedLocals.clear(); - localTypes.clear(); - for (size_t i = 0; i < curr->params.size(); i++) { - mappedLocals.push_back(curr->params[i].name); - localTypes[curr->params[i].name] = curr->params[i].type; - } - for (size_t i = 0; i < curr->locals.size(); i++) { - mappedLocals.push_back(curr->locals[i].name); - localTypes[curr->locals[i].name] = curr->locals[i].type; + wasm.addFunction(func); + } + // now that we have names for each function, apply things + + if (startIndex >= 0) { + wasm.start = wasm.functions[startIndex]->name; + } + + for (auto& iter : exportIndexes) { + Export* curr = iter.first; + curr->value = wasm.functions[iter.second]->name; + wasm.addExport(curr); + } + + for (auto& iter : functionCalls) { + size_t index = iter.first; + auto& calls = iter.second; + for (auto* call : calls) { + call->target = wasm.functions[index]->name; } - // process body - assert(breakStack.empty()); - assert(expressionStack.empty()); - depth = 0; - processExpressions(); - assert(expressionStack.size() == 1); - curr->body = popExpression(); - assert(depth == 0); - assert(breakStack.empty()); - assert(expressionStack.empty()); - assert(pos == func.pos + func.size); + } + + for (size_t index : functionTable) { + assert(index < wasm.functions.size()); + wasm.table.names.push_back(wasm.functions[index]->name); } } @@ -1431,13 +1450,14 @@ public: } } + std::vector<size_t> functionTable; + void readFunctionTable() { if (debug) std::cerr << "== readFunctionTable" << std::endl; auto num = getLEB128(); for (size_t i = 0; i < num; i++) { auto index = getLEB128(); - assert(index < wasm.functions.size()); - wasm.table.names.push_back(wasm.functions[index]->name); + functionTable.push_back(index); } } @@ -1588,14 +1608,15 @@ public: } void visitCall(Call *curr) { if (debug) std::cerr << "zz node: Call" << std::endl; - curr->target = wasm.functions[getLEB128()]->name; - auto type = wasm.functionTypesMap[wasm.functionsMap[curr->target]->type]; + auto index = getLEB128(); + auto type = functionTypes[index]; auto num = type->params.size(); curr->operands.resize(num); for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popExpression(); } curr->type = type->result; + functionCalls[index].push_back(curr); } void visitCallImport(CallImport *curr) { if (debug) std::cerr << "zz node: CallImport" << std::endl; |