summaryrefslogtreecommitdiff
path: root/src/wasm-binary.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/wasm-binary.h')
-rw-r--r--src/wasm-binary.h147
1 files changed, 86 insertions, 61 deletions
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 561a310b0..bbfde5b21 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -411,7 +411,6 @@ enum ASTNodes {
SetLocal = 0x15,
CallFunction = 0x16,
CallIndirect = 0x17,
- CallImport = 0x18,
TeeLocal = 0x19,
GetGlobal = 0x1a,
SetGlobal = 0x1b,
@@ -653,7 +652,7 @@ public:
if (debug) std::cerr << "write one at" << o.size() << std::endl;
size_t sizePos = writeU32LEBPlaceholder();
size_t start = o.size();
- Function* function = wasm->getFunction(i);
+ Function* function = wasm->functions[i].get();
mappedLocals.clear();
numLocalsByType.clear();
if (debug) std::cerr << "writing" << function->name << std::endl;
@@ -727,27 +726,21 @@ public:
}
finishSection(start);
}
-
- std::map<Name, uint32_t> mappedImports; // name of the Import => index
- uint32_t getImportIndex(Name name) {
- if (!mappedImports.size()) {
- // Create name => index mapping.
- for (size_t i = 0; i < wasm->imports.size(); i++) {
- assert(mappedImports.count(wasm->imports[i]->name) == 0);
- mappedImports[wasm->imports[i]->name] = i;
- }
- }
- assert(mappedImports.count(name));
- return mappedImports[name];
- }
- std::map<Name, uint32_t> mappedFunctions; // name of the Function => index
+ std::map<Name, Index> mappedFunctions; // name of the Function => index. first imports, then internals
uint32_t getFunctionIndex(Name name) {
if (!mappedFunctions.size()) {
// Create name => index mapping.
+ for (auto& import : wasm->imports) {
+ if (import->kind != Import::Function) continue;
+ assert(mappedFunctions.count(import->name) == 0);
+ auto index = mappedFunctions.size();
+ mappedFunctions[import->name] = index;
+ }
for (size_t i = 0; i < wasm->functions.size(); i++) {
assert(mappedFunctions.count(wasm->functions[i]->name) == 0);
- mappedFunctions[wasm->functions[i]->name] = i;
+ auto index = mappedFunctions.size();
+ mappedFunctions[wasm->functions[i]->name] = index;
}
}
assert(mappedFunctions.count(name));
@@ -957,7 +950,7 @@ public:
for (auto* operand : curr->operands) {
recurse(operand);
}
- o << int8_t(BinaryConsts::CallImport) << U32LEB(curr->operands.size()) << U32LEB(getImportIndex(curr->target));
+ o << int8_t(BinaryConsts::CallFunction) << U32LEB(curr->operands.size()) << U32LEB(getFunctionIndex(curr->target));
}
void visitCallIndirect(CallIndirect *curr) {
if (debug) std::cerr << "zz node: CallIndirect" << std::endl;
@@ -1300,8 +1293,13 @@ public:
else if (match(BinaryConsts::Section::FunctionSignatures)) readFunctionSignatures();
else if (match(BinaryConsts::Section::Functions)) readFunctions();
else if (match(BinaryConsts::Section::ExportTable)) readExports();
- else if (match(BinaryConsts::Section::Globals)) readGlobals();
- else if (match(BinaryConsts::Section::DataSegments)) readDataSegments();
+ else if (match(BinaryConsts::Section::Globals)) {
+ readGlobals();
+ // imports can read global imports, so we run getGlobalName and create the mapping
+ // but after we read globals, we need to add the internal globals too, so do that here
+ mappedGlobals.clear(); // wipe the mapping
+ getGlobalName(0); // force rebuild
+ } else if (match(BinaryConsts::Section::DataSegments)) readDataSegments();
else if (match(BinaryConsts::Section::FunctionTable)) readFunctionTable();
else if (match(BinaryConsts::Section::Names)) readNames();
else {
@@ -1497,10 +1495,25 @@ public:
assert(numResults == 1);
curr->result = getWasmType();
}
+ curr->name = Name::fromInt(wasm.functionTypes.size());
wasm.addFunctionType(curr);
}
}
+ std::vector<Name> functionImportIndexes; // index in function index space => name of function import
+
+ // gets a name in the combined function import+defined function space
+ Name getFunctionIndexName(Index i) {
+ if (i < functionImportIndexes.size()) {
+ auto* import = wasm.getImport(functionImportIndexes[i]);
+ assert(import->kind == Import::Function);
+ return import->name;
+ } else {
+ i -= functionImportIndexes.size();
+ return wasm.functions.at(i)->name;
+ }
+ }
+
void readImports() {
if (debug) std::cerr << "== readImports" << std::endl;
size_t num = getU32LEB();
@@ -1511,16 +1524,17 @@ public:
curr->name = Name(std::string("import$") + std::to_string(i));
curr->kind = (Import::Kind)getU32LEB();
switch (curr->kind) {
- case Export::Function: {
+ case Import::Function: {
auto index = getU32LEB();
assert(index < wasm.functionTypes.size());
- curr->functionType = wasm.getFunctionType(index);
+ curr->functionType = wasm.functionTypes[index].get();
assert(curr->functionType->name.is());
+ functionImportIndexes.push_back(curr->name);
break;
}
- case Export::Table: break;
- case Export::Memory: break;
- case Export::Global: curr->globalType = getWasmType(); break;
+ case Import::Table: break;
+ case Import::Memory: break;
+ case Import::Global: curr->globalType = getWasmType(); break;
default: WASM_UNREACHABLE();
}
curr->module = getInlineString();
@@ -1529,7 +1543,7 @@ public:
}
}
- std::vector<FunctionType*> functionTypes;
+ std::vector<FunctionType*> functionTypes; // types of defined functions
void readFunctionSignatures() {
if (debug) std::cerr << "== readFunctionSignatures" << std::endl;
@@ -1538,7 +1552,7 @@ public:
for (size_t i = 0; i < num; i++) {
if (debug) std::cerr << "read one" << std::endl;
auto index = getU32LEB();
- functionTypes.push_back(wasm.getFunctionType(index));
+ functionTypes.push_back(wasm.functionTypes[index].get());
}
}
@@ -1551,9 +1565,9 @@ public:
// We read functions before we know their names, so we need to backpatch the names later
std::vector<Function*> functions; // we store functions here before wasm.addFunction after we know their names
- std::map<size_t, std::vector<Call*>> functionCalls; // at index i we have all calls to i
+ std::map<Index, std::vector<Call*>> functionCalls; // at index i we have all calls to the defined function i
Function* currFunction = nullptr;
- size_t endOfFunction;
+ Index endOfFunction = -1; // before we see a function (like global init expressions), there is no end of function to check
void readFunctions() {
if (debug) std::cerr << "== readFunctions" << std::endl;
@@ -1611,7 +1625,7 @@ public:
}
}
- std::map<Export*, size_t> exportIndexes;
+ std::map<Export*, Index> exportIndexes;
void readExports() {
if (debug) std::cerr << "== readExports" << std::endl;
@@ -1645,6 +1659,8 @@ public:
auto curr = new Global;
curr->type = getWasmType();
curr->init = readExpression();
+ curr->mutable_ = true; // TODO
+ curr->name = Name("global$" + std::to_string(wasm.globals.size()));
wasm.addGlobal(curr);
}
}
@@ -1698,14 +1714,17 @@ public:
}
// now that we have names for each function, apply things
- if (startIndex != static_cast<Index>(-1) && startIndex < wasm.functions.size()) {
- wasm.start = wasm.functions[startIndex]->name;
+ if (startIndex != static_cast<Index>(-1)) {
+ wasm.start = getFunctionIndexName(startIndex);
}
for (auto& iter : exportIndexes) {
Export* curr = iter.first;
switch (curr->kind) {
- case Export::Function: curr->value = wasm.functions[iter.second]->name; break;
+ case Export::Function: {
+ curr->value = getFunctionIndexName(iter.second);
+ break;
+ }
case Export::Table: curr->value = Name::fromInt(0); break;
case Export::Memory: curr->value = Name::fromInt(0); break;
case Export::Global: curr->value = getGlobalName(iter.second); break;
@@ -1726,7 +1745,7 @@ public:
auto i = pair.first;
auto& indexes = pair.second;
for (auto j : indexes) {
- wasm.table.segments[i].data.push_back(wasm.functions[j]->name);
+ wasm.table.segments[i].data.push_back(getFunctionIndexName(j));
}
}
}
@@ -1795,8 +1814,7 @@ public:
case BinaryConsts::Br:
case BinaryConsts::BrIf: visitBreak((curr = allocator.alloc<Break>())->cast<Break>(), code); break; // code distinguishes br from br_if
case BinaryConsts::TableSwitch: visitSwitch((curr = allocator.alloc<Switch>())->cast<Switch>()); break;
- case BinaryConsts::CallFunction: visitCall((curr = allocator.alloc<Call>())->cast<Call>()); break;
- case BinaryConsts::CallImport: visitCallImport((curr = allocator.alloc<CallImport>())->cast<CallImport>()); break;
+ case BinaryConsts::CallFunction: curr = visitCall(); break; // we don't know if it's a call or call_import yet
case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break;
case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break;
case BinaryConsts::TeeLocal:
@@ -1938,44 +1956,51 @@ public:
}
curr->default_ = getBreakName(getInt32());
}
- void visitCall(Call *curr) {
- if (debug) std::cerr << "zz node: Call" << std::endl;
- auto arity = getU32LEB();
- WASM_UNUSED(arity);
- auto index = getU32LEB();
- assert(index < functionTypes.size());
- auto type = functionTypes[index];
+
+ template<typename T>
+ void fillCall(T* call, FunctionType* type, Index arity) {
+ assert(type);
auto num = type->params.size();
assert(num == arity);
- curr->operands.resize(num);
+ call->operands.resize(num);
for (size_t i = 0; i < num; i++) {
- curr->operands[num - i - 1] = popExpression();
+ call->operands[num - i - 1] = popExpression();
}
- curr->type = type->result;
- functionCalls[index].push_back(curr);
+ call->type = type->result;
}
- void visitCallImport(CallImport *curr) {
- if (debug) std::cerr << "zz node: CallImport" << std::endl;
+
+ Expression* visitCall() {
+ if (debug) std::cerr << "zz node: Call" << std::endl;
auto arity = getU32LEB();
WASM_UNUSED(arity);
- auto import = wasm.getImport(getU32LEB());
- curr->target = import->name;
- auto type = import->functionType;
- assert(type);
- auto num = type->params.size();
- assert(num == arity);
- if (debug) std::cerr << "zz node: CallImport " << curr->target << " with type " << type->name << " and " << num << " params\n";
- curr->operands.resize(num);
- for (size_t i = 0; i < num; i++) {
- curr->operands[num - i - 1] = popExpression();
+ auto index = getU32LEB();
+ FunctionType* type;
+ Expression* ret;
+ if (index < functionImportIndexes.size()) {
+ // this is a call of an imported function
+ auto* call = allocator.alloc<CallImport>();
+ auto* import = wasm.getImport(functionImportIndexes[index]);
+ call->target = import->name;
+ type = import->functionType;
+ fillCall(call, type, arity);
+ ret = call;
+ } else {
+ // this is a call of a defined function
+ auto* call = allocator.alloc<Call>();
+ auto adjustedIndex = index - functionImportIndexes.size();
+ assert(adjustedIndex < functionTypes.size());
+ type = functionTypes[adjustedIndex];
+ fillCall(call, type, arity);
+ functionCalls[adjustedIndex].push_back(call); // we don't know function names yet
+ ret = call;
}
- curr->type = type->result;
+ return ret;
}
void visitCallIndirect(CallIndirect *curr) {
if (debug) std::cerr << "zz node: CallIndirect" << std::endl;
auto arity = getU32LEB();
WASM_UNUSED(arity);
- auto* fullType = wasm.getFunctionType(getU32LEB());
+ auto* fullType = wasm.functionTypes.at(getU32LEB()).get();
curr->fullType = fullType->name;
auto num = fullType->params.size();
assert(num == arity);