summaryrefslogtreecommitdiff
path: root/src/tools/wasm-merge.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tools/wasm-merge.cpp')
-rw-r--r--src/tools/wasm-merge.cpp168
1 files changed, 90 insertions, 78 deletions
diff --git a/src/tools/wasm-merge.cpp b/src/tools/wasm-merge.cpp
index e8910860e..70611e269 100644
--- a/src/tools/wasm-merge.cpp
+++ b/src/tools/wasm-merge.cpp
@@ -34,20 +34,10 @@
#include "wasm-binary.h"
#include "wasm-builder.h"
#include "wasm-validator.h"
+#include "ir/module-utils.h"
using namespace wasm;
-// Calls note() on every import that has form "env".(base)
-static void findImportsByBase(Module& wasm, Name base, std::function<void (Name)> note) {
- for (auto& curr : wasm.imports) {
- if (curr->module == ENV) {
- if (curr->base == base) {
- note(curr->name);
- }
- }
- }
-}
-
// Ensure a memory or table is of at least a size
template<typename T>
static void ensureSize(T& what, Index size) {
@@ -119,31 +109,33 @@ struct Mergeable {
}
void findImports() {
- findImportsByBase(wasm, MEMORY_BASE, [&](Name name) {
- memoryBaseGlobals.insert(name);
+ ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) {
+ if (import->module == ENV && import->base == MEMORY_BASE) {
+ memoryBaseGlobals.insert(import->name);
+ }
});
if (memoryBaseGlobals.size() == 0) {
// add one
- auto* import = new Import;
+ auto* import = new Global;
import->name = MEMORY_BASE;
import->module = ENV;
import->base = MEMORY_BASE;
- import->kind = ExternalKind::Global;
- import->globalType = i32;
- wasm.addImport(import);
+ import->type = i32;
+ wasm.addGlobal(import);
memoryBaseGlobals.insert(import->name);
}
- findImportsByBase(wasm, TABLE_BASE, [&](Name name) {
- tableBaseGlobals.insert(name);
+ ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) {
+ if (import->module == ENV && import->base == TABLE_BASE) {
+ tableBaseGlobals.insert(import->name);
+ }
});
if (tableBaseGlobals.size() == 0) {
- auto* import = new Import;
+ auto* import = new Global;
import->name = TABLE_BASE;
import->module = ENV;
import->base = TABLE_BASE;
- import->kind = ExternalKind::Global;
- import->globalType = i32;
- wasm.addImport(import);
+ import->type = i32;
+ wasm.addGlobal(import);
tableBaseGlobals.insert(import->name);
}
}
@@ -244,7 +236,7 @@ struct Mergeable {
struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergeable>>, public Mergeable {
OutputMergeable(Module& wasm) : Mergeable(wasm) {}
- void visitCallImport(CallImport* curr) {
+ void visitCall(Call* curr) {
auto iter = implementedFunctionImports.find(curr->target);
if (iter != implementedFunctionImports.end()) {
// this import is now in the module - call it
@@ -264,10 +256,10 @@ struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergea
void visitModule(Module* curr) {
// remove imports that are being implemented
for (auto& pair : implementedFunctionImports) {
- curr->removeImport(pair.first);
+ curr->removeFunction(pair.first);
}
for (auto& pair : implementedGlobalImports) {
- curr->removeImport(pair.first);
+ curr->removeGlobal(pair.first);
}
}
};
@@ -288,11 +280,6 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
std::map<Name, Name> gNames; // globals
void visitCall(Call* curr) {
- curr->target = fNames[curr->target];
- assert(curr->target.is());
- }
-
- void visitCallImport(CallImport* curr) {
auto iter = implementedFunctionImports.find(curr->target);
if (iter != implementedFunctionImports.end()) {
// this import is now in the module - call it
@@ -333,29 +320,38 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
void merge() {
// find function imports in us that are implemented in the output
// TODO make maps, avoid N^2
- for (auto& imp : wasm.imports) {
+ ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) {
// per wasm dynamic library rules, we expect to see exports on 'env'
- if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) {
+ if (import->module == ENV) {
// seek an export on the other side that matches
for (auto& exp : outputMergeable.wasm.exports) {
- if (exp->kind == imp->kind && exp->name == imp->base) {
+ if (exp->name == import->base) {
// fits!
- if (imp->kind == ExternalKind::Function) {
- implementedFunctionImports[imp->name] = exp->value;
- } else {
- implementedGlobalImports[imp->name] = exp->value;
- }
+ implementedFunctionImports[import->name] = exp->value;
break;
}
}
}
- }
+ });
+ ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) {
+ // per wasm dynamic library rules, we expect to see exports on 'env'
+ if (import->module == ENV) {
+ // seek an export on the other side that matches
+ for (auto& exp : outputMergeable.wasm.exports) {
+ if (exp->name == import->base) {
+ // fits!
+ implementedGlobalImports[import->name] = exp->value;
+ break;
+ }
+ }
+ }
+ });
// remove the unneeded ones
for (auto& pair : implementedFunctionImports) {
- wasm.removeImport(pair.first);
+ wasm.removeFunction(pair.first);
}
for (auto& pair : implementedGlobalImports) {
- wasm.removeImport(pair.first);
+ wasm.removeGlobal(pair.first);
}
// find new names
@@ -364,27 +360,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
return outputMergeable.wasm.getFunctionTypeOrNull(name);
});
}
- for (auto& curr : wasm.imports) {
- if (curr->kind == ExternalKind::Function) {
- curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
- return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getFunctionOrNull(name);
- });
- } else if (curr->kind == ExternalKind::Global) {
- curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
- return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getGlobalOrNull(name);
- });
- }
- }
- for (auto& curr : wasm.functions) {
+ ModuleUtils::iterImportedFunctions(wasm, [&](Function* curr) {
+ curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
+ return !!outputMergeable.wasm.getFunctionOrNull(name);
+ });
+ });
+ ModuleUtils::iterImportedGlobals(wasm, [&](Global* curr) {
+ curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
+ return !!outputMergeable.wasm.getGlobalOrNull(name);
+ });
+ });
+ ModuleUtils::iterDefinedFunctions(wasm, [&](Function* curr) {
curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
return outputMergeable.wasm.getFunctionOrNull(name);
});
- }
- for (auto& curr : wasm.globals) {
+ });
+ ModuleUtils::iterDefinedGlobals(wasm, [&](Global* curr) {
curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool {
return outputMergeable.wasm.getGlobalOrNull(name);
});
- }
+ });
// update global names in input
{
@@ -403,20 +398,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
}
// find function imports in output that are implemented in the input
- for (auto& imp : outputMergeable.wasm.imports) {
- if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) {
+ ModuleUtils::iterImportedFunctions(outputMergeable.wasm, [&](Function* import) {
+ if (import->module == ENV) {
for (auto& exp : wasm.exports) {
- if (exp->kind == imp->kind && exp->name == imp->base) {
- if (imp->kind == ExternalKind::Function) {
- outputMergeable.implementedFunctionImports[imp->name] = fNames[exp->value];
- } else {
- outputMergeable.implementedGlobalImports[imp->name] = gNames[exp->value];
- }
+ if (exp->name == import->base) {
+ outputMergeable.implementedFunctionImports[import->name] = fNames[exp->value];
break;
}
}
}
- }
+ });
+ ModuleUtils::iterImportedGlobals(outputMergeable.wasm, [&](Global* import) {
+ if (import->module == ENV) {
+ for (auto& exp : wasm.exports) {
+ if (exp->name == import->base) {
+ outputMergeable.implementedGlobalImports[import->name] = gNames[exp->value];
+ break;
+ }
+ }
+ }
+ });
// update the output before bringing anything in. avoid doing so when possible, as in the
// common case the output module is very large.
@@ -448,16 +449,19 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
for (auto& curr : wasm.functionTypes) {
outputMergeable.wasm.addFunctionType(curr.release());
}
- for (auto& curr : wasm.imports) {
- if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) {
- continue; // wasm has just 1 of each, they must match
+ for (auto& curr : wasm.globals) {
+ if (curr->imported()) {
+ outputMergeable.wasm.addGlobal(curr.release());
}
- // update and add
- if (curr->functionType.is()) {
- curr->functionType = ftNames[curr->functionType];
- assert(curr->functionType.is());
+ }
+ for (auto& curr : wasm.functions) {
+ if (curr->imported()) {
+ if (curr->type.is()) {
+ curr->type = ftNames[curr->type];
+ assert(curr->type.is());
+ }
+ outputMergeable.wasm.addFunction(curr.release());
}
- outputMergeable.wasm.addImport(curr.release());
}
for (auto& curr : wasm.exports) {
if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) {
@@ -477,13 +481,21 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp
}
}
}
+ // Copy over the remaining non-imports (we have already transferred
+ // the imports, and they are nullptrs).
for (auto& curr : wasm.functions) {
- curr->type = ftNames[curr->type];
- assert(curr->type.is());
- outputMergeable.wasm.addFunction(curr.release());
+ if (curr) {
+ assert(!curr->imported());
+ curr->type = ftNames[curr->type];
+ assert(curr->type.is());
+ outputMergeable.wasm.addFunction(curr.release());
+ }
}
for (auto& curr : wasm.globals) {
- outputMergeable.wasm.addGlobal(curr.release());
+ if (curr) {
+ assert(!curr->imported());
+ outputMergeable.wasm.addGlobal(curr.release());
+ }
}
}