diff options
Diffstat (limited to 'src/tools/wasm-merge.cpp')
-rw-r--r-- | src/tools/wasm-merge.cpp | 168 |
1 files changed, 90 insertions, 78 deletions
diff --git a/src/tools/wasm-merge.cpp b/src/tools/wasm-merge.cpp index e8910860e..70611e269 100644 --- a/src/tools/wasm-merge.cpp +++ b/src/tools/wasm-merge.cpp @@ -34,20 +34,10 @@ #include "wasm-binary.h" #include "wasm-builder.h" #include "wasm-validator.h" +#include "ir/module-utils.h" using namespace wasm; -// Calls note() on every import that has form "env".(base) -static void findImportsByBase(Module& wasm, Name base, std::function<void (Name)> note) { - for (auto& curr : wasm.imports) { - if (curr->module == ENV) { - if (curr->base == base) { - note(curr->name); - } - } - } -} - // Ensure a memory or table is of at least a size template<typename T> static void ensureSize(T& what, Index size) { @@ -119,31 +109,33 @@ struct Mergeable { } void findImports() { - findImportsByBase(wasm, MEMORY_BASE, [&](Name name) { - memoryBaseGlobals.insert(name); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + if (import->module == ENV && import->base == MEMORY_BASE) { + memoryBaseGlobals.insert(import->name); + } }); if (memoryBaseGlobals.size() == 0) { // add one - auto* import = new Import; + auto* import = new Global; import->name = MEMORY_BASE; import->module = ENV; import->base = MEMORY_BASE; - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + import->type = i32; + wasm.addGlobal(import); memoryBaseGlobals.insert(import->name); } - findImportsByBase(wasm, TABLE_BASE, [&](Name name) { - tableBaseGlobals.insert(name); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + if (import->module == ENV && import->base == TABLE_BASE) { + tableBaseGlobals.insert(import->name); + } }); if (tableBaseGlobals.size() == 0) { - auto* import = new Import; + auto* import = new Global; import->name = TABLE_BASE; import->module = ENV; import->base = TABLE_BASE; - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + import->type = i32; + wasm.addGlobal(import); tableBaseGlobals.insert(import->name); } } @@ -244,7 +236,7 @@ struct Mergeable { struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergeable>>, public Mergeable { OutputMergeable(Module& wasm) : Mergeable(wasm) {} - void visitCallImport(CallImport* curr) { + void visitCall(Call* curr) { auto iter = implementedFunctionImports.find(curr->target); if (iter != implementedFunctionImports.end()) { // this import is now in the module - call it @@ -264,10 +256,10 @@ struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergea void visitModule(Module* curr) { // remove imports that are being implemented for (auto& pair : implementedFunctionImports) { - curr->removeImport(pair.first); + curr->removeFunction(pair.first); } for (auto& pair : implementedGlobalImports) { - curr->removeImport(pair.first); + curr->removeGlobal(pair.first); } } }; @@ -288,11 +280,6 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp std::map<Name, Name> gNames; // globals void visitCall(Call* curr) { - curr->target = fNames[curr->target]; - assert(curr->target.is()); - } - - void visitCallImport(CallImport* curr) { auto iter = implementedFunctionImports.find(curr->target); if (iter != implementedFunctionImports.end()) { // this import is now in the module - call it @@ -333,29 +320,38 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp void merge() { // find function imports in us that are implemented in the output // TODO make maps, avoid N^2 - for (auto& imp : wasm.imports) { + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { // per wasm dynamic library rules, we expect to see exports on 'env' - if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) { + if (import->module == ENV) { // seek an export on the other side that matches for (auto& exp : outputMergeable.wasm.exports) { - if (exp->kind == imp->kind && exp->name == imp->base) { + if (exp->name == import->base) { // fits! - if (imp->kind == ExternalKind::Function) { - implementedFunctionImports[imp->name] = exp->value; - } else { - implementedGlobalImports[imp->name] = exp->value; - } + implementedFunctionImports[import->name] = exp->value; break; } } } - } + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + // per wasm dynamic library rules, we expect to see exports on 'env' + if (import->module == ENV) { + // seek an export on the other side that matches + for (auto& exp : outputMergeable.wasm.exports) { + if (exp->name == import->base) { + // fits! + implementedGlobalImports[import->name] = exp->value; + break; + } + } + } + }); // remove the unneeded ones for (auto& pair : implementedFunctionImports) { - wasm.removeImport(pair.first); + wasm.removeFunction(pair.first); } for (auto& pair : implementedGlobalImports) { - wasm.removeImport(pair.first); + wasm.removeGlobal(pair.first); } // find new names @@ -364,27 +360,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp return outputMergeable.wasm.getFunctionTypeOrNull(name); }); } - for (auto& curr : wasm.imports) { - if (curr->kind == ExternalKind::Function) { - curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { - return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getFunctionOrNull(name); - }); - } else if (curr->kind == ExternalKind::Global) { - curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { - return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getGlobalOrNull(name); - }); - } - } - for (auto& curr : wasm.functions) { + ModuleUtils::iterImportedFunctions(wasm, [&](Function* curr) { + curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { + return !!outputMergeable.wasm.getFunctionOrNull(name); + }); + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* curr) { + curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { + return !!outputMergeable.wasm.getGlobalOrNull(name); + }); + }); + ModuleUtils::iterDefinedFunctions(wasm, [&](Function* curr) { curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { return outputMergeable.wasm.getFunctionOrNull(name); }); - } - for (auto& curr : wasm.globals) { + }); + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* curr) { curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { return outputMergeable.wasm.getGlobalOrNull(name); }); - } + }); // update global names in input { @@ -403,20 +398,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp } // find function imports in output that are implemented in the input - for (auto& imp : outputMergeable.wasm.imports) { - if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) { + ModuleUtils::iterImportedFunctions(outputMergeable.wasm, [&](Function* import) { + if (import->module == ENV) { for (auto& exp : wasm.exports) { - if (exp->kind == imp->kind && exp->name == imp->base) { - if (imp->kind == ExternalKind::Function) { - outputMergeable.implementedFunctionImports[imp->name] = fNames[exp->value]; - } else { - outputMergeable.implementedGlobalImports[imp->name] = gNames[exp->value]; - } + if (exp->name == import->base) { + outputMergeable.implementedFunctionImports[import->name] = fNames[exp->value]; break; } } } - } + }); + ModuleUtils::iterImportedGlobals(outputMergeable.wasm, [&](Global* import) { + if (import->module == ENV) { + for (auto& exp : wasm.exports) { + if (exp->name == import->base) { + outputMergeable.implementedGlobalImports[import->name] = gNames[exp->value]; + break; + } + } + } + }); // update the output before bringing anything in. avoid doing so when possible, as in the // common case the output module is very large. @@ -448,16 +449,19 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp for (auto& curr : wasm.functionTypes) { outputMergeable.wasm.addFunctionType(curr.release()); } - for (auto& curr : wasm.imports) { - if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) { - continue; // wasm has just 1 of each, they must match + for (auto& curr : wasm.globals) { + if (curr->imported()) { + outputMergeable.wasm.addGlobal(curr.release()); } - // update and add - if (curr->functionType.is()) { - curr->functionType = ftNames[curr->functionType]; - assert(curr->functionType.is()); + } + for (auto& curr : wasm.functions) { + if (curr->imported()) { + if (curr->type.is()) { + curr->type = ftNames[curr->type]; + assert(curr->type.is()); + } + outputMergeable.wasm.addFunction(curr.release()); } - outputMergeable.wasm.addImport(curr.release()); } for (auto& curr : wasm.exports) { if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) { @@ -477,13 +481,21 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp } } } + // Copy over the remaining non-imports (we have already transferred + // the imports, and they are nullptrs). for (auto& curr : wasm.functions) { - curr->type = ftNames[curr->type]; - assert(curr->type.is()); - outputMergeable.wasm.addFunction(curr.release()); + if (curr) { + assert(!curr->imported()); + curr->type = ftNames[curr->type]; + assert(curr->type.is()); + outputMergeable.wasm.addFunction(curr.release()); + } } for (auto& curr : wasm.globals) { - outputMergeable.wasm.addGlobal(curr.release()); + if (curr) { + assert(!curr->imported()); + outputMergeable.wasm.addGlobal(curr.release()); + } } } |