diff options
Diffstat (limited to 'src/passes/LegalizeJSInterface.cpp')
-rw-r--r-- | src/passes/LegalizeJSInterface.cpp | 92 |
1 files changed, 48 insertions, 44 deletions
diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 1f02cd4d7..5f64ad2e6 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -26,11 +26,12 @@ // disallow f32s. TODO: an option to not do that, if it matters? // -#include <wasm.h> -#include <pass.h> -#include <wasm-builder.h> -#include <ir/utils.h> -#include <ir/literal-utils.h> +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/function-type-utils.h" +#include "ir/literal-utils.h" +#include "ir/utils.h" namespace wasm { @@ -44,22 +45,23 @@ struct LegalizeJSInterface : public Pass { for (auto& ex : module->exports) { if (ex->kind == ExternalKind::Function) { // if it's an import, ignore it - if (auto* func = module->getFunctionOrNull(ex->value)) { - if (isIllegal(func)) { - auto legalName = makeLegalStub(func, module); - ex->value = legalName; - } + auto* func = module->getFunction(ex->value); + if (isIllegal(func)) { + auto legalName = makeLegalStub(func, module); + ex->value = legalName; } } } + // Avoid iterator invalidation later. + std::vector<Function*> originalFunctions; + for (auto& func : module->functions) { + originalFunctions.push_back(func.get()); + } // for each illegal import, we must call a legalized stub instead - std::vector<Import*> newImports; // add them at the end, to not invalidate the iter - for (auto& im : module->imports) { - if (im->kind == ExternalKind::Function && isIllegal(module->getFunctionType(im->functionType))) { - Name funcName; - auto* legal = makeLegalStub(im.get(), module, funcName); - illegalToLegal[im->name] = funcName; - newImports.push_back(legal); + for (auto* im : originalFunctions) { + if (im->imported() && isIllegal(module->getFunctionType(im->type))) { + auto funcName = makeLegalStubForCalledImport(im, module); + illegalImportsToLegal[im->name] = funcName; // we need to use the legalized version in the table, as the import from JS // is legal for JS. Our stub makes it look like a native wasm function. for (auto& segment : module->table.segments) { @@ -71,13 +73,9 @@ struct LegalizeJSInterface : public Pass { } } } - if (illegalToLegal.size() > 0) { - for (auto& pair : illegalToLegal) { - module->removeImport(pair.first); - } - - for (auto* im : newImports) { - module->addImport(im); + if (illegalImportsToLegal.size() > 0) { + for (auto& pair : illegalImportsToLegal) { + module->removeFunction(pair.first); } // fix up imports: call_import of an illegal must be turned to a call of a legal @@ -85,15 +83,15 @@ struct LegalizeJSInterface : public Pass { struct FixImports : public WalkerPass<PostWalker<FixImports>> { bool isFunctionParallel() override { return true; } - Pass* create() override { return new FixImports(illegalToLegal); } + Pass* create() override { return new FixImports(illegalImportsToLegal); } - std::map<Name, Name>* illegalToLegal; + std::map<Name, Name>* illegalImportsToLegal; - FixImports(std::map<Name, Name>* illegalToLegal) : illegalToLegal(illegalToLegal) {} + FixImports(std::map<Name, Name>* illegalImportsToLegal) : illegalImportsToLegal(illegalImportsToLegal) {} - void visitCallImport(CallImport* curr) { - auto iter = illegalToLegal->find(curr->target); - if (iter == illegalToLegal->end()) return; + void visitCall(Call* curr) { + auto iter = illegalImportsToLegal->find(curr->target); + if (iter == illegalImportsToLegal->end()) return; if (iter->second == getFunction()->name) return; // inside the stub function itself, is the one safe place to do the call replaceCurrent(Builder(*getModule()).makeCall(iter->second, curr->operands, curr->type)); @@ -102,7 +100,7 @@ struct LegalizeJSInterface : public Pass { PassRunner passRunner(module); passRunner.setIsNested(true); - passRunner.add<FixImports>(&illegalToLegal); + passRunner.add<FixImports>(&illegalImportsToLegal); passRunner.run(); } @@ -113,7 +111,7 @@ struct LegalizeJSInterface : public Pass { private: // map of illegal to legal names for imports - std::map<Name, Name> illegalToLegal; + std::map<Name, Name> illegalImportsToLegal; bool needTempRet0Helpers = false; @@ -179,24 +177,22 @@ private: } // wasm calls the import, so it must call a stub that calls the actual legal JS import - Import* makeLegalStub(Import* im, Module* module, Name& funcName) { + Name makeLegalStubForCalledImport(Function* im, Module* module) { Builder builder(*module); - auto* type = new FunctionType(); + auto* type = new FunctionType; type->name = Name(std::string("legaltype$") + im->name.str); - auto* legal = new Import(); + auto* legal = new Function; legal->name = Name(std::string("legalimport$") + im->name.str); legal->module = im->module; legal->base = im->base; - legal->kind = ExternalKind::Function; - legal->functionType = type->name; - auto* func = new Function(); + legal->type = type->name; + auto* func = new Function; func->name = Name(std::string("legalfunc$") + im->name.str); - funcName = func->name; - auto* call = module->allocator.alloc<CallImport>(); + auto* call = module->allocator.alloc<Call>(); call->target = legal->name; - auto* imFunctionType = module->getFunctionType(im->functionType); + auto* imFunctionType = module->getFunctionType(im->type); for (auto param : imFunctionType->params) { if (param == i64) { @@ -231,10 +227,18 @@ private: type->result = imFunctionType->result; } func->result = imFunctionType->result; + FunctionTypeUtils::fillFunction(legal, type); - module->addFunction(func); - module->addFunctionType(type); - return legal; + if (!module->getFunctionOrNull(func->name)) { + module->addFunction(func); + } + if (!module->getFunctionTypeOrNull(type->name)) { + module->addFunctionType(type); + } + if (!module->getFunctionOrNull(legal->name)) { + module->addFunction(legal); + } + return func->name; } void ensureTempRet0(Module* module) { |