diff options
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/wasm-binary.cpp | 148 | ||||
-rw-r--r-- | src/wasm/wasm-emscripten.cpp | 124 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 235 | ||||
-rw-r--r-- | src/wasm/wasm-stack.cpp | 4 | ||||
-rw-r--r-- | src/wasm/wasm-type.cpp | 25 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 136 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 91 |
7 files changed, 285 insertions, 478 deletions
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index b8b001927..d86eea8f5 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -30,56 +30,7 @@ namespace wasm { void WasmBinaryWriter::prepare() { // Collect function types and their frequencies. Collect information in each // function in parallel, then merge. - typedef std::unordered_map<Signature, size_t> Counts; - ModuleUtils::ParallelFunctionAnalysis<Counts> analysis( - *wasm, [&](Function* func, Counts& counts) { - if (func->imported()) { - return; - } - struct TypeCounter : PostWalker<TypeCounter> { - Module& wasm; - Counts& counts; - - TypeCounter(Module& wasm, Counts& counts) - : wasm(wasm), counts(counts) {} - - void visitCallIndirect(CallIndirect* curr) { - auto* type = wasm.getFunctionType(curr->fullType); - Signature sig(Type(type->params), type->result); - counts[sig]++; - } - }; - TypeCounter(*wasm, counts).walk(func->body); - }); - // Collect all the counts. - Counts counts; - for (auto& curr : wasm->functions) { - counts[Signature(Type(curr->params), curr->result)]++; - } - for (auto& curr : wasm->events) { - counts[curr->sig]++; - } - for (auto& pair : analysis.map) { - Counts& functionCounts = pair.second; - for (auto& innerPair : functionCounts) { - counts[innerPair.first] += innerPair.second; - } - } - std::vector<std::pair<Signature, size_t>> sorted(counts.begin(), - counts.end()); - std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { - // order by frequency then simplicity - if (a.second != b.second) { - return a.second > b.second; - } else { - return a.first < b.first; - } - }); - for (Index i = 0; i < sorted.size(); ++i) { - typeIndexes[sorted[i].first] = i; - types.push_back(sorted[i].first); - } - + ModuleUtils::collectSignatures(*wasm, types, typeIndices); importInfo = wasm::make_unique<ImportInfo>(*wasm); } @@ -250,7 +201,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one function\n"); writeImportHeader(func); o << U32LEB(int32_t(ExternalKind::Function)); - o << U32LEB(getTypeIndex(Signature(Type(func->params), func->result))); + o << U32LEB(getTypeIndex(func->sig)); }); ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) { BYN_TRACE("write one global\n"); @@ -297,7 +248,7 @@ void WasmBinaryWriter::writeFunctionSignatures() { o << U32LEB(importInfo->getNumDefinedFunctions()); ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { BYN_TRACE("write one\n"); - o << U32LEB(getTypeIndex(Signature(Type(func->params), func->result))); + o << U32LEB(getTypeIndex(func->sig)); }); finishSection(start); } @@ -458,8 +409,8 @@ uint32_t WasmBinaryWriter::getEventIndex(Name name) const { } uint32_t WasmBinaryWriter::getTypeIndex(Signature sig) const { - auto it = typeIndexes.find(sig); - assert(it != typeIndexes.end()); + auto it = typeIndices.find(sig); + assert(it != typeIndices.end()); return it->second; } @@ -1124,7 +1075,8 @@ void WasmBinaryBuilder::readSignatures() { BYN_TRACE("num: " << numTypes << std::endl); for (size_t i = 0; i < numTypes; i++) { BYN_TRACE("read one\n"); - auto curr = make_unique<FunctionType>(); + std::vector<Type> params; + std::vector<Type> results; auto form = getS32LEB(); if (form != BinaryConsts::EncodedType::Func) { throwError("bad signature form " + std::to_string(form)); @@ -1132,19 +1084,14 @@ void WasmBinaryBuilder::readSignatures() { size_t numParams = getU32LEB(); BYN_TRACE("num params: " << numParams << std::endl); for (size_t j = 0; j < numParams; j++) { - curr->params.push_back(getConcreteType()); + params.push_back(getConcreteType()); } auto numResults = getU32LEB(); - if (numResults == 0) { - curr->result = none; - } else { - if (numResults != 1) { - throwError("signature must have 1 result"); - } - curr->result = getType(); + BYN_TRACE("num results: " << numResults << std::endl); + for (size_t j = 0; j < numResults; j++) { + results.push_back(getConcreteType()); } - curr->name = Name::fromInt(wasm.functionTypes.size()); - wasm.addFunctionType(std::move(curr)); + signatures.emplace_back(Type(params), Type(results)); } } @@ -1205,17 +1152,13 @@ void WasmBinaryBuilder::readImports() { case ExternalKind::Function: { auto name = Name(std::string("fimport$") + std::to_string(i)); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index > signatures.size()) { throwError("invalid function index " + std::to_string(index) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - auto* functionType = wasm.functionTypes[index].get(); - auto params = functionType->params; - auto result = functionType->result; - auto* curr = builder.makeFunction(name, std::move(params), result, {}); + auto* curr = builder.makeFunction(name, signatures[index], {}); curr->module = module; curr->base = base; - curr->type = functionType->name; wasm.addFunction(curr); functionImports.push_back(curr); break; @@ -1267,13 +1210,11 @@ void WasmBinaryBuilder::readImports() { auto name = Name(std::string("eimport$") + std::to_string(i)); auto attribute = getU32LEB(); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("invalid event index " + std::to_string(index) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - Type params = Type(wasm.functionTypes[index]->params); - auto* curr = - builder.makeEvent(name, attribute, Signature(params, Type::none)); + auto* curr = builder.makeEvent(name, attribute, signatures[index]); curr->module = module; curr->base = base; wasm.addEvent(curr); @@ -1302,17 +1243,17 @@ void WasmBinaryBuilder::readFunctionSignatures() { for (size_t i = 0; i < num; i++) { BYN_TRACE("read one\n"); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("invalid function type index for function"); } - functionTypes.push_back(wasm.functionTypes[index].get()); + functionSignatures.push_back(signatures[index]); } } void WasmBinaryBuilder::readFunctions() { BYN_TRACE("== readFunctions\n"); size_t total = getU32LEB(); - if (total != functionTypes.size()) { + if (total != functionSignatures.size()) { throwError("invalid function section size, must equal types"); } for (size_t i = 0; i < total; i++) { @@ -1325,17 +1266,12 @@ void WasmBinaryBuilder::readFunctions() { Function* func = new Function; func->name = Name::fromInt(i); + func->sig = functionSignatures[i]; currFunction = func; readNextDebugLocation(); - auto type = functionTypes[i]; BYN_TRACE("reading " << i << std::endl); - func->type = type->name; - func->result = type->result; - for (size_t j = 0; j < type->params.size(); j++) { - func->params.emplace_back(type->params[j]); - } size_t numLocalTypes = getU32LEB(); for (size_t t = 0; t < numLocalTypes; t++) { auto num = getU32LEB(); @@ -1357,7 +1293,7 @@ void WasmBinaryBuilder::readFunctions() { assert(breakStack.empty()); assert(expressionStack.empty()); assert(depth == 0); - func->body = getBlockOrSingleton(func->result); + func->body = getBlockOrSingleton(func->sig.results); assert(depth == 0); assert(breakStack.size() == 0); assert(breakTargetNames.size() == 0); @@ -1391,7 +1327,7 @@ void WasmBinaryBuilder::readExports() { names.insert(curr->name); curr->kind = (ExternalKind)getU32LEB(); auto index = getU32LEB(); - exportIndexes[curr] = index; + exportIndices[curr] = index; exportOrder.push_back(curr); } } @@ -1753,7 +1689,7 @@ void WasmBinaryBuilder::processFunctions() { } for (auto* curr : exportOrder) { - auto index = exportIndexes[curr]; + auto index = exportIndices[curr]; switch (curr->kind) { case ExternalKind::Function: { curr->value = getFunctionName(index); @@ -1787,8 +1723,8 @@ void WasmBinaryBuilder::processFunctions() { for (auto& pair : functionTable) { auto i = pair.first; - auto& indexes = pair.second; - for (auto j : indexes) { + auto& indices = pair.second; + for (auto j : indices) { wasm.table.segments[i].data.push_back(getFunctionName(j)); } } @@ -1884,13 +1820,12 @@ void WasmBinaryBuilder::readEvents() { BYN_TRACE("read one\n"); auto attribute = getU32LEB(); auto typeIndex = getU32LEB(); - if (typeIndex >= wasm.functionTypes.size()) { + if (typeIndex >= signatures.size()) { throwError("invalid event index " + std::to_string(typeIndex) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - Type params = Type(wasm.functionTypes[typeIndex]->params); wasm.addEvent(Builder::makeEvent( - "event$" + std::to_string(i), attribute, Signature(params, Type::none))); + "event$" + std::to_string(i), attribute, signatures[typeIndex])); } } @@ -2463,24 +2398,23 @@ void WasmBinaryBuilder::visitSwitch(Switch* curr) { void WasmBinaryBuilder::visitCall(Call* curr) { BYN_TRACE("zz node: Call\n"); auto index = getU32LEB(); - FunctionType* type; + Signature sig; if (index < functionImports.size()) { auto* import = functionImports[index]; - type = wasm.getFunctionType(import->type); + sig = import->sig; } else { Index adjustedIndex = index - functionImports.size(); - if (adjustedIndex >= functionTypes.size()) { + if (adjustedIndex >= functionSignatures.size()) { throwError("invalid call index"); } - type = functionTypes[adjustedIndex]; + sig = functionSignatures[adjustedIndex]; } - assert(type); - auto num = type->params.size(); + auto num = sig.params.size(); curr->operands.resize(num); for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->type = type->result; + curr->type = sig.results; functionCalls[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -2488,22 +2422,20 @@ void WasmBinaryBuilder::visitCall(Call* curr) { void WasmBinaryBuilder::visitCallIndirect(CallIndirect* curr) { BYN_TRACE("zz node: CallIndirect\n"); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("bad call_indirect function index"); } - auto* fullType = wasm.functionTypes[index].get(); + curr->sig = signatures[index]; auto reserved = getU32LEB(); if (reserved != 0) { throwError("Invalid flags field in call_indirect"); } - curr->fullType = fullType->name; - auto num = fullType->params.size(); + auto num = curr->sig.params.size(); curr->operands.resize(num); curr->target = popNonVoidExpression(); for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->type = fullType->result; curr->finalize(); } @@ -4299,7 +4231,7 @@ void WasmBinaryBuilder::visitSelect(Select* curr) { void WasmBinaryBuilder::visitReturn(Return* curr) { BYN_TRACE("zz node: Return\n"); requireFunctionContext("return"); - if (currFunction->result != none) { + if (currFunction->sig.results != Type::none) { curr->value = popNonVoidExpression(); } curr->finalize(); diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 695561f68..9a0b145da 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -20,7 +20,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "ir/import-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" @@ -207,7 +206,7 @@ void EmscriptenGlueGenerator::generateRuntimeFunctions() { } static Function* -ensureFunctionImport(Module* module, Name name, std::string sig) { +ensureFunctionImport(Module* module, Name name, Signature sig) { // Then see if its already imported ImportInfo info(*module); if (Function* f = info.getImportedFunction(ENV, name)) { @@ -218,9 +217,7 @@ ensureFunctionImport(Module* module, Name name, std::string sig) { import->name = name; import->module = ENV; import->base = name; - auto* functionType = ensureFunctionType(sig, module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = sig; module->addFunction(import); return import; } @@ -267,7 +264,7 @@ Function* EmscriptenGlueGenerator::generateAssignGOTEntriesFunction() { for (Global* g : gotMemEntries) { Name getter(std::string("g$") + g->base.c_str()); - ensureFunctionImport(&wasm, getter, "i"); + ensureFunctionImport(&wasm, getter, Signature(Type::none, Type::i32)); Expression* call = builder.makeCall(getter, {}, i32); GlobalSet* globalSet = builder.makeGlobalSet(g->name, call); block->list.push_back(globalSet); @@ -293,7 +290,7 @@ Function* EmscriptenGlueGenerator::generateAssignGOTEntriesFunction() { Name getter( (std::string("fp$") + g->base.c_str() + std::string("$") + getSig(f)) .c_str()); - ensureFunctionImport(&wasm, getter, "i"); + ensureFunctionImport(&wasm, getter, Signature(Type::none, Type::i32)); Expression* call = builder.makeCall(getter, {}, i32); GlobalSet* globalSet = builder.makeGlobalSet(g->name, call); block->list.push_back(globalSet); @@ -371,29 +368,28 @@ inline void exportFunction(Module& wasm, Name name, bool must_export) { wasm.addExport(exp); } -void EmscriptenGlueGenerator::generateDynCallThunk(std::string sig) { - auto* funcType = ensureFunctionType(sig, &wasm); +void EmscriptenGlueGenerator::generateDynCallThunk(Signature sig) { if (!sigs.insert(sig).second) { return; // sig is already in the set } - Name name = std::string("dynCall_") + sig; + Name name = std::string("dynCall_") + getSig(sig.results, sig.params); if (wasm.getFunctionOrNull(name) || wasm.getExportOrNull(name)) { return; // module already contains this dyncall } std::vector<NameType> params; params.emplace_back("fptr", i32); // function pointer param int p = 0; - for (const auto& ty : funcType->params) { + const std::vector<Type>& paramTypes = sig.params.expand(); + for (const auto& ty : paramTypes) { params.emplace_back(std::to_string(p++), ty); } - Function* f = - builder.makeFunction(name, std::move(params), funcType->result, {}); + Function* f = builder.makeFunction(name, std::move(params), sig.results, {}); Expression* fptr = builder.makeLocalGet(0, i32); std::vector<Expression*> args; - for (unsigned i = 0; i < funcType->params.size(); ++i) { - args.push_back(builder.makeLocalGet(i + 1, funcType->params[i])); + for (unsigned i = 0; i < paramTypes.size(); ++i) { + args.push_back(builder.makeLocalGet(i + 1, paramTypes[i])); } - Expression* call = builder.makeCallIndirect(funcType, fptr, args); + Expression* call = builder.makeCallIndirect(fptr, args, sig); f->body = call; wasm.addFunction(f); @@ -407,8 +403,7 @@ void EmscriptenGlueGenerator::generateDynCallThunks() { tableSegmentData = wasm.table.segments[0].data; } for (const auto& indirectFunc : tableSegmentData) { - std::string sig = getSig(wasm.getFunction(indirectFunc)); - generateDynCallThunk(sig); + generateDynCallThunk(wasm.getFunction(indirectFunc)->sig); } } @@ -479,10 +474,11 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() { RemoveStackPointer walker(stackPointer); walker.walkModule(&wasm); if (walker.needStackSave) { - ensureFunctionImport(&wasm, STACK_SAVE, "i"); + ensureFunctionImport(&wasm, STACK_SAVE, Signature(Type::none, Type::i32)); } if (walker.needStackRestore) { - ensureFunctionImport(&wasm, STACK_RESTORE, "vi"); + ensureFunctionImport( + &wasm, STACK_RESTORE, Signature(Type::i32, Type::none)); } // Finally remove the stack pointer global itself. This avoids importing @@ -545,7 +541,7 @@ void EmscriptenGlueGenerator::enforceStackLimit() { void EmscriptenGlueGenerator::generateSetStackLimitFunction() { Function* function = - builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {}); + builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); LocalGet* getArg = builder.makeLocalGet(0, i32); Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); function->body = store; @@ -562,9 +558,7 @@ Name EmscriptenGlueGenerator::importStackOverflowHandler() { import->name = STACK_OVERFLOW_IMPORT; import->module = ENV; import->base = STACK_OVERFLOW_IMPORT; - auto* functionType = ensureFunctionType("v", &wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::none); wasm.addFunction(import); return STACK_OVERFLOW_IMPORT; } @@ -696,14 +690,14 @@ struct AsmConstWalker : public LinearExecutionWalker<AsmConstWalker> { std::vector<Address> segmentOffsets; // segment index => address offset struct AsmConst { - std::set<std::string> sigs; + std::set<Signature> sigs; Address id; std::string code; Proxying proxy; }; std::vector<AsmConst> asmConsts; - std::set<std::pair<std::string, Proxying>> allSigs; + std::set<std::pair<Signature, Proxying>> allSigs; // last sets in the current basic block, per index std::map<Index, LocalSet*> sets; @@ -719,12 +713,12 @@ struct AsmConstWalker : public LinearExecutionWalker<AsmConstWalker> { void process(); private: - std::string fixupName(Name& name, std::string baseSig, Proxying proxy); + Signature fixupName(Name& name, Signature baseSig, Proxying proxy); AsmConst& - createAsmConst(uint32_t id, std::string code, std::string sig, Name name); - std::string asmConstSig(std::string baseSig); - Name nameForImportWithSig(std::string sig, Proxying proxy); - void queueImport(Name importName, std::string baseSig); + createAsmConst(uint32_t id, std::string code, Signature sig, Name name); + Signature asmConstSig(Signature baseSig); + Name nameForImportWithSig(Signature sig, Proxying proxy); + void queueImport(Name importName, Signature baseSig); void addImports(); Proxying proxyType(Name name); @@ -750,7 +744,7 @@ void AsmConstWalker::visitCall(Call* curr) { return; } - auto baseSig = getSig(curr); + auto baseSig = wasm.getFunction(curr->target)->sig; auto sig = asmConstSig(baseSig); auto* arg = curr->operands[0]; while (!arg->dynCast<Const>()) { @@ -816,9 +810,8 @@ void AsmConstWalker::visitTable(Table* curr) { for (auto& name : segment.data) { auto* func = wasm.getFunction(name); if (func->imported() && func->base.hasSubstring(EM_ASM_PREFIX)) { - std::string baseSig = getSig(func); auto proxy = proxyType(func->base); - fixupName(name, baseSig, proxy); + fixupName(name, func->sig, proxy); } } } @@ -832,8 +825,8 @@ void AsmConstWalker::process() { addImports(); } -std::string -AsmConstWalker::fixupName(Name& name, std::string baseSig, Proxying proxy) { +Signature +AsmConstWalker::fixupName(Name& name, Signature baseSig, Proxying proxy) { auto sig = asmConstSig(baseSig); auto importName = nameForImportWithSig(sig, proxy); name = importName; @@ -848,7 +841,7 @@ AsmConstWalker::fixupName(Name& name, std::string baseSig, Proxying proxy) { AsmConstWalker::AsmConst& AsmConstWalker::createAsmConst(uint32_t id, std::string code, - std::string sig, + Signature sig, Name name) { AsmConst asmConst; asmConst.id = id; @@ -859,31 +852,27 @@ AsmConstWalker::AsmConst& AsmConstWalker::createAsmConst(uint32_t id, return asmConsts.back(); } -std::string AsmConstWalker::asmConstSig(std::string baseSig) { - std::string sig = ""; - for (size_t i = 0; i < baseSig.size(); ++i) { - // Omit the signature of the "code" parameter, taken as a string, as the - // first argument - if (i != 1) { - sig += baseSig[i]; - } - } - return sig; +Signature AsmConstWalker::asmConstSig(Signature baseSig) { + std::vector<Type> params = baseSig.params.expand(); + assert(params.size() >= 1); + // Omit the signature of the "code" parameter, taken as a string, as the + // first argument + params.erase(params.begin()); + return Signature(Type(params), baseSig.results); } -Name AsmConstWalker::nameForImportWithSig(std::string sig, Proxying proxy) { - std::string fixedTarget = - EM_ASM_PREFIX.str + std::string("_") + proxyingSuffix(proxy) + sig; +Name AsmConstWalker::nameForImportWithSig(Signature sig, Proxying proxy) { + std::string fixedTarget = EM_ASM_PREFIX.str + std::string("_") + + proxyingSuffix(proxy) + + getSig(sig.results, sig.params); return Name(fixedTarget.c_str()); } -void AsmConstWalker::queueImport(Name importName, std::string baseSig) { +void AsmConstWalker::queueImport(Name importName, Signature baseSig) { auto import = new Function; import->name = import->base = importName; import->module = ENV; - auto* funcType = ensureFunctionType(baseSig, &wasm); - import->type = funcType->name; - FunctionTypeUtils::fillFunction(import, funcType); + import->sig = baseSig; queuedImports.push_back(std::unique_ptr<Function>(import)); } @@ -994,7 +983,7 @@ struct FixInvokeFunctionNamesWalker std::map<Name, Name> importRenames; std::vector<Name> toRemove; std::set<Name> newImports; - std::set<std::string> invokeSigs; + std::set<Signature> invokeSigs; FixInvokeFunctionNamesWalker(Module& _wasm) : wasm(_wasm) {} @@ -1017,7 +1006,7 @@ struct FixInvokeFunctionNamesWalker // This function converts the names of invoke wrappers based on their lowered // argument types and a return type. In the example above, the resulting new // wrapper name becomes "invoke_vii". - Name fixEmExceptionInvoke(const Name& name, const std::string& sig) { + Name fixEmExceptionInvoke(const Name& name, Signature sig) { std::string nameStr = name.c_str(); if (nameStr.front() == '"' && nameStr.back() == '"') { nameStr = nameStr.substr(1, nameStr.size() - 2); @@ -1025,12 +1014,16 @@ struct FixInvokeFunctionNamesWalker if (nameStr.find("__invoke_") != 0) { return name; } - std::string sigWoOrigFunc = sig.front() + sig.substr(2, sig.size() - 2); + + const std::vector<Type>& params = sig.params.expand(); + std::vector<Type> newParams(params.begin() + 1, params.end()); + Signature sigWoOrigFunc = Signature(Type(newParams), sig.results); invokeSigs.insert(sigWoOrigFunc); - return Name("invoke_" + sigWoOrigFunc); + return Name("invoke_" + + getSig(sigWoOrigFunc.results, sigWoOrigFunc.params)); } - Name fixEmEHSjLjNames(const Name& name, const std::string& sig) { + Name fixEmEHSjLjNames(const Name& name, Signature sig) { if (name == "emscripten_longjmp_jmpbuf") { return "emscripten_longjmp"; } @@ -1042,8 +1035,7 @@ struct FixInvokeFunctionNamesWalker return; } - FunctionType* func = wasm.getFunctionType(curr->type); - Name newname = fixEmEHSjLjNames(curr->base, getSig(func)); + Name newname = fixEmEHSjLjNames(curr->base, curr->sig); if (newname == curr->base) { return; } @@ -1085,16 +1077,16 @@ void EmscriptenGlueGenerator::fixInvokeFunctionNames() { } } -template<class C> void printSet(std::ostream& o, C& c) { +void printSignatures(std::ostream& o, const std::set<Signature>& c) { o << "["; bool first = true; - for (auto& item : c) { + for (auto& sig : c) { if (first) { first = false; } else { o << ","; } - o << '"' << item << '"'; + o << '"' << getSig(sig.results, sig.params) << '"'; } o << "]"; } @@ -1123,7 +1115,7 @@ std::string EmscriptenGlueGenerator::generateEmscriptenMetadata( for (auto& asmConst : emAsmWalker.asmConsts) { meta << nextElement(); meta << '"' << asmConst.id << "\": [\"" << asmConst.code << "\", "; - printSet(meta, asmConst.sigs); + printSignatures(meta, asmConst.sigs); meta << ", [\"" << proxyingSuffix(asmConst.proxy) << "\"]"; meta << "]"; @@ -1303,7 +1295,7 @@ void EmscriptenGlueGenerator::exportWasiStart() { {LiteralUtils::makeZero(i32, wasm), LiteralUtils::makeZero(i32, wasm)}, i32)); auto* func = - builder.makeFunction(_start, std::vector<wasm::Type>{}, none, {}, body); + builder.makeFunction(_start, Signature(Type::none, Type::none), {}, body); wasm.addFunction(func); wasm.addExport(builder.makeExport(_start, _start, ExternalKind::Function)); } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 24a4fcb35..10b6aead7 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -23,7 +23,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" #include "ir/branch-utils.h" -#include "ir/function-type-utils.h" #include "shared-constants.h" #include "wasm-binary.h" @@ -452,16 +451,20 @@ Name SExpressionWasmBuilder::getFunctionName(Element& s) { } } -Name SExpressionWasmBuilder::getFunctionTypeName(Element& s) { +Signature SExpressionWasmBuilder::getFunctionSignature(Element& s) { if (s.dollared()) { - return s.str(); + auto it = signatureIndices.find(s.str().str); + if (it == signatureIndices.end()) { + throw ParseException("unknown function type in getFunctionSignature"); + } + return signatures[it->second]; } else { // index size_t offset = atoi(s.str().c_str()); - if (offset >= wasm.functionTypes.size()) { - throw ParseException("unknown function type in getFunctionTypeName"); + if (offset >= signatures.size()) { + throw ParseException("unknown function type in getFunctionSignature"); } - return wasm.functionTypes[offset]->name; + return signatures[offset]; } } @@ -539,7 +542,7 @@ SExpressionWasmBuilder::parseParamOrLocal(Element& s, size_t& localIndex) { } // Parses (result type) element. (e.g. (result i32)) -Type SExpressionWasmBuilder::parseResult(Element& s) { +Type SExpressionWasmBuilder::parseResults(Element& s) { assert(elementStartsWith(s, RESULT)); if (s.size() != 2) { throw ParseException("invalid result arity", s.line, s.col); @@ -550,131 +553,95 @@ Type SExpressionWasmBuilder::parseResult(Element& s) { // Parses an element that references an entry in the type section. The element // should be in the form of (type name) or (type index). // (e.g. (type $a), (type 0)) -FunctionType* SExpressionWasmBuilder::parseTypeRef(Element& s) { +Signature SExpressionWasmBuilder::parseTypeRef(Element& s) { assert(elementStartsWith(s, TYPE)); if (s.size() != 2) { throw ParseException("invalid type reference", s.line, s.col); } - IString name = getFunctionTypeName(*s[1]); - FunctionType* functionType = wasm.getFunctionTypeOrNull(name); - if (!functionType) { - throw ParseException("bad function type for import", s[1]->line, s[1]->col); - } - return functionType; + return getFunctionSignature(*s[1]); } // Prases typeuse, a reference to a type definition. It is in the form of either // (type index) or (type name), possibly augmented by inlined (param) and -// (result) nodes. (type) node can be omitted as well, in which case we get an -// existing type if there's one with the same structure or create one. -// Outputs are returned by parameter references. +// (result) nodes. (type) node can be omitted as well. Outputs are returned by +// parameter references. // typeuse ::= (type index|name)+ | // (type index|name)+ (param ..)* (result ..)* | // (param ..)* (result ..)* -// TODO Remove FunctionType* parameter and the related logic to create -// FunctionType once we remove FunctionType class. -size_t SExpressionWasmBuilder::parseTypeUse(Element& s, - size_t startPos, - FunctionType*& functionType, - std::vector<NameType>& namedParams, - Type& result) { +size_t +SExpressionWasmBuilder::parseTypeUse(Element& s, + size_t startPos, + Signature& functionSignature, + std::vector<NameType>& namedParams) { + std::vector<Type> params, results; size_t i = startPos; - bool typeExists = false, paramOrResultExists = false; + + bool typeExists = false, paramsOrResultsExist = false; if (i < s.size() && elementStartsWith(*s[i], TYPE)) { typeExists = true; - functionType = parseTypeRef(*s[i++]); + functionSignature = parseTypeRef(*s[i++]); } - size_t paramPos = i; + size_t paramPos = i; size_t localIndex = 0; + while (i < s.size() && elementStartsWith(*s[i], PARAM)) { - paramOrResultExists = true; + paramsOrResultsExist = true; auto newParams = parseParamOrLocal(*s[i++], localIndex); namedParams.insert(namedParams.end(), newParams.begin(), newParams.end()); + for (auto p : newParams) { + params.push_back(p.type); + } } - result = none; - if (i < s.size() && elementStartsWith(*s[i], RESULT)) { - paramOrResultExists = true; - result = parseResult(*s[i++]); + + while (i < s.size() && elementStartsWith(*s[i], RESULT)) { + paramsOrResultsExist = true; + // TODO: make parseResults return a vector + results.push_back(parseResults(*s[i++])); } + + auto inlineSig = Signature(Type(params), Type(results)); + // If none of type/param/result exists, this is equivalent to a type that does // not have parameters and returns nothing. - if (!typeExists && !paramOrResultExists) { - paramOrResultExists = true; + if (!typeExists && !paramsOrResultsExist) { + paramsOrResultsExist = true; } - // verify if (type) and (params)/(result) match, if both are specified - if (typeExists && paramOrResultExists) { - size_t line = s[paramPos]->line, col = s[paramPos]->col; - const char* msg = "type and param/result don't match"; - if (functionType->result != result) { - throw ParseException(msg, line, col); - } - if (functionType->params.size() != namedParams.size()) { - throw ParseException(msg, line, col); - } - for (size_t i = 0, n = namedParams.size(); i < n; i++) { - if (functionType->params[i] != namedParams[i].type) { - throw ParseException(msg, line, col); - } + if (!typeExists) { + functionSignature = inlineSig; + } else if (paramsOrResultsExist) { + // verify that (type) and (params)/(result) match + if (inlineSig != functionSignature) { + throw ParseException("type and param/result don't match", + s[paramPos]->line, + s[paramPos]->col); } } - // If only (param)/(result) is specified, check if there's a matching type, - // and if there isn't, create one. - if (!typeExists) { - bool need = true; - std::vector<Type> params; - for (auto& p : namedParams) { - params.push_back(p.type); - } - for (auto& existing : wasm.functionTypes) { - if (existing->structuralComparison(params, result)) { - functionType = existing.get(); - need = false; - break; - } - } - if (need) { - functionType = ensureFunctionType(params, result, &wasm); - } + // Add implicitly defined type to global list so it has an index + if (std::find(signatures.begin(), signatures.end(), functionSignature) == + signatures.end()) { + signatures.push_back(functionSignature); } - // If only (type) is specified, populate params and result. - if (!paramOrResultExists) { - assert(functionType); - result = functionType->result; - for (size_t index = 0, e = functionType->params.size(); index < e; - index++) { - Type type = functionType->params[index]; - namedParams.emplace_back(Name::fromInt(index), type); + // If only (type) is specified, populate `namedParams` + if (!paramsOrResultsExist) { + const std::vector<Type>& funcParams = functionSignature.params.expand(); + for (size_t index = 0, e = funcParams.size(); index < e; index++) { + namedParams.emplace_back(Name::fromInt(index), funcParams[index]); } } return i; } -// Parses a typeuse. Ignores all parameter names. -size_t SExpressionWasmBuilder::parseTypeUse(Element& s, - size_t startPos, - FunctionType*& functionType, - std::vector<Type>& params, - Type& result) { - std::vector<NameType> namedParams; - size_t nextPos = parseTypeUse(s, startPos, functionType, namedParams, result); - for (auto& p : namedParams) { - params.push_back(p.type); - } - return nextPos; -} - // Parses a typeuse. Use this when only FunctionType* is needed. size_t SExpressionWasmBuilder::parseTypeUse(Element& s, size_t startPos, - FunctionType*& functionType) { - std::vector<Type> params; - Type result; - return parseTypeUse(s, startPos, functionType, params, result); + Signature& functionSignature) { + std::vector<NameType> params; + return parseTypeUse(s, startPos, functionSignature, params); } void SExpressionWasmBuilder::preParseFunctionType(Element& s) { @@ -694,11 +661,9 @@ void SExpressionWasmBuilder::preParseFunctionType(Element& s) { } functionNames.push_back(name); functionCounter++; - FunctionType* type = nullptr; - std::vector<Type> params; - parseTypeUse(s, i, type); - assert(type && "type should've been set by parseTypeUse"); - functionTypes[name] = type->result; + Signature sig; + parseTypeUse(s, i, sig); + functionTypes[name] = sig.results; } size_t SExpressionWasmBuilder::parseFunctionNames(Element& s, @@ -771,11 +736,9 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { } // parse typeuse: type/param/result - FunctionType* functionType = nullptr; + Signature sig; std::vector<NameType> params; - Type result = none; - i = parseTypeUse(s, i, functionType, params, result); - assert(functionType && "functionType should've been set by parseTypeUse"); + i = parseTypeUse(s, i, sig, params); // when (import) is inside a (func) element, this is not a function definition // but an import. @@ -790,9 +753,8 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { im->name = name; im->module = importModule; im->base = importBase; - im->type = functionType->name; - FunctionTypeUtils::fillFunction(im.get(), functionType); - functionTypes[name] = im->result; + im->sig = sig; + functionTypes[name] = sig.results; if (wasm.getFunctionOrNull(im->name)) { throw ParseException("duplicate import", s.line, s.col); } @@ -808,7 +770,6 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { throw ParseException("preParseImport in func"); } - result = functionType->result; size_t localIndex = params.size(); // local index for params and locals // parse locals @@ -820,8 +781,7 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { // make a new function currFunction = std::unique_ptr<Function>(Builder(wasm).makeFunction( - name, std::move(params), result, std::move(vars))); - currFunction->type = functionType->name; + name, std::move(params), sig.results, std::move(vars))); // parse body Block* autoBlock = nullptr; // may need to add a block for the very top level @@ -847,14 +807,11 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { autoBlock->name = FAKE_RETURN; } if (autoBlock) { - autoBlock->finalize(result); + autoBlock->finalize(sig.results); } if (!currFunction->body) { currFunction->body = allocator.alloc<Nop>(); } - if (currFunction->result != result) { - throw ParseException("bad func declaration", s.line, s.col); - } if (s.startLoc) { currFunction->prologLocation.insert(getDebugLocation(*s.startLoc)); } @@ -1677,13 +1634,9 @@ Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s, if (!wasm.table.exists) { throw ParseException("no table"); } - auto ret = allocator.alloc<CallIndirect>(); Index i = 1; - FunctionType* functionType = nullptr; - i = parseTypeUse(s, i, functionType); - assert(functionType && "functionType should've been set by parseTypeUse"); - ret->fullType = functionType->name; - ret->type = functionType->result; + auto ret = allocator.alloc<CallIndirect>(); + i = parseTypeUse(s, i, ret->sig); parseCallOperands(s, i, s.size() - 1, ret); ret->target = parseExpression(s[s.size() - 1]); ret->isReturn = isReturn; @@ -2135,14 +2088,13 @@ void SExpressionWasmBuilder::parseImport(Element& s) { Element& inner = newStyle ? *s[3] : s; Index j = newStyle ? newStyleInner : i; if (kind == ExternalKind::Function) { - FunctionType* functionType = nullptr; auto func = make_unique<Function>(); - j = parseTypeUse(inner, j, functionType, func->params, func->result); + + j = parseTypeUse(inner, j, func->sig); func->name = name; func->module = module; func->base = base; - func->type = functionType->name; - functionTypes[name] = func->result; + functionTypes[name] = func->sig.results; wasm.addFunction(func.release()); } else if (kind == ExternalKind::Global) { Type type; @@ -2202,14 +2154,10 @@ void SExpressionWasmBuilder::parseImport(Element& s) { throw ParseException("invalid attribute", attrElem.line, attrElem.col); } event->attribute = atoi(attrElem[1]->c_str()); - std::vector<Type> paramTypes; - FunctionType* fakeFunctionType; // just to call parseTypeUse - Type results; - j = parseTypeUse(inner, j, fakeFunctionType, paramTypes, results); + j = parseTypeUse(inner, j, event->sig); event->name = name; event->module = module; event->base = base; - event->sig = Signature(Type(paramTypes), results); wasm.addEvent(event.release()); } // If there are more elements, they are invalid @@ -2406,10 +2354,15 @@ void SExpressionWasmBuilder::parseInnerElem(Element& s, } void SExpressionWasmBuilder::parseType(Element& s) { - std::unique_ptr<FunctionType> type = make_unique<FunctionType>(); + std::vector<Type> params; + std::vector<Type> results; size_t i = 1; if (s[i]->isStr()) { - type->name = s[i]->str(); + std::string name = s[i]->str().str; + if (signatureIndices.find(name) != signatureIndices.end()) { + throw ParseException("duplicate function type", s.line, s.col); + } + signatureIndices[name] = signatures.size(); i++; } Element& func = *s[i]; @@ -2417,25 +2370,13 @@ void SExpressionWasmBuilder::parseType(Element& s) { Element& curr = *func[k]; if (elementStartsWith(curr, PARAM)) { auto newParams = parseParamOrLocal(curr); - type->params.insert( - type->params.end(), newParams.begin(), newParams.end()); + params.insert(params.end(), newParams.begin(), newParams.end()); } else if (elementStartsWith(curr, RESULT)) { - type->result = parseResult(curr); + // TODO: Parse multiple results at once + results.push_back(parseResults(curr)); } } - while (type->name.is() && wasm.getFunctionTypeOrNull(type->name)) { - throw ParseException("duplicate function type", s.line, s.col); - } - // We allow duplicate types in the type section, i.e., we can have - // (func (param i32) (result i32)) many times. For unnamed types, find a name - // that does not clash with existing ones. - if (!type->name.is()) { - type->name = "FUNCSIG$" + getSig(type.get()); - } - while (wasm.getFunctionTypeOrNull(type->name)) { - type->name = Name(std::string(type->name.c_str()) + "_"); - } - wasm.addFunctionType(std::move(type)); + signatures.emplace_back(Type(params), Type(results)); } void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) { @@ -2515,11 +2456,7 @@ void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) { event->attribute = atoi(attrElem[1]->c_str()); // Parse typeuse - std::vector<Type> paramTypes; - Type results; - FunctionType* fakeFunctionType; // just co call parseTypeUse - i = parseTypeUse(s, i, fakeFunctionType, paramTypes, results); - event->sig = Signature(Type(paramTypes), results); + i = parseTypeUse(s, i, event->sig); // If there are more elements, they are invalid if (i < s.size()) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 0f12a2f2b..abaf17e05 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -68,9 +68,7 @@ void BinaryInstWriter::visitCall(Call* curr) { void BinaryInstWriter::visitCallIndirect(CallIndirect* curr) { int8_t op = curr->isReturn ? BinaryConsts::RetCallIndirect : BinaryConsts::CallIndirect; - auto* type = parent.getModule()->getFunctionType(curr->fullType); - Signature sig(Type(type->params), type->result); - o << op << U32LEB(parent.getTypeIndex(sig)) + o << op << U32LEB(parent.getTypeIndex(curr->sig)) << U32LEB(0); // Reserved flags field } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index bc17e2193..7d29b0dcd 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -254,16 +254,23 @@ unsigned getTypeSize(Type type) { } FeatureSet getFeatures(Type type) { - switch (type) { - case v128: - return FeatureSet::SIMD; - case anyref: - return FeatureSet::ReferenceTypes; - case exnref: - return FeatureSet::ExceptionHandling; - default: - return FeatureSet(); + FeatureSet feats = FeatureSet::MVP; + for (Type t : type.expand()) { + switch (t) { + case v128: + feats |= FeatureSet::SIMD; + break; + case anyref: + feats |= FeatureSet::ReferenceTypes; + break; + case exnref: + feats |= FeatureSet::ExceptionHandling; + break; + default: + break; + } } + return feats; } Type getType(unsigned size, bool float_) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 75121d4f4..c6de444c1 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -593,14 +593,15 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - if (!shouldBeTrue(curr->operands.size() == target->params.size(), + const std::vector<Type> params = target->sig.params.expand(); + if (!shouldBeTrue(curr->operands.size() == params.size(), curr, "call param number must match")) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - target->params[i], + params[i], curr, "call param types must match") && !info.quiet) { @@ -613,8 +614,8 @@ void FunctionValidator::visitCall(Call* curr) { curr, "return_call should have unreachable type"); shouldBeEqual( - getFunction()->result, - target->result, + getFunction()->sig.results, + target->sig.results, curr, "return_call callee return type must match caller return type"); } else { @@ -629,7 +630,7 @@ void FunctionValidator::visitCall(Call* curr) { "calls may only be unreachable if they have unreachable operands"); } else { shouldBeEqual(curr->type, - target->result, + target->sig.results, curr, "call type must match callee return type"); } @@ -643,20 +644,17 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { if (!info.validateGlobally) { return; } - auto* type = getModule()->getFunctionTypeOrNull(curr->fullType); - if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) { - return; - } + const std::vector<Type>& params = curr->sig.params.expand(); shouldBeEqualOrFirstIsUnreachable( curr->target->type, i32, curr, "indirect call target must be an i32"); - if (!shouldBeTrue(curr->operands.size() == type->params.size(), + if (!shouldBeTrue(curr->operands.size() == params.size(), curr, "call param number must match")) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - type->params[i], + params[i], curr, "call param types must match") && !info.quiet) { @@ -669,8 +667,8 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { curr, "return_call_indirect should have unreachable type"); shouldBeEqual( - getFunction()->result, - type->result, + getFunction()->sig.results, + curr->sig.results, curr, "return_call_indirect callee return type must match caller return type"); } else { @@ -687,7 +685,7 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { } } else { shouldBeEqual(curr->type, - type->result, + curr->sig.results, curr, "call_indirect type must match callee return type"); } @@ -701,31 +699,35 @@ void FunctionValidator::visitConst(Const* curr) { } void FunctionValidator::visitLocalGet(LocalGet* curr) { - shouldBeTrue(curr->index < getFunction()->getNumLocals(), - curr, - "local.get index must be small enough"); shouldBeTrue(curr->type.isConcrete(), curr, "local.get must have a valid type - check what you provided " "when you constructed the node"); - shouldBeTrue(curr->type == getFunction()->getLocalType(curr->index), - curr, - "local.get must have proper type"); + if (shouldBeTrue(curr->index < getFunction()->getNumLocals(), + curr, + "local.get index must be small enough")) { + shouldBeTrue(curr->type == getFunction()->getLocalType(curr->index), + curr, + "local.get must have proper type"); + } } void FunctionValidator::visitLocalSet(LocalSet* curr) { - shouldBeTrue(curr->index < getFunction()->getNumLocals(), - curr, - "local.set index must be small enough"); - if (curr->value->type != unreachable) { - if (curr->type != none) { // tee is ok anyhow - shouldBeEqualOrFirstIsUnreachable( - curr->value->type, curr->type, curr, "local.set type must be correct"); + if (shouldBeTrue(curr->index < getFunction()->getNumLocals(), + curr, + "local.set index must be small enough")) { + if (curr->value->type != unreachable) { + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, + curr->type, + curr, + "local.set type must be correct"); + } + shouldBeEqual(getFunction()->getLocalType(curr->index), + curr->value->type, + curr, + "local.set type must match function"); } - shouldBeEqual(getFunction()->getLocalType(curr->index), - curr->value->type, - curr, - "local.set type must match function"); } } @@ -1754,28 +1756,35 @@ void FunctionValidator::visitBrOnExn(BrOnExn* curr) { } void FunctionValidator::visitFunction(Function* curr) { - FeatureSet typeFeatures = getFeatures(curr->result); - for (auto type : curr->params) { - typeFeatures |= getFeatures(type); + shouldBeTrue(!curr->sig.results.isMulti(), + curr->body, + "Multivalue functions not allowed yet"); + FeatureSet features; + for (auto type : curr->sig.params.expand()) { + features |= getFeatures(type); shouldBeTrue(type.isConcrete(), curr, "params must be concretely typed"); } + for (auto type : curr->sig.results.expand()) { + features |= getFeatures(type); + shouldBeTrue(type.isConcrete(), curr, "results must be concretely typed"); + } for (auto type : curr->vars) { - typeFeatures |= getFeatures(type); + features |= getFeatures(type); shouldBeTrue(type.isConcrete(), curr, "vars must be concretely typed"); } - shouldBeTrue(typeFeatures <= getModule()->features, + shouldBeTrue(features <= getModule()->features, curr, "all used types should be allowed"); // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, + shouldBeEqual(curr->sig.results, curr->body->type, curr->body, "function body type must match, if function returns"); } if (returnType != unreachable) { - shouldBeEqual(curr->result, + shouldBeEqual(curr->sig.results, returnType, curr->body, "function result must match, if function has returns"); @@ -1784,22 +1793,6 @@ void FunctionValidator::visitFunction(Function* curr) { breakInfos.empty(), curr->body, "all named break targets must exist"); returnType = unreachable; labelNames.clear(); - // if function has a named type, it must match up with the function's params - // and result - if (info.validateGlobally && curr->type.is()) { - auto* ft = getModule()->getFunctionType(curr->type); - shouldBeTrue(ft->params == curr->params, - curr->name, - "function params must match its declared type"); - shouldBeTrue(ft->result == curr->result, - curr->name, - "function result must match its declared type"); - } - if (curr->imported()) { - shouldBeTrue(curr->type.is(), - curr->name, - "imported functions must have a function type"); - } // validate optional local names std::set<Name> seen; for (auto& pair : curr->localNames) { @@ -1925,17 +1918,18 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { static void validateImports(Module& module, ValidationInfo& info) { ModuleUtils::iterImportedFunctions(module, [&](Function* curr) { if (info.validateWeb) { - auto* functionType = module.getFunctionType(curr->type); - info.shouldBeUnequal(functionType->result, - i64, - curr->name, - "Imported function must not have i64 return type"); - for (Type param : functionType->params) { + for (Type param : curr->sig.params.expand()) { info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); } + for (Type result : curr->sig.results.expand()) { + info.shouldBeUnequal(result, + i64, + curr->name, + "Imported function must not have i64 results"); + } } }); if (!module.features.hasMutableGlobals()) { @@ -1951,17 +1945,19 @@ static void validateExports(Module& module, ValidationInfo& info) { if (curr->kind == ExternalKind::Function) { if (info.validateWeb) { Function* f = module.getFunction(curr->value); - info.shouldBeUnequal(f->result, - i64, - f->name, - "Exported function must not have i64 return type"); - for (auto param : f->params) { + for (auto param : f->sig.params.expand()) { info.shouldBeUnequal( param, i64, f->name, "Exported function must not have i64 parameters"); } + for (auto result : f->sig.results.expand()) { + info.shouldBeUnequal(result, + i64, + f->name, + "Exported function must not have i64 results"); + } } } else if (curr->kind == ExternalKind::Global && !module.features.hasMutableGlobals()) { @@ -2133,10 +2129,12 @@ static void validateModule(Module& module, ValidationInfo& info) { auto func = module.getFunctionOrNull(module.start); if (info.shouldBeTrue( func != nullptr, module.start, "start must be found")) { - info.shouldBeTrue( - func->params.size() == 0, module.start, "start must have 0 params"); - info.shouldBeTrue( - func->result == none, module.start, "start must not return a value"); + info.shouldBeTrue(func->sig.params == Type::none, + module.start, + "start must have 0 params"); + info.shouldBeTrue(func->sig.results == Type::none, + module.start, + "start must not return a value"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 5a722e600..83829418e 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -432,6 +432,7 @@ void Call::finalize() { } void CallIndirect::finalize() { + type = sig.results; handleUnreachableOperands(this); if (isReturn) { type = unreachable; @@ -441,34 +442,6 @@ void CallIndirect::finalize() { } } -bool FunctionType::structuralComparison(FunctionType& b) { - return structuralComparison(b.params, b.result); -} - -bool FunctionType::structuralComparison(const std::vector<Type>& otherParams, - Type otherResult) { - if (result != otherResult) { - return false; - } - if (params.size() != otherParams.size()) { - return false; - } - for (size_t i = 0; i < params.size(); i++) { - if (params[i] != otherParams[i]) { - return false; - } - } - return true; -} - -bool FunctionType::operator==(FunctionType& b) { - if (name != b.name) { - return false; - } - return structuralComparison(b); -} -bool FunctionType::operator!=(FunctionType& b) { return !(*this == b); } - bool LocalSet::isTee() { return type != none; } void LocalSet::setTee(bool is) { @@ -932,15 +905,23 @@ void Push::finalize() { } } -size_t Function::getNumParams() { return params.size(); } +size_t Function::getNumParams() { return sig.params.size(); } size_t Function::getNumVars() { return vars.size(); } -size_t Function::getNumLocals() { return params.size() + vars.size(); } +size_t Function::getNumLocals() { return sig.params.size() + vars.size(); } -bool Function::isParam(Index index) { return index < params.size(); } +bool Function::isParam(Index index) { + size_t size = sig.params.size(); + assert(index < size + vars.size()); + return index < size; +} -bool Function::isVar(Index index) { return index >= params.size(); } +bool Function::isVar(Index index) { + auto base = getVarIndexBase(); + assert(index < base + vars.size()); + return index >= base; +} bool Function::hasLocalName(Index index) const { return localNames.find(index) != localNames.end(); @@ -973,13 +954,14 @@ Index Function::getLocalIndex(Name name) { return iter->second; } -Index Function::getVarIndexBase() { return params.size(); } +Index Function::getVarIndexBase() { return sig.params.size(); } Type Function::getLocalType(Index index) { - if (isParam(index)) { + const std::vector<Type>& params = sig.params.expand(); + if (index < params.size()) { return params[index]; } else if (isVar(index)) { - return vars[index - getVarIndexBase()]; + return vars[index - params.size()]; } else { WASM_UNREACHABLE("invalid local index"); } @@ -994,14 +976,6 @@ void Function::clearDebugInfo() { epilogLocation.clear(); } -FunctionType* Module::getFunctionType(Name name) { - auto iter = functionTypesMap.find(name); - if (iter == functionTypesMap.end()) { - Fatal() << "Module::getFunctionType: " << name << " does not exist"; - } - return iter->second; -} - Export* Module::getExport(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -1035,14 +1009,6 @@ Event* Module::getEvent(Name name) { return iter->second; } -FunctionType* Module::getFunctionTypeOrNull(Name name) { - auto iter = functionTypesMap.find(name); - if (iter == functionTypesMap.end()) { - return nullptr; - } - return iter->second; -} - Export* Module::getExportOrNull(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -1075,19 +1041,6 @@ Event* Module::getEventOrNull(Name name) { return iter->second; } -FunctionType* Module::addFunctionType(std::unique_ptr<FunctionType> curr) { - if (!curr->name.is()) { - Fatal() << "Module::addFunctionType: empty name"; - } - if (getFunctionTypeOrNull(curr->name)) { - Fatal() << "Module::addFunctionType: " << curr->name << " already exists"; - } - auto* p = curr.get(); - functionTypes.emplace_back(std::move(curr)); - functionTypesMap[p->name] = p; - return p; -} - Export* Module::addExport(Export* curr) { if (!curr->name.is()) { Fatal() << "Module::addExport: empty name"; @@ -1166,9 +1119,6 @@ void removeModuleElement(Vector& v, Map& m, Name name) { } } -void Module::removeFunctionType(Name name) { - removeModuleElement(functionTypes, functionTypesMap, name); -} void Module::removeExport(Name name) { removeModuleElement(exports, exportsMap, name); } @@ -1198,9 +1148,6 @@ void removeModuleElements(Vector& v, v.end()); } -void Module::removeFunctionTypes(std::function<bool(FunctionType*)> pred) { - removeModuleElements(functionTypes, functionTypesMap, pred); -} void Module::removeExports(std::function<bool(Export*)> pred) { removeModuleElements(exports, exportsMap, pred); } @@ -1219,10 +1166,6 @@ void Module::updateMaps() { for (auto& curr : functions) { functionsMap[curr->name] = curr.get(); } - functionTypesMap.clear(); - for (auto& curr : functionTypes) { - functionTypesMap[curr->name] = curr.get(); - } exportsMap.clear(); for (auto& curr : exports) { exportsMap[curr->name] = curr.get(); |