diff options
Diffstat (limited to 'src')
62 files changed, 848 insertions, 1477 deletions
diff --git a/src/abi/js.h b/src/abi/js.h index 52ca5657f..89e3f0087 100644 --- a/src/abi/js.h +++ b/src/abi/js.h @@ -52,22 +52,20 @@ extern cashew::IString SCRATCH_STORE_F64; inline void ensureScratchMemoryHelpers(Module* wasm, cashew::IString specific = cashew::IString()) { - auto ensureImport = - [&](Name name, const std::vector<Type> params, Type result) { - if (wasm->getFunctionOrNull(name)) { - return; - } - if (specific.is() && name != specific) { - return; - } - auto func = make_unique<Function>(); - func->name = name; - func->params = params; - func->result = result; - func->module = ENV; - func->base = name; - wasm->addFunction(std::move(func)); - }; + auto ensureImport = [&](Name name, Type params, Type results) { + if (wasm->getFunctionOrNull(name)) { + return; + } + if (specific.is() && name != specific) { + return; + } + auto func = make_unique<Function>(); + func->name = name; + func->sig = Signature(params, results); + func->module = ENV; + func->base = name; + wasm->addFunction(std::move(func)); + }; ensureImport(SCRATCH_LOAD_I32, {i32}, i32); ensureImport(SCRATCH_STORE_I32, {i32, i32}, none); diff --git a/src/abi/stack.h b/src/abi/stack.h index 68ac7f08a..265a7af6e 100644 --- a/src/abi/stack.h +++ b/src/abi/stack.h @@ -128,10 +128,10 @@ getStackSpace(Index local, Function* func, Index size, Module& wasm) { // no need to restore the old stack value, we're gone anyhow } else { // save the return value - auto temp = builder.addVar(func, func->result); + auto temp = builder.addVar(func, func->sig.results); block->list.push_back(builder.makeLocalSet(temp, func->body)); block->list.push_back(makeStackRestore()); - block->list.push_back(builder.makeLocalGet(temp, func->result)); + block->list.push_back(builder.makeLocalGet(temp, func->sig.results)); } block->finalize(); func->body = block; diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 048d363ef..fc841e634 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -28,7 +28,6 @@ #include "emscripten-optimizer/optimizer.h" #include "ir/bits.h" #include "ir/branch-utils.h" -#include "ir/function-type-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" #include "ir/trapping.h" @@ -509,47 +508,51 @@ private: // function types. we fill in this information as we see // uses, in the first pass - std::map<IString, std::unique_ptr<FunctionType>> importedFunctionTypes; + std::map<IString, Signature> importedSignatures; void noteImportedFunctionCall(Ref ast, Type resultType, Call* call) { assert(ast[0] == CALL && ast[1]->isString()); IString importName = ast[1]->getIString(); - auto type = make_unique<FunctionType>(); - type->name = IString((std::string("type$") + importName.str).c_str(), - false); // TODO: make a list of such types - type->result = resultType; + std::vector<Type> params; for (auto* operand : call->operands) { - type->params.push_back(operand->type); + params.push_back(operand->type); } + Signature sig = Signature(Type(params), resultType); // if we already saw this signature, verify it's the same (or else handle // that) - if (importedFunctionTypes.find(importName) != importedFunctionTypes.end()) { - FunctionType* previous = importedFunctionTypes[importName].get(); - if (*type != *previous) { + if (importedSignatures.find(importName) != importedSignatures.end()) { + Signature& previous = importedSignatures[importName]; + if (sig != previous) { + std::vector<Type> mergedParams = previous.params.expand(); // merge it in. we'll add on extra 0 parameters for ones not actually // used, and upgrade types to double where there is a conflict (which is // ok since in JS, double can contain everything i32 and f32 can). - for (size_t i = 0; i < type->params.size(); i++) { - if (previous->params.size() > i) { - if (previous->params[i] == none) { - previous->params[i] = type->params[i]; // use a more concrete type - } else if (previous->params[i] != type->params[i]) { - previous->params[i] = f64; // overloaded type, make it a double + for (size_t i = 0; i < params.size(); i++) { + if (mergedParams.size() > i) { + // TODO: Is this dead? + // if (mergedParams[i] == Type::none) { + // mergedParams[i] = params[i]; // use a more concrete type + // } else + if (mergedParams[i] != params[i]) { + mergedParams[i] = f64; // overloaded type, make it a double } } else { - previous->params.push_back(type->params[i]); // add a new param + mergedParams.push_back(params[i]); // add a new param } } + previous.params = Type(mergedParams); // we accept none and a concrete type, but two concrete types mean we // need to use an f64 to contain anything - if (previous->result == none) { - previous->result = type->result; // use a more concrete type - } else if (previous->result != type->result && type->result != none) { - previous->result = f64; // overloaded return type, make it a double + if (previous.results == Type::none) { + previous.results = sig.results; // use a more concrete type + } else if (previous.results != sig.results && + sig.results != Type::none) { + // overloaded return type, make it a double + previous.results = Type::f64; } } } else { - importedFunctionTypes[importName].swap(type); + importedSignatures[importName] = sig; } } @@ -566,10 +569,14 @@ private: return result; } - FunctionType* - getFunctionType(Ref parent, ExpressionList& operands, AsmData* data) { - Type result = getResultTypeOfCallUsingParent(parent, data); - return ensureFunctionType(getSig(result, operands), &wasm); + Signature getSignature(Ref parent, ExpressionList& operands, AsmData* data) { + Type results = getResultTypeOfCallUsingParent(parent, data); + std::vector<Type> paramTypes; + for (auto& op : operands) { + assert(op->type != Type::unreachable); + paramTypes.push_back(op->type); + } + return Signature(Type(paramTypes), results); } public: @@ -790,25 +797,29 @@ private: } } - FunctionType* getBuiltinFunctionType(Name module, - Name base, - ExpressionList* operands = nullptr) { + bool getBuiltinSignature(Signature& sig, + Name module, + Name base, + ExpressionList* operands = nullptr) { if (module == GLOBAL_MATH) { if (base == ABS) { assert(operands && operands->size() == 1); Type type = (*operands)[0]->type; if (type == i32) { - return ensureFunctionType("ii", &wasm); + sig = Signature(Type::i32, Type::i32); + return true; } if (type == f32) { - return ensureFunctionType("ff", &wasm); + sig = Signature(Type::f32, Type::f32); + return true; } if (type == f64) { - return ensureFunctionType("dd", &wasm); + sig = Signature(Type::f64, Type::f64); + return true; } } } - return nullptr; + return false; } // ensure a nameless block @@ -1043,6 +1054,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { import->name = name; import->module = moduleName; import->base = base; + import->sig = Signature(Type::none, Type::none); wasm.addFunction(import); } }; @@ -1377,16 +1389,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) { ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { IString name = import->name; - if (importedFunctionTypes.find(name) != importedFunctionTypes.end()) { + if (importedSignatures.find(name) != importedSignatures.end()) { // special math builtins - FunctionType* builtin = - getBuiltinFunctionType(import->module, import->base); - if (builtin) { - import->type = builtin->name; + Signature builtin; + if (getBuiltinSignature(builtin, import->module, import->base)) { + import->sig = builtin; } else { - import->type = - ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm) - ->name; + import->sig = importedSignatures[name]; } } else if (import->module != ASM2WASM) { // special-case the special module // never actually used, which means we don't know the function type since @@ -1399,12 +1408,6 @@ void Asm2WasmBuilder::processAsm(Ref ast) { wasm.removeFunction(curr); } - // Finalize function imports now that we've seen all the calls - - ModuleUtils::iterImportedFunctions(wasm, [&](Function* func) { - FunctionTypeUtils::fillFunction(func, wasm.getFunctionType(func->type)); - }); - // Finalize calls now that everything is known and generated struct FinalizeCalls : public WalkerPass<PostWalker<FinalizeCalls>> { @@ -1450,9 +1453,9 @@ void Asm2WasmBuilder::processAsm(Ref ast) { if (calledFunc && !calledFunc->imported()) { // The result type of the function being called is now known, and can be // applied. - auto result = calledFunc->result; - if (curr->type != result) { - curr->type = result; + auto results = calledFunc->sig.results; + if (curr->type != results) { + curr->type = results; } // Handle mismatched numbers of arguments. In clang, if a function is // declared one way but called in another, it inserts bitcasts to make @@ -1460,26 +1463,26 @@ void Asm2WasmBuilder::processAsm(Ref ast) { // parameters in native platforms, even though it's undefined behavior. // We warn about it here, but tolerate it, if there is a simple // solution. - if (curr->operands.size() < calledFunc->params.size()) { + const std::vector<Type>& params = calledFunc->sig.params.expand(); + if (curr->operands.size() < params.size()) { notifyAboutWrongOperands("warning: asm2wasm adding operands", calledFunc); - while (curr->operands.size() < calledFunc->params.size()) { + while (curr->operands.size() < params.size()) { // Add params as necessary, with zeros. curr->operands.push_back(LiteralUtils::makeZero( - calledFunc->params[curr->operands.size()], *getModule())); + params[curr->operands.size()], *getModule())); } } - if (curr->operands.size() > calledFunc->params.size()) { + if (curr->operands.size() > params.size()) { notifyAboutWrongOperands("warning: asm2wasm dropping operands", calledFunc); - curr->operands.resize(calledFunc->params.size()); + curr->operands.resize(params.size()); } // If the types are wrong, validation will fail later anyhow, but add a // warning here, it may help people. for (Index i = 0; i < curr->operands.size(); i++) { auto sent = curr->operands[i]->type; - auto expected = calledFunc->params[i]; - if (sent != unreachable && sent != expected) { + if (sent != Type::unreachable && sent != params[i]) { notifyAboutWrongOperands( "error: asm2wasm seeing an invalid argument type at index " + std::to_string(i) + " (this will not validate)", @@ -1490,23 +1493,23 @@ void Asm2WasmBuilder::processAsm(Ref ast) { // A call to an import // fill things out: add extra params as needed, etc. asm tolerates ffi // overloading, wasm does not - auto iter = parent->importedFunctionTypes.find(curr->target); - if (iter == parent->importedFunctionTypes.end()) { + auto iter = parent->importedSignatures.find(curr->target); + if (iter == parent->importedSignatures.end()) { return; // one of our fake imports for callIndirect fixups } - auto type = iter->second.get(); - for (size_t i = 0; i < type->params.size(); i++) { + const std::vector<Type>& params = iter->second.params.expand(); + for (size_t i = 0; i < params.size(); i++) { if (i >= curr->operands.size()) { // add a new param auto val = parent->allocator.alloc<Const>(); - val->type = val->value.type = type->params[i]; + val->type = val->value.type = params[i]; curr->operands.push_back(val); - } else if (curr->operands[i]->type != type->params[i]) { + } else if (curr->operands[i]->type != params[i]) { // if the param is used, then we have overloading here and the // combined type must be f64; if this is an unreachable param, then // it doesn't matter. - assert(type->params[i] == f64 || - curr->operands[i]->type == unreachable); + assert(params[i] == Type::f64 || + curr->operands[i]->type == Type::unreachable); // overloaded, upgrade to f64 switch (curr->operands[i]->type) { case i32: @@ -1522,12 +1525,11 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } } Module* wasm = getModule(); - auto importResult = - wasm->getFunctionType(wasm->getFunction(curr->target)->type)->result; - if (curr->type != importResult) { + Type importResults = wasm->getFunction(curr->target)->sig.results; + if (curr->type != importResults) { auto old = curr->type; - curr->type = importResult; - if (importResult == f64) { + curr->type = importResults; + if (importResults == Type::f64) { // we use a JS f64 value which is the most general, and convert to // it switch (old) { @@ -1743,7 +1745,6 @@ void Asm2WasmBuilder::processAsm(Ref ast) { // if r then *r = x % y // returns x / y auto* func = wasm.getFunction(udivmoddi4); - assert(!func->type.is()); Builder::clearLocals(func); Index xl = Builder::addParam(func, "xl", i32), xh = Builder::addParam(func, "xh", i32), @@ -1786,6 +1787,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { BYN_TRACE("asm2wasming func: " << ast[1]->getIString().str << '\n'); auto function = new Function; + function->sig = Signature(Type::none, Type::none); function->name = name; Ref params = ast[2]; Ref body = ast[3]; @@ -1851,8 +1853,8 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { std::function<Expression*(Ref, unsigned)> processIgnoringShift; std::function<Expression*(Ref)> process = [&](Ref ast) -> Expression* { - AstStackHelper astStackHelper( - ast); // TODO: only create one when we need it? + // TODO: only create one when we need it? + AstStackHelper astStackHelper(ast); if (ast->isString()) { IString name = ast->getIString(); if (functionVariables.has(name)) { @@ -1873,9 +1875,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = DEBUGGER; import->module = ASM2WASM; import->base = DEBUGGER; - auto* functionType = ensureFunctionType("v", &wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::none); wasm.addFunction(import); } return call; @@ -1989,9 +1989,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { import->name = F64_REM; import->module = ASM2WASM; import->base = F64_REM; - auto* functionType = ensureFunctionType("ddd", &wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature({Type::f64, Type::f64}, Type::f64); wasm.addFunction(import); } return call; @@ -2647,7 +2645,6 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } break; } - default: {} } } // ftCall_* and mftCall_* represent function table calls, either from @@ -2685,10 +2682,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { auto specific = ret->dynCast<CallIndirect>(); // note that we could also get the type from the suffix of the name, // e.g., mftCall_vi - auto* fullType = getFunctionType( + auto sig = getSignature( astStackHelper.getParent(), specific->operands, &asmData); - specific->fullType = fullType->name; - specific->type = fullType->result; + specific->sig = sig; + specific->type = sig.results; } if (callImport) { // apply the detected type from the parent @@ -2719,10 +2716,10 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { for (unsigned i = 0; i < args->size(); i++) { ret->operands.push_back(process(args[i])); } - auto* fullType = - getFunctionType(astStackHelper.getParent(), ret->operands, &asmData); - ret->fullType = fullType->name; - ret->type = fullType->result; + auto sig = + getSignature(astStackHelper.getParent(), ret->operands, &asmData); + ret->sig = sig; + ret->type = sig.results; // we don't know the table offset yet. emit target = target + // callImport(tableName), which we fix up later when we know how asm // function tables are layed out inside the wasm table. @@ -2734,9 +2731,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } else if (what == RETURN) { Type type = !!ast[1] ? detectWasmType(ast[1], &asmData) : none; if (seenReturn) { - assert(function->result == type); + assert(function->sig.results == type); } else { - function->result = type; + function->sig.results = type; } // wasm has no return, so we just break on the topmost block auto ret = allocator.alloc<Return>(); diff --git a/src/asm_v_wasm.h b/src/asm_v_wasm.h index 95b4cbbca..d7e2e5b24 100644 --- a/src/asm_v_wasm.h +++ b/src/asm_v_wasm.h @@ -28,19 +28,8 @@ Type asmToWasmType(AsmType asmType); AsmType wasmToAsmType(Type type); char getSig(Type type); - -template<typename ListType> -std::string getSig(const ListType& params, Type result) { - std::string ret; - ret += getSig(result); - for (auto param : params) { - ret += getSig(param); - } - return ret; -} - -std::string getSig(const FunctionType* type); std::string getSig(Function* func); +std::string getSig(Type results, Type params); template<typename T, typename std::enable_if<std::is_base_of<Expression, T>::value>::type* = @@ -74,21 +63,6 @@ std::string getSigFromStructs(Type result, const ListType& operands) { return ret; } -Type sigToType(char sig); - -FunctionType sigToFunctionType(const std::string& sig); - -FunctionType* -ensureFunctionType(const std::string& sig, Module* wasm, Name name = Name()); - -template<typename ListType> -FunctionType* ensureFunctionType(const ListType& params, - Type result, - Module* wasm, - Name name = Name()) { - return ensureFunctionType(getSig(params, result), wasm, name); -} - // converts an f32 to an f64 if necessary Expression* ensureDouble(Expression* expr, MixedArena& allocator); diff --git a/src/asmjs/asm_v_wasm.cpp b/src/asmjs/asm_v_wasm.cpp index 2fba0520a..3720ca079 100644 --- a/src/asmjs/asm_v_wasm.cpp +++ b/src/asmjs/asm_v_wasm.cpp @@ -89,62 +89,18 @@ char getSig(Type type) { WASM_UNREACHABLE("invalid type"); } -std::string getSig(const FunctionType* type) { - return getSig(type->params, type->result); -} - std::string getSig(Function* func) { - return getSig(func->params, func->result); -} - -Type sigToType(char sig) { - switch (sig) { - case 'i': - return i32; - case 'j': - return i64; - case 'f': - return f32; - case 'd': - return f64; - case 'V': - return v128; - case 'a': - return anyref; - case 'e': - return exnref; - case 'v': - return none; - default: - abort(); - } + return getSig(func->sig.results, func->sig.params); } -FunctionType sigToFunctionType(const std::string& sig) { - FunctionType ret; - ret.result = sigToType(sig[0]); - for (size_t i = 1; i < sig.size(); i++) { - ret.params.push_back(sigToType(sig[i])); - } - return ret; -} - -FunctionType* -ensureFunctionType(const std::string& sig, Module* wasm, Name name) { - if (!name.is()) { - name = "FUNCSIG$" + sig; - } - if (wasm->getFunctionTypeOrNull(name)) { - return wasm->getFunctionType(name); - } - // add new type - auto type = make_unique<FunctionType>(); - type->name = name; - type->result = sigToType(sig[0]); - for (size_t i = 1; i < sig.size(); i++) { - type->params.push_back(sigToType(sig[i])); +std::string getSig(Type results, Type params) { + assert(!results.isMulti()); + std::string sig; + sig += getSig(results); + for (Type t : params.expand()) { + sig += getSig(t); } - return wasm->addFunctionType(std::move(type)); + return sig; } Expression* ensureDouble(Expression* expr, MixedArena& allocator) { diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 87e1fdb0b..35f394fd1 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -22,7 +22,6 @@ #include "binaryen-c.h" #include "cfg/Relooper.h" -#include "ir/function-type-utils.h" #include "ir/utils.h" #include "pass.h" #include "shell-interface.h" @@ -105,7 +104,6 @@ Literal fromBinaryenLiteral(BinaryenLiteral x) { // module, but likely it doesn't matter) static std::mutex BinaryenFunctionMutex; -static std::mutex BinaryenFunctionTypeMutex; // Optimization options static PassOptions globalPassOptions = @@ -123,7 +121,6 @@ void traceNameOrNULL(const char* name, std::ostream& out = std::cout) { } } -std::map<BinaryenFunctionTypeRef, size_t> functionTypes; std::map<BinaryenExpressionRef, size_t> expressions; std::map<BinaryenFunctionRef, size_t> functions; std::map<BinaryenGlobalRef, size_t> globals; @@ -479,14 +476,12 @@ BinaryenModuleRef BinaryenModuleCreate(void) { void BinaryenModuleDispose(BinaryenModuleRef module) { if (tracing) { std::cout << " BinaryenModuleDispose(the_module);\n"; - std::cout << " functionTypes.clear();\n"; std::cout << " expressions.clear();\n"; std::cout << " functions.clear();\n"; std::cout << " globals.clear();\n"; std::cout << " events.clear();\n"; std::cout << " exports.clear();\n"; std::cout << " relooperBlocks.clear();\n"; - functionTypes.clear(); expressions.clear(); functions.clear(); globals.clear(); @@ -498,70 +493,7 @@ void BinaryenModuleDispose(BinaryenModuleRef module) { delete (Module*)module; } -// Function types - -BinaryenFunctionTypeRef BinaryenAddFunctionType(BinaryenModuleRef module, - const char* name, - BinaryenType result, - BinaryenType* paramTypes, - BinaryenIndex numParams) { - auto* wasm = (Module*)module; - auto ret = make_unique<FunctionType>(); - if (name) { - ret->name = name; - } else { - ret->name = Name::fromInt(wasm->functionTypes.size()); - } - ret->result = Type(result); - for (BinaryenIndex i = 0; i < numParams; i++) { - ret->params.push_back(Type(paramTypes[i])); - } - - if (tracing) { - std::cout << " {\n"; - std::cout << " BinaryenType paramTypes[] = { "; - for (BinaryenIndex i = 0; i < numParams; i++) { - if (i > 0) { - std::cout << ", "; - } - std::cout << paramTypes[i]; - } - if (numParams == 0) { - // ensure the array is not empty, otherwise a compiler error on VS - std::cout << "0"; - } - std::cout << " };\n"; - size_t id = functionTypes.size(); - std::cout << " functionTypes[" << id - << "] = BinaryenAddFunctionType(the_module, "; - functionTypes[ret.get()] = id; - traceNameOrNULL(name); - std::cout << ", " << result << ", paramTypes, " << numParams << ");\n"; - std::cout << " }\n"; - } - - // Lock. This can be called from multiple threads at once, and is a - // point where they all access and modify the module. - std::lock_guard<std::mutex> lock(BinaryenFunctionTypeMutex); - return wasm->addFunctionType(std::move(ret)); -} -void BinaryenRemoveFunctionType(BinaryenModuleRef module, const char* name) { - if (tracing) { - std::cout << " BinaryenRemoveFunctionType(the_module, "; - traceNameOrNULL(name); - std::cout << ");\n"; - } - - auto* wasm = (Module*)module; - assert(name != NULL); - - // Lock. This can be called from multiple threads at once, and is a - // point where they all access and modify the module. - { - std::lock_guard<std::mutex> lock(BinaryenFunctionTypeMutex); - wasm->removeFunctionType(name); - } -} +// Literals BinaryenLiteral BinaryenLiteralInt32(int32_t x) { return toBinaryenLiteral(Literal(x)); @@ -1180,7 +1112,8 @@ makeBinaryenCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, - const char* type, + BinaryenType params, + BinaryenType results, bool isReturn) { auto* wasm = (Module*)module; auto* ret = wasm->allocator.alloc<CallIndirect>(); @@ -1205,7 +1138,8 @@ makeBinaryenCallIndirect(BinaryenModuleRef module, target, "operands", numOperands, - StringLit(type)); + params, + results); std::cout << " }\n"; } @@ -1213,8 +1147,8 @@ makeBinaryenCallIndirect(BinaryenModuleRef module, for (BinaryenIndex i = 0; i < numOperands; i++) { ret->operands.push_back((Expression*)operands[i]); } - ret->fullType = type; - ret->type = wasm->getFunctionType(ret->fullType)->result; + ret->sig = Signature(Type(params), Type(results)); + ret->type = Type(results); ret->isReturn = isReturn; ret->finalize(); return static_cast<Expression*>(ret); @@ -1223,18 +1157,20 @@ BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, - const char* type) { + BinaryenType params, + BinaryenType results) { return makeBinaryenCallIndirect( - module, target, operands, numOperands, type, false); + module, target, operands, numOperands, params, results, false); } BinaryenExpressionRef BinaryenReturnCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, - const char* type) { + BinaryenType params, + BinaryenType results) { return makeBinaryenCallIndirect( - module, target, operands, numOperands, type, true); + module, target, operands, numOperands, params, results, true); } BinaryenExpressionRef BinaryenLocalGet(BinaryenModuleRef module, BinaryenIndex index, @@ -3126,7 +3062,8 @@ BinaryenExpressionRef BinaryenBrOnExnGetExnref(BinaryenExpressionRef expr) { BinaryenFunctionRef BinaryenAddFunction(BinaryenModuleRef module, const char* name, - BinaryenFunctionTypeRef type, + BinaryenType params, + BinaryenType results, BinaryenType* varTypes, BinaryenIndex numVarTypes, BinaryenExpressionRef body) { @@ -3150,18 +3087,14 @@ BinaryenFunctionRef BinaryenAddFunction(BinaryenModuleRef module, auto id = functions.size(); functions[ret] = id; std::cout << " functions[" << id - << "] = BinaryenAddFunction(the_module, \"" << name - << "\", functionTypes[" << functionTypes[type] << "], varTypes, " - << numVarTypes << ", expressions[" << expressions[body] - << "]);\n"; + << "] = BinaryenAddFunction(the_module, \"" << name << "\", " + << params << ", " << results << ", varTypes, " << numVarTypes + << ", expressions[" << expressions[body] << "]);\n"; std::cout << " }\n"; } ret->name = name; - ret->type = ((FunctionType*)type)->name; - auto* functionType = wasm->getFunctionType(ret->type); - ret->result = functionType->result; - ret->params = functionType->params; + ret->sig = Signature(Type(params), Type(results)); for (BinaryenIndex i = 0; i < numVarTypes; i++) { ret->vars.push_back(Type(varTypes[i])); } @@ -3301,21 +3234,21 @@ void BinaryenAddFunctionImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, - BinaryenFunctionTypeRef functionType) { + BinaryenType params, + BinaryenType results) { auto* wasm = (Module*)module; auto* ret = new Function(); if (tracing) { std::cout << " BinaryenAddFunctionImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName - << "\", functionTypes[" << functionTypes[functionType] << "]);\n"; + << "\", " << params << ", " << results << ");\n"; } ret->name = internalName; ret->module = externalModuleName; ret->base = externalBaseName; - ret->type = ((FunctionType*)functionType)->name; - FunctionTypeUtils::fillFunction(ret, (FunctionType*)functionType); + ret->sig = Signature(Type(params), Type(results)); wasm->addFunction(ret); } void BinaryenAddTableImport(BinaryenModuleRef module, @@ -4082,46 +4015,6 @@ const char* BinaryenModuleGetDebugInfoFileName(BinaryenModuleRef module, } // -// ======== FunctionType Operations ======== -// - -const char* BinaryenFunctionTypeGetName(BinaryenFunctionTypeRef ftype) { - if (tracing) { - std::cout << " BinaryenFunctionTypeGetName(functionsTypes[" - << functionTypes[ftype] << "]);\n"; - } - - return ((FunctionType*)ftype)->name.c_str(); -} -BinaryenIndex BinaryenFunctionTypeGetNumParams(BinaryenFunctionTypeRef ftype) { - if (tracing) { - std::cout << " BinaryenFunctionTypeGetNumParams(functionsTypes[" - << functionTypes[ftype] << "]);\n"; - } - - return ((FunctionType*)ftype)->params.size(); -} -BinaryenType BinaryenFunctionTypeGetParam(BinaryenFunctionTypeRef ftype, - BinaryenIndex index) { - if (tracing) { - std::cout << " BinaryenFunctionTypeGetParam(functionsTypes[" - << functionTypes[ftype] << "], " << index << ");\n"; - } - - auto* ft = (FunctionType*)ftype; - assert(index < ft->params.size()); - return ft->params[index]; -} -BinaryenType BinaryenFunctionTypeGetResult(BinaryenFunctionTypeRef ftype) { - if (tracing) { - std::cout << " BinaryenFunctionTypeGetResult(functionsTypes[" - << functionTypes[ftype] << "]);\n"; - } - - return ((FunctionType*)ftype)->result; -} - -// // ========== Function Operations ========== // @@ -4133,40 +4026,21 @@ const char* BinaryenFunctionGetName(BinaryenFunctionRef func) { return ((Function*)func)->name.c_str(); } -const char* BinaryenFunctionGetType(BinaryenFunctionRef func) { - if (tracing) { - std::cout << " BinaryenFunctionGetType(functions[" << functions[func] - << "]);\n"; - } - - return ((Function*)func)->type.c_str(); -} -BinaryenIndex BinaryenFunctionGetNumParams(BinaryenFunctionRef func) { +BinaryenType BinaryenFunctionGetParams(BinaryenFunctionRef func) { if (tracing) { - std::cout << " BinaryenFunctionGetNumParams(functions[" << functions[func] + std::cout << " BinaryenFunctionGetParams(functions[" << functions[func] << "]);\n"; } - return ((Function*)func)->params.size(); + return ((Function*)func)->sig.params; } -BinaryenType BinaryenFunctionGetParam(BinaryenFunctionRef func, - BinaryenIndex index) { +BinaryenType BinaryenFunctionGetResults(BinaryenFunctionRef func) { if (tracing) { - std::cout << " BinaryenFunctionGetParam(functions[" << functions[func] - << "], " << index << ");\n"; - } - - auto* fn = (Function*)func; - assert(index < fn->params.size()); - return fn->params[index]; -} -BinaryenType BinaryenFunctionGetResult(BinaryenFunctionRef func) { - if (tracing) { - std::cout << " BinaryenFunctionGetResult(functions[" << functions[func] + std::cout << " BinaryenFunctionGetResults(functions[" << functions[func] << "]);\n"; } - return ((Function*)func)->result; + return ((Function*)func)->sig.results; } BinaryenIndex BinaryenFunctionGetNumVars(BinaryenFunctionRef func) { if (tracing) { @@ -4625,7 +4499,6 @@ void BinaryenSetAPITracing(int on) { "#include <map>\n" "#include \"binaryen-c.h\"\n" "int main() {\n" - " std::map<size_t, BinaryenFunctionTypeRef> functionTypes;\n" " std::map<size_t, BinaryenExpressionRef> expressions;\n" " std::map<size_t, BinaryenFunctionRef> functions;\n" " std::map<size_t, BinaryenGlobalRef> globals;\n" @@ -4637,6 +4510,7 @@ void BinaryenSetAPITracing(int on) { } else { std::cout << " return 0;\n"; std::cout << "}\n"; + std::cout << "// ending a Binaryen API trace\n"; } } @@ -4644,36 +4518,6 @@ void BinaryenSetAPITracing(int on) { // ========= Utilities ========= // -BinaryenFunctionTypeRef -BinaryenGetFunctionTypeBySignature(BinaryenModuleRef module, - BinaryenType result, - BinaryenType* paramTypes, - BinaryenIndex numParams) { - if (tracing) { - std::cout << " // BinaryenGetFunctionTypeBySignature\n"; - } - - auto* wasm = (Module*)module; - FunctionType test; - test.result = Type(result); - for (BinaryenIndex i = 0; i < numParams; i++) { - test.params.push_back(Type(paramTypes[i])); - } - - // Lock. Guard against reading the list while types are being added. - { - std::lock_guard<std::mutex> lock(BinaryenFunctionTypeMutex); - for (BinaryenIndex i = 0; i < wasm->functionTypes.size(); i++) { - FunctionType* curr = wasm->functionTypes[i].get(); - if (curr->structuralComparison(test)) { - return curr; - } - } - } - - return NULL; -} - void BinaryenSetColorsEnabled(int enabled) { Colors::setEnabled(enabled); } int BinaryenAreColorsEnabled() { return Colors::isEnabled(); } diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 046a56363..ab3210644 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -30,11 +30,10 @@ // --------------- // // Thread safety: You can create Expressions in parallel, as they do not -// refer to global state. BinaryenAddFunction and -// BinaryenAddFunctionType are also thread-safe, which means -// that you can create functions and their contents in multiple -// threads. This is important since functions are where the -// majority of the work is done. +// refer to global state. BinaryenAddFunction is also +// thread-safe, which means that you can create functions and +// their contents in multiple threads. This is important since +// functions are where the majority of the work is done. // Other methods - creating imports, exports, etc. - are // not currently thread-safe (as there is typically no need // to parallelize them). @@ -213,22 +212,6 @@ BINARYEN_REF(Module); BINARYEN_API BinaryenModuleRef BinaryenModuleCreate(void); BINARYEN_API void BinaryenModuleDispose(BinaryenModuleRef module); -// Function types - -BINARYEN_REF(FunctionType); - -// Add a new function type. This is thread-safe. -// Note: name can be NULL, in which case we auto-generate a name -BINARYEN_API BinaryenFunctionTypeRef -BinaryenAddFunctionType(BinaryenModuleRef module, - const char* name, - BinaryenType result, - BinaryenType* paramTypes, - BinaryenIndex numParams); -// Removes a function type. -BINARYEN_API void BinaryenRemoveFunctionType(BinaryenModuleRef module, - const char* name); - // Literals. These are passed by value. struct BinaryenLiteral { @@ -632,7 +615,8 @@ BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, - const char* type); + BinaryenType params, + BinaryenType results); BINARYEN_API BinaryenExpressionRef BinaryenReturnCall(BinaryenModuleRef module, const char* target, @@ -644,7 +628,8 @@ BinaryenReturnCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, - const char* type); + BinaryenType params, + BinaryenType results); // LocalGet: Note the 'type' parameter. It might seem redundant, since the // local at that index must have a type. However, this API lets you @@ -1083,7 +1068,8 @@ BINARYEN_REF(Function); BINARYEN_API BinaryenFunctionRef BinaryenAddFunction(BinaryenModuleRef module, const char* name, - BinaryenFunctionTypeRef type, + BinaryenType params, + BinaryenType results, BinaryenType* varTypes, BinaryenIndex numVarTypes, BinaryenExpressionRef body); @@ -1102,12 +1088,12 @@ BinaryenGetFunctionByIndex(BinaryenModuleRef module, BinaryenIndex id); // Imports -BINARYEN_API void -BinaryenAddFunctionImport(BinaryenModuleRef module, - const char* internalName, - const char* externalModuleName, - const char* externalBaseName, - BinaryenFunctionTypeRef functionType); +BINARYEN_API void BinaryenAddFunctionImport(BinaryenModuleRef module, + const char* internalName, + const char* externalModuleName, + const char* externalBaseName, + BinaryenType params, + BinaryenType results); BINARYEN_API void BinaryenAddTableImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, @@ -1361,41 +1347,16 @@ BinaryenModuleGetDebugInfoFileName(BinaryenModuleRef module, BinaryenIndex index); // -// ======== FunctionType Operations ======== -// - -// Gets the name of the specified `FunctionType`. -BINARYEN_API const char* -BinaryenFunctionTypeGetName(BinaryenFunctionTypeRef ftype); -// Gets the number of parameters of the specified `FunctionType`. -BINARYEN_API BinaryenIndex -BinaryenFunctionTypeGetNumParams(BinaryenFunctionTypeRef ftype); -// Gets the type of the parameter at the specified index of the specified -// `FunctionType`. -BINARYEN_API BinaryenType BinaryenFunctionTypeGetParam( - BinaryenFunctionTypeRef ftype, BinaryenIndex index); -// Gets the result type of the specified `FunctionType`. -BINARYEN_API BinaryenType -BinaryenFunctionTypeGetResult(BinaryenFunctionTypeRef ftype); - -// // ========== Function Operations ========== // // Gets the name of the specified `Function`. BINARYEN_API const char* BinaryenFunctionGetName(BinaryenFunctionRef func); -// Gets the name of the `FunctionType` associated with the specified `Function`. -// May be `NULL` if the signature is implicit. -BINARYEN_API const char* BinaryenFunctionGetType(BinaryenFunctionRef func); -// Gets the number of parameters of the specified `Function`. -BINARYEN_API BinaryenIndex -BinaryenFunctionGetNumParams(BinaryenFunctionRef func); // Gets the type of the parameter at the specified index of the specified // `Function`. -BINARYEN_API BinaryenType BinaryenFunctionGetParam(BinaryenFunctionRef func, - BinaryenIndex index); +BINARYEN_API BinaryenType BinaryenFunctionGetParams(BinaryenFunctionRef func); // Gets the result type of the specified `Function`. -BINARYEN_API BinaryenType BinaryenFunctionGetResult(BinaryenFunctionRef func); +BINARYEN_API BinaryenType BinaryenFunctionGetResults(BinaryenFunctionRef func); // Gets the number of additional locals within the specified `Function`. BINARYEN_API BinaryenIndex BinaryenFunctionGetNumVars(BinaryenFunctionRef func); // Gets the type of the additional local at the specified index within the @@ -1574,17 +1535,6 @@ BINARYEN_API void BinaryenSetAPITracing(int on); // ========= Utilities ========= // -// Note that this function has been added because there is no better alternative -// currently and is scheduled for removal once there is one. It takes the same -// set of parameters as BinaryenAddFunctionType but instead of adding a new -// function signature, it returns a pointer to the existing signature or NULL if -// there is no such signature yet. -BINARYEN_API BinaryenFunctionTypeRef -BinaryenGetFunctionTypeBySignature(BinaryenModuleRef module, - BinaryenType result, - BinaryenType* paramTypes, - BinaryenIndex numParams); - // Enable or disable coloring for the WASM printer BINARYEN_API void BinaryenSetColorsEnabled(int enabled); diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index f7b7b636e..5128f525b 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -59,7 +59,7 @@ bool ExpressionAnalyzer::isResultUsed(ExpressionStack& stack, Function* func) { } } // The value might be used, so it depends on if the function returns - return func->result != none; + return func->sig.results != Type::none; } // Checks if a value is dropped. @@ -137,7 +137,8 @@ template<typename T> void visitImmediates(Expression* curr, T& visitor) { visitor.visitInt(curr->isReturn); } void visitCallIndirect(CallIndirect* curr) { - visitor.visitNonScopeName(curr->fullType); + visitor.visitInt(curr->sig.params); + visitor.visitInt(curr->sig.results); visitor.visitInt(curr->isReturn); } void visitLocalGet(LocalGet* curr) { visitor.visitIndex(curr->index); } diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp index 2542e5743..fd0e6fd75 100644 --- a/src/ir/ExpressionManipulator.cpp +++ b/src/ir/ExpressionManipulator.cpp @@ -79,12 +79,12 @@ flexibleCopy(Expression* original, Module& wasm, CustomCopier custom) { return ret; } Expression* visitCallIndirect(CallIndirect* curr) { - auto* ret = builder.makeCallIndirect( - curr->fullType, copy(curr->target), {}, curr->type, curr->isReturn); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); + std::vector<Expression*> copiedOps; + for (auto op : curr->operands) { + copiedOps.push_back(copy(op)); } - return ret; + return builder.makeCallIndirect( + copy(curr->target), copiedOps, curr->sig, curr->isReturn); } Expression* visitLocalGet(LocalGet* curr) { return builder.makeLocalGet(curr->index, curr->type); diff --git a/src/ir/ReFinalize.cpp b/src/ir/ReFinalize.cpp index a8054d261..be0a8604b 100644 --- a/src/ir/ReFinalize.cpp +++ b/src/ir/ReFinalize.cpp @@ -145,15 +145,12 @@ void ReFinalize::visitPop(Pop* curr) { curr->finalize(); } void ReFinalize::visitFunction(Function* curr) { // we may have changed the body from unreachable to none, which might be bad // if the function has a return value - if (curr->result != none && curr->body->type == none) { + if (curr->sig.results != Type::none && curr->body->type == Type::none) { Builder builder(*getModule()); curr->body = builder.blockify(curr->body, builder.makeUnreachable()); } } -void ReFinalize::visitFunctionType(FunctionType* curr) { - WASM_UNREACHABLE("unimp"); -} void ReFinalize::visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitGlobal(Global* curr) { WASM_UNREACHABLE("unimp"); } void ReFinalize::visitTable(Table* curr) { WASM_UNREACHABLE("unimp"); } diff --git a/src/ir/function-type-utils.h b/src/ir/function-type-utils.h deleted file mode 100644 index ee1e95f70..000000000 --- a/src/ir/function-type-utils.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2018 WebAssembly Community Group participants - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef wasm_ir_function_type_utils_h -#define wasm_ir_function_type_utils_h - -namespace wasm { - -namespace FunctionTypeUtils { - -// Fills in function info from a function type -inline void fillFunction(Function* func, FunctionType* type) { - func->params = type->params; - func->result = type->result; -} - -} // namespace FunctionTypeUtils - -} // namespace wasm - -#endif // wasm_ir_function_type_utils_h diff --git a/src/ir/function-utils.h b/src/ir/function-utils.h index 61af153a0..f172240e2 100644 --- a/src/ir/function-utils.h +++ b/src/ir/function-utils.h @@ -28,23 +28,17 @@ namespace FunctionUtils { // everything but their name (which can't be the same, in the same // module!) - same params, vars, body, result, etc. inline bool equal(Function* left, Function* right) { - if (left->getNumParams() != right->getNumParams()) { + if (left->sig != right->sig) { return false; } if (left->getNumVars() != right->getNumVars()) { return false; } - for (Index i = 0; i < left->getNumLocals(); i++) { + for (Index i = left->sig.params.size(); i < left->getNumLocals(); i++) { if (left->getLocalType(i) != right->getLocalType(i)) { return false; } } - if (left->result != right->result) { - return false; - } - if (left->type != right->type) { - return false; - } if (!left->imported() && !right->imported()) { return ExpressionAnalyzer::equal(left->body, right->body); } diff --git a/src/ir/hashed.h b/src/ir/hashed.h index a3d285cf5..9e9717cda 100644 --- a/src/ir/hashed.h +++ b/src/ir/hashed.h @@ -82,18 +82,11 @@ struct FunctionHasher : public WalkerPass<PostWalker<FunctionHasher>> { static HashType hashFunction(Function* func) { HashType ret = 0; - ret = rehash(ret, (HashType)func->getNumParams()); - for (auto type : func->params) { - ret = rehash(ret, (HashType)type); - } - ret = rehash(ret, (HashType)func->getNumVars()); + ret = rehash(ret, (HashType)func->sig.params); + ret = rehash(ret, (HashType)func->sig.results); for (auto type : func->vars) { ret = rehash(ret, (HashType)type); } - ret = rehash(ret, (HashType)func->result); - ret = rehash(ret, - HashType(func->type.is() ? std::hash<wasm::Name>{}(func->type) - : HashType(0))); ret = rehash(ret, (HashType)ExpressionAnalyzer::hash(func->body)); return ret; } diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index e212ae8dc..d84648dfd 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -63,11 +63,8 @@ struct BinaryIndexes { inline Function* copyFunction(Function* func, Module& out) { auto* ret = new Function(); ret->name = func->name; - ret->result = func->result; - ret->params = func->params; + ret->sig = func->sig; ret->vars = func->vars; - // start with no named type; the names in the other module may differ - ret->type = Name(); ret->localNames = func->localNames; ret->localIndices = func->localIndices; ret->debugLocations = func->debugLocations; @@ -108,9 +105,6 @@ inline Event* copyEvent(Event* event, Module& out) { inline void copyModule(const Module& in, Module& out) { // we use names throughout, not raw pointers, so simple copying is fine // for everything *but* expressions - for (auto& curr : in.functionTypes) { - out.addFunctionType(make_unique<FunctionType>(*curr)); - } for (auto& curr : in.exports) { out.addExport(new Export(*curr)); } @@ -137,7 +131,6 @@ inline void copyModule(const Module& in, Module& out) { } inline void clearModule(Module& wasm) { - wasm.functionTypes.clear(); wasm.exports.clear(); wasm.functions.clear(); wasm.globals.clear(); @@ -413,6 +406,64 @@ template<typename T> struct CallGraphPropertyAnalysis { } }; +// Helper function for collecting the type signature used in a module +// +// Used when emitting or printing a module to give signatures canonical +// indices. Signatures are sorted in order of decreasing frequency to minize the +// size of their collective encoding. Both a vector mapping indices to +// signatures and a map mapping signatures to indices are produced. +inline void +collectSignatures(Module& wasm, + std::vector<Signature>& signatures, + std::unordered_map<Signature, Index>& sigIndices) { + using Counts = std::unordered_map<Signature, size_t>; + + // Collect the signature use counts for a single function + auto updateCounts = [&](Function* func, Counts& counts) { + if (func->imported()) { + return; + } + struct TypeCounter : PostWalker<TypeCounter> { + Counts& counts; + + TypeCounter(Counts& counts) : counts(counts) {} + + void visitCallIndirect(CallIndirect* curr) { counts[curr->sig]++; } + }; + TypeCounter(counts).walk(func->body); + }; + + ModuleUtils::ParallelFunctionAnalysis<Counts> analysis(wasm, updateCounts); + + // Collect all the counts. + Counts counts; + for (auto& curr : wasm.functions) { + counts[curr->sig]++; + } + for (auto& curr : wasm.events) { + counts[curr->sig]++; + } + for (auto& pair : analysis.map) { + Counts& functionCounts = pair.second; + for (auto& innerPair : functionCounts) { + counts[innerPair.first] += innerPair.second; + } + } + std::vector<std::pair<Signature, size_t>> sorted(counts.begin(), + counts.end()); + std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { + // order by frequency then simplicity + if (a.second != b.second) { + return a.second > b.second; + } + return a.first < b.first; + }); + for (Index i = 0; i < sorted.size(); ++i) { + sigIndices[sorted[i].first] = i; + signatures.push_back(sorted[i].first); + } +} + } // namespace ModuleUtils } // namespace wasm diff --git a/src/ir/utils.h b/src/ir/utils.h index 722277bc3..cad7bc885 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -157,7 +157,6 @@ struct ReFinalize void visitFunction(Function* curr); - void visitFunctionType(FunctionType* curr); void visitExport(Export* curr); void visitGlobal(Global* curr); void visitTable(Table* curr); @@ -220,7 +219,6 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitPush(Push* curr) { curr->finalize(); } void visitPop(Pop* curr) { curr->finalize(); } - void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE("unimp"); } void visitExport(Export* curr) { WASM_UNREACHABLE("unimp"); } void visitGlobal(Global* curr) { WASM_UNREACHABLE("unimp"); } void visitTable(Table* curr) { WASM_UNREACHABLE("unimp"); } @@ -300,7 +298,7 @@ struct AutoDrop : public WalkerPass<ExpressionStackWalker<AutoDrop>> { void doWalkFunction(Function* curr) { ReFinalize().walkFunctionInModule(curr, getModule()); walk(curr->body); - if (curr->result == none && curr->body->type.isConcrete()) { + if (curr->sig.results == Type::none && curr->body->type.isConcrete()) { curr->body = Builder(*getModule()).makeDrop(curr->body); } ReFinalize().walkFunctionInModule(curr, getModule()); diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 3d732c1fc..65a4e15dc 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -1,5 +1,4 @@ // export friendly API methods - function preserveStack(func) { try { var stack = stackSave(); @@ -518,9 +517,9 @@ function wrapModule(module, self) { return Module['_BinaryenCall'](module, strToStack(name), i32sToStack(operands), operands.length, type); }); }; - self['callIndirect'] = self['call_indirect'] = function(target, operands, type) { + self['callIndirect'] = self['call_indirect'] = function(target, operands, params, results) { return preserveStack(function() { - return Module['_BinaryenCallIndirect'](module, target, i32sToStack(operands), operands.length, strToStack(type)); + return Module['_BinaryenCallIndirect'](module, target, i32sToStack(operands), operands.length, params, results); }); }; self['returnCall'] = function(name, operands, type) { @@ -528,9 +527,9 @@ function wrapModule(module, self) { return Module['_BinaryenReturnCall'](module, strToStack(name), i32sToStack(operands), operands.length, type); }); }; - self['returnCallIndirect'] = function(target, operands, type) { + self['returnCallIndirect'] = function(target, operands, params, results) { return preserveStack(function() { - return Module['_BinaryenReturnCallIndirect'](module, target, i32sToStack(operands), operands.length, strToStack(type)); + return Module['_BinaryenReturnCallIndirect'](module, target, i32sToStack(operands), operands.length, params, results); }); }; @@ -2008,28 +2007,9 @@ function wrapModule(module, self) { }; // 'Module' operations - self['addFunctionType'] = function(name, result, paramTypes) { - if (!paramTypes) paramTypes = []; - return preserveStack(function() { - return Module['_BinaryenAddFunctionType'](module, strToStack(name), result, - i32sToStack(paramTypes), paramTypes.length); - }); - }; - self['getFunctionTypeBySignature'] = function(result, paramTypes) { - if (!paramTypes) paramTypes = []; - return preserveStack(function() { - return Module['_BinaryenGetFunctionTypeBySignature'](module, result, - i32sToStack(paramTypes), paramTypes.length); - }); - }; - self['removeFunctionType'] = function(name) { + self['addFunction'] = function(name, params, results, varTypes, body) { return preserveStack(function() { - return Module['_BinaryenRemoveFunctionType'](module, strToStack(name)); - }); - }; - self['addFunction'] = function(name, functionType, varTypes, body) { - return preserveStack(function() { - return Module['_BinaryenAddFunction'](module, strToStack(name), functionType, i32sToStack(varTypes), varTypes.length, body); + return Module['_BinaryenAddFunction'](module, strToStack(name), params, results, i32sToStack(varTypes), varTypes.length, body); }); }; self['getFunction'] = function(name) { @@ -2072,9 +2052,9 @@ function wrapModule(module, self) { return Module['_BinaryenRemoveEvent'](module, strToStack(name)); }); }; - self['addFunctionImport'] = function(internalName, externalModuleName, externalBaseName, functionType) { + self['addFunctionImport'] = function(internalName, externalModuleName, externalBaseName, params, results) { return preserveStack(function() { - return Module['_BinaryenAddFunctionImport'](module, strToStack(internalName), strToStack(externalModuleName), strToStack(externalBaseName), functionType); + return Module['_BinaryenAddFunctionImport'](module, strToStack(internalName), strToStack(externalModuleName), strToStack(externalBaseName), params, results); }); }; self['addTableImport'] = function(internalName, externalModuleName, externalBaseName) { @@ -2698,24 +2678,14 @@ Module['getExpressionInfo'] = function(expr) { } }; -// Obtains information about a 'FunctionType' -Module['getFunctionTypeInfo'] = function(func) { - return { - 'name': UTF8ToString(Module['_BinaryenFunctionTypeGetName'](func)), - 'params': getAllNested(func, Module['_BinaryenFunctionTypeGetNumParams'], Module['_BinaryenFunctionTypeGetParam']), - 'result': Module['_BinaryenFunctionTypeGetResult'](func) - }; -}; - // Obtains information about a 'Function' Module['getFunctionInfo'] = function(func) { return { 'name': UTF8ToString(Module['_BinaryenFunctionGetName'](func)), 'module': UTF8ToString(Module['_BinaryenFunctionImportGetModule'](func)), 'base': UTF8ToString(Module['_BinaryenFunctionImportGetBase'](func)), - 'type': UTF8ToString(Module['_BinaryenFunctionGetType'](func)), - 'params': getAllNested(func, Module['_BinaryenFunctionGetNumParams'], Module['_BinaryenFunctionGetParam']), - 'result': Module['_BinaryenFunctionGetResult'](func), + 'params': Module['_BinaryenFunctionGetParams'](func), + 'results': Module['_BinaryenFunctionGetResults'](func), 'vars': getAllNested(func, Module['_BinaryenFunctionGetNumVars'], Module['_BinaryenFunctionGetVar']), 'body': Module['_BinaryenFunctionGetBody'](func) }; diff --git a/src/passes/Asyncify.cpp b/src/passes/Asyncify.cpp index d92181639..8e583863c 100644 --- a/src/passes/Asyncify.cpp +++ b/src/passes/Asyncify.cpp @@ -706,7 +706,7 @@ struct AsyncifyFlow : public Pass { State::Rewinding), // TODO: such checks can be !normal makeCallIndexPop()), process(func->body)}); - if (func->result != none) { + if (func->sig.results != Type::none) { // Rewriting control flow may alter things; make sure the function ends in // something valid (which the optimizer can remove later). block->list.push_back(builder->makeUnreachable()); @@ -1045,7 +1045,7 @@ struct AsyncifyLocals : public WalkerPass<PostWalker<AsyncifyLocals>> { walk(func->body); // After the normal function body, emit a barrier before the postamble. Expression* barrier; - if (func->result == none) { + if (func->sig.results == Type::none) { // The function may have ended without a return; ensure one. barrier = builder->makeReturn(); } else { @@ -1063,12 +1063,12 @@ struct AsyncifyLocals : public WalkerPass<PostWalker<AsyncifyLocals>> { builder->makeSequence(func->body, barrier))), makeCallIndexPush(unwindIndex), makeLocalSaving()}); - if (func->result != none) { + if (func->sig.results != Type::none) { // If we unwind, we must still "return" a value, even if it will be // ignored on the outside. newBody->list.push_back( - LiteralUtils::makeZero(func->result, *getModule())); - newBody->finalize(func->result); + LiteralUtils::makeZero(func->sig.results, *getModule())); + newBody->finalize(func->sig.results); } func->body = newBody; // Making things like returns conditional may alter types. @@ -1324,7 +1324,7 @@ private: builder.makeUnreachable())); body->finalize(); auto* func = builder.makeFunction( - name, std::move(params), none, std::vector<Type>{}, body); + name, Signature(Type(params), Type::none), {}, body); module->addFunction(func); module->addExport(builder.makeExport(name, name, ExternalKind::Function)); }; diff --git a/src/passes/CMakeLists.txt b/src/passes/CMakeLists.txt index 7ec9054d7..8eb3c7e88 100644 --- a/src/passes/CMakeLists.txt +++ b/src/passes/CMakeLists.txt @@ -53,7 +53,6 @@ set(passes_SOURCES StripTargetFeatures.cpp RedundantSetElimination.cpp RelooperJumpThreading.cpp - ReReloop.cpp RemoveImports.cpp RemoveMemory.cpp RemoveNonJSOps.cpp @@ -62,6 +61,7 @@ set(passes_SOURCES RemoveUnusedModuleElements.cpp ReorderLocals.cpp ReorderFunctions.cpp + ReReloop.cpp TrapMode.cpp SafeHeap.cpp SimplifyGlobals.cpp diff --git a/src/passes/CodeFolding.cpp b/src/passes/CodeFolding.cpp index 818bb33b8..947e64715 100644 --- a/src/passes/CodeFolding.cpp +++ b/src/passes/CodeFolding.cpp @@ -718,7 +718,7 @@ private: mergeable.pop_back(); } // ensure the replacement has the same type, so the outside is not surprised - outer->finalize(getFunction()->result); + outer->finalize(getFunction()->sig.results); getFunction()->body = outer; return true; } diff --git a/src/passes/DataFlowOpts.cpp b/src/passes/DataFlowOpts.cpp index d4a3cc087..cf0092d37 100644 --- a/src/passes/DataFlowOpts.cpp +++ b/src/passes/DataFlowOpts.cpp @@ -138,7 +138,7 @@ struct DataFlowOpts : public WalkerPass<PostWalker<DataFlowOpts>> { // XXX we should copy expr here, in principle, and definitely will need to // when we do arbitrarily regenerated expressions auto* func = Builder(temp).makeFunction( - "temp", std::vector<Type>{}, none, std::vector<Type>{}, expr); + "temp", Signature(Type::none, Type::none), {}, expr); PassRunner runner(&temp); runner.setIsNested(true); runner.add("precompute"); diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 8e1da3ac5..4395fc780 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -362,7 +362,7 @@ struct DAE : public Pass { // once to remove a param, once to drop the return value). if (changed.empty()) { for (auto& func : module->functions) { - if (func->result == none) { + if (func->sig.results == Type::none) { continue; } auto name = func->name; @@ -403,15 +403,15 @@ private: std::unordered_map<Call*, Expression**> allDroppedCalls; void removeParameter(Function* func, Index i, std::vector<Call*>& calls) { - // Clear the type, which is no longer accurate. - func->type = Name(); // It's cumbersome to adjust local names - TODO don't clear them? Builder::clearLocalNames(func); // Remove the parameter from the function. We must add a new local // for uses of the parameter, but cannot make it use the same index // (in general). - auto type = func->getLocalType(i); - func->params.erase(func->params.begin() + i); + std::vector<Type> params = func->sig.params.expand(); + auto type = params[i]; + params.erase(params.begin() + i); + func->sig.params = Type(params); Index newIndex = Builder::addVar(func, type); // Update local operations. struct LocalUpdater : public PostWalker<LocalUpdater> { @@ -439,9 +439,7 @@ private: void removeReturnValue(Function* func, std::vector<Call*>& calls, Module* module) { - // Clear the type, which is no longer accurate. - func->type = Name(); - func->result = none; + func->sig.results = Type::none; Builder builder(*module); // Remove any return values. struct ReturnUpdater : public PostWalker<ReturnUpdater> { diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 663b8257b..52aa7e087 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -58,8 +58,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { return; } auto* func = getModule()->getFunction(name); - if (getSig(getModule()->getFunctionType(curr->fullType)) != - getSig(func)) { + if (curr->sig != func->sig) { replaceWithUnreachable(curr); return; } diff --git a/src/passes/DuplicateImportElimination.cpp b/src/passes/DuplicateImportElimination.cpp index 39b126b6c..87126b139 100644 --- a/src/passes/DuplicateImportElimination.cpp +++ b/src/passes/DuplicateImportElimination.cpp @@ -42,7 +42,7 @@ struct DuplicateImportElimination : public Pass { auto previousFunc = module->getFunction(previousName); // It is ok to import the same thing with multiple types; we can only // merge if the types match, of course. - if (getSig(previousFunc) == getSig(func)) { + if (previousFunc->sig == func->sig) { replacements[func->name] = previousName; toRemove.push_back(func->name); continue; diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index f36e1e909..729a4a6c3 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -131,7 +131,7 @@ struct ParallelFuncCastEmulation Pass* create() override { return new ParallelFuncCastEmulation(ABIType); } - ParallelFuncCastEmulation(Name ABIType) : ABIType(ABIType) {} + ParallelFuncCastEmulation(Signature ABIType) : ABIType(ABIType) {} void visitCallIndirect(CallIndirect* curr) { if (curr->operands.size() > NUM_PARAMS) { @@ -146,7 +146,7 @@ struct ParallelFuncCastEmulation curr->operands.push_back(LiteralUtils::makeZero(i64, *getModule())); } // Set the new types - curr->fullType = ABIType; + curr->sig = ABIType; auto oldType = curr->type; curr->type = i64; curr->finalize(); // may be unreachable @@ -155,18 +155,15 @@ struct ParallelFuncCastEmulation } private: - // the name of a type for a call with the right params and return - Name ABIType; + // The signature of a call with the right params and return + Signature ABIType; }; struct FuncCastEmulation : public Pass { void run(PassRunner* runner, Module* module) override { // we just need the one ABI function type for all indirect calls - std::string sig = "j"; - for (Index i = 0; i < NUM_PARAMS; i++) { - sig += 'j'; - } - ABIType = ensureFunctionType(sig, module)->name; + Signature ABIType(Type(std::vector<Type>(NUM_PARAMS, Type::i64)), + Type::i64); // Add a way for JS to call into the table (as our i64 ABI means an i64 // is returned when there is a return value, which JS engines will fail on), // using dynCalls @@ -191,9 +188,6 @@ struct FuncCastEmulation : public Pass { } private: - // the name of a type for a call with the right params and return - Name ABIType; - // Creates a thunk for a function, casting args and return value as needed. Name makeThunk(Name name, Module* module) { Name thunk = std::string("byn$fpcast-emu$") + name.str; @@ -203,8 +197,8 @@ private: } // The item in the table may be a function or a function import. auto* func = module->getFunction(name); - std::vector<Type>& params = func->params; - Type type = func->result; + const std::vector<Type>& params = func->sig.params.expand(); + Type type = func->sig.results; Builder builder(*module); std::vector<Expression*> callOperands; for (Index i = 0; i < params.size(); i++) { @@ -216,12 +210,11 @@ private: for (Index i = 0; i < NUM_PARAMS; i++) { thunkParams.push_back(i64); } - auto* thunkFunc = builder.makeFunction(thunk, - std::move(thunkParams), - i64, - {}, // no vars - toABI(call, module)); - thunkFunc->type = ABIType; + auto* thunkFunc = + builder.makeFunction(thunk, + Signature(Type(thunkParams), Type::i64), + {}, // no vars + toABI(call, module)); module->addFunction(thunkFunc); return thunk; } diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp index e2f744957..c9a4f46ea 100644 --- a/src/passes/I64ToI32Lowering.cpp +++ b/src/passes/I64ToI32Lowering.cpp @@ -147,22 +147,6 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { PostWalker<I64ToI32Lowering>::doWalkModule(module); } - void visitFunctionType(FunctionType* curr) { - std::vector<Type> params; - for (auto t : curr->params) { - if (t == i64) { - params.push_back(i32); - params.push_back(i32); - } else { - params.push_back(t); - } - } - std::swap(params, curr->params); - if (curr->result == i64) { - curr->result = i32; - } - } - void doWalkFunction(Function* func) { Flat::verifyFlatness(func); // create builder here if this is first entry to module for this object @@ -174,7 +158,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { freeTemps.clear(); Module temp; auto* oldFunc = ModuleUtils::copyFunction(func, temp); - func->params.clear(); + func->sig.params = Type::none; func->vars.clear(); func->localNames.clear(); func->localIndices.clear(); @@ -190,8 +174,8 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ? Builder::addParam : static_cast<Index (*)(Function*, Name, Type)>(Builder::addVar); if (paramType == i64) { - builderFunc(func, lowName, i32); - builderFunc(func, highName, i32); + builderFunc(func, lowName, Type::i32); + builderFunc(func, highName, Type::i32); indexMap[i] = newIdx; newIdx += 2; } else { @@ -207,8 +191,8 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { if (func->imported()) { return; } - if (func->result == i64) { - func->result = i32; + if (func->sig.results == Type::i64) { + func->sig.results = Type::i32; // body may not have out param if it ends with control flow if (hasOutParam(func->body)) { TempVar highBits = fetchOutParam(func->body); @@ -244,14 +228,14 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { fixed = true; } } - if (curr->type != i64) { + if (curr->type != Type::i64) { auto* ret = callBuilder(args, curr->type); replaceCurrent(ret); return fixed ? ret : nullptr; } TempVar lowBits = getTemp(); TempVar highBits = getTemp(); - auto* call = callBuilder(args, i32); + auto* call = callBuilder(args, Type::i32); LocalSet* doCall = builder->makeLocalSet(lowBits, call); LocalSet* setHigh = builder->makeLocalSet( highBits, builder->makeGlobalGet(INT64_TO_32_HIGH_BITS, i32)); @@ -263,13 +247,13 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { } void visitCall(Call* curr) { if (curr->isReturn && - getModule()->getFunction(curr->target)->result == i64) { + getModule()->getFunction(curr->target)->sig.results == Type::i64) { Fatal() << "i64 to i32 lowering of return_call values not yet implemented"; } auto* fixedCall = visitGenericCall<Call>( - curr, [&](std::vector<Expression*>& args, Type ty) { - return builder->makeCall(curr->target, args, ty, curr->isReturn); + curr, [&](std::vector<Expression*>& args, Type results) { + return builder->makeCall(curr->target, args, results, curr->isReturn); }); // If this was to an import, we need to call the legal version. This assumes // that legalize-js-interface has been run before. @@ -280,15 +264,23 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { } void visitCallIndirect(CallIndirect* curr) { - if (curr->isReturn && - getModule()->getFunctionType(curr->fullType)->result == i64) { + if (curr->isReturn && curr->sig.results == Type::i64) { Fatal() << "i64 to i32 lowering of return_call values not yet implemented"; } visitGenericCall<CallIndirect>( - curr, [&](std::vector<Expression*>& args, Type ty) { + curr, [&](std::vector<Expression*>& args, Type results) { + std::vector<Type> params; + for (auto param : curr->sig.params.expand()) { + if (param == Type::i64) { + params.push_back(Type::i32); + params.push_back(Type::i32); + } else { + params.push_back(param); + } + } return builder->makeCallIndirect( - curr->fullType, curr->target, args, ty, curr->isReturn); + curr->target, args, Signature(Type(params), results), curr->isReturn); }); } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 96c4531c8..db1db5971 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -198,12 +198,12 @@ struct Updater : public PostWalker<Updater> { } void visitCall(Call* curr) { if (curr->isReturn) { - handleReturnCall(curr, module->getFunction(curr->target)->result); + handleReturnCall(curr, module->getFunction(curr->target)->sig.results); } } void visitCallIndirect(CallIndirect* curr) { if (curr->isReturn) { - handleReturnCall(curr, module->getFunctionType(curr->fullType)->result); + handleReturnCall(curr, curr->sig.results); } } void visitLocalGet(LocalGet* curr) { @@ -221,7 +221,7 @@ doInlining(Module* module, Function* into, const InliningAction& action) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); // Works for return_call, too - Type retType = module->getFunction(call->target)->result; + Type retType = module->getFunction(call->target)->sig.results; Builder builder(*module); auto* block = builder.makeBlock(); block->name = Name(std::string("__inlined_func$") + from->name.str); @@ -244,7 +244,7 @@ doInlining(Module* module, Function* into, const InliningAction& action) { updater.localMapping[i] = builder.addVar(into, from->getLocalType(i)); } // Assign the operands into the params - for (Index i = 0; i < from->params.size(); i++) { + for (Index i = 0; i < from->sig.params.size(); i++) { block->list.push_back( builder.makeLocalSet(updater.localMapping[i], call->operands[i])); } diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp index 0b21c5000..407903219 100644 --- a/src/passes/InstrumentLocals.cpp +++ b/src/passes/InstrumentLocals.cpp @@ -45,7 +45,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "shared-constants.h" #include <pass.h> #include <wasm-builder.h> @@ -147,36 +146,38 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { } void visitModule(Module* curr) { - addImport(curr, get_i32, "iiii"); - addImport(curr, get_i64, "jiij"); - addImport(curr, get_f32, "fiif"); - addImport(curr, get_f64, "diid"); - addImport(curr, set_i32, "iiii"); - addImport(curr, set_i64, "jiij"); - addImport(curr, set_f32, "fiif"); - addImport(curr, set_f64, "diid"); + addImport(curr, get_i32, {Type::i32, Type::i32, Type::i32}, Type::i32); + addImport(curr, get_i64, {Type::i32, Type::i32, Type::i64}, Type::i64); + addImport(curr, get_f32, {Type::i32, Type::i32, Type::f32}, Type::f32); + addImport(curr, get_f64, {Type::i32, Type::i32, Type::f64}, Type::f64); + addImport(curr, set_i32, {Type::i32, Type::i32, Type::i32}, Type::i32); + addImport(curr, set_i64, {Type::i32, Type::i32, Type::i64}, Type::i64); + addImport(curr, set_f32, {Type::i32, Type::i32, Type::f32}, Type::f32); + addImport(curr, set_f64, {Type::i32, Type::i32, Type::f64}, Type::f64); if (curr->features.hasReferenceTypes()) { - addImport(curr, get_anyref, "aiia"); - addImport(curr, set_anyref, "aiia"); + addImport( + curr, get_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); + addImport( + curr, set_anyref, {Type::i32, Type::i32, Type::anyref}, Type::anyref); } if (curr->features.hasExceptionHandling()) { - addImport(curr, get_exnref, "eiie"); - addImport(curr, set_exnref, "eiie"); + addImport( + curr, get_exnref, {Type::i32, Type::i32, Type::exnref}, Type::exnref); + addImport( + curr, set_exnref, {Type::i32, Type::i32, Type::exnref}, Type::exnref); } } private: Index id = 0; - void addImport(Module* wasm, Name name, std::string sig) { + void addImport(Module* wasm, Name name, Type params, Type results) { auto import = new Function; import->name = name; import->module = ENV; import->base = name; - auto* functionType = ensureFunctionType(sig, wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(params, results); wasm->addFunction(import); } }; diff --git a/src/passes/InstrumentMemory.cpp b/src/passes/InstrumentMemory.cpp index 4a479db34..9a805b19b 100644 --- a/src/passes/InstrumentMemory.cpp +++ b/src/passes/InstrumentMemory.cpp @@ -54,7 +54,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "shared-constants.h" #include <pass.h> #include <wasm-builder.h> @@ -141,29 +140,29 @@ struct InstrumentMemory : public WalkerPass<PostWalker<InstrumentMemory>> { } void visitModule(Module* curr) { - addImport(curr, load_ptr, "iiiii"); - addImport(curr, load_val_i32, "iii"); - addImport(curr, load_val_i64, "jij"); - addImport(curr, load_val_f32, "fif"); - addImport(curr, load_val_f64, "did"); - addImport(curr, store_ptr, "iiiii"); - addImport(curr, store_val_i32, "iii"); - addImport(curr, store_val_i64, "jij"); - addImport(curr, store_val_f32, "fif"); - addImport(curr, store_val_f64, "did"); + addImport( + curr, load_ptr, {Type::i32, Type::i32, Type::i32, Type::i32}, Type::i32); + addImport(curr, load_val_i32, {Type::i32, Type::i32}, Type::i32); + addImport(curr, load_val_i64, {Type::i32, Type::i64}, Type::i64); + addImport(curr, load_val_f32, {Type::i32, Type::f32}, Type::f32); + addImport(curr, load_val_f64, {Type::i32, Type::f64}, Type::f64); + addImport( + curr, store_ptr, {Type::i32, Type::i32, Type::i32, Type::i32}, Type::i32); + addImport(curr, store_val_i32, {Type::i32, Type::i32}, Type::i32); + addImport(curr, store_val_i64, {Type::i32, Type::i64}, Type::i64); + addImport(curr, store_val_f32, {Type::i32, Type::f32}, Type::f32); + addImport(curr, store_val_f64, {Type::i32, Type::f64}, Type::f64); } private: Index id; - void addImport(Module* curr, Name name, std::string sig) { + void addImport(Module* curr, Name name, Type params, Type results) { auto import = new Function; import->name = name; import->module = ENV; import->base = name; - auto* functionType = ensureFunctionType(sig, curr); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(params, results); curr->addFunction(import); } }; diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 4dc680972..8c7bc4414 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -32,7 +32,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "ir/import-utils.h" #include "ir/literal-utils.h" #include "ir/utils.h" @@ -153,12 +152,12 @@ private: std::map<Name, Name> illegalImportsToLegal; template<typename T> bool isIllegal(T* t) { - for (auto param : t->params) { - if (param == i64) { + for (auto param : t->sig.params.expand()) { + if (param == Type::i64) { return true; } } - return t->result == i64; + return t->sig.results == Type::i64; } bool isDynCall(Name name) { return name.startsWith("dynCall_"); } @@ -190,25 +189,29 @@ private: auto* call = module->allocator.alloc<Call>(); call->target = func->name; - call->type = func->result; + call->type = func->sig.results; - for (auto param : func->params) { - if (param == i64) { + const std::vector<Type>& params = func->sig.params.expand(); + std::vector<Type> legalParams; + for (auto param : params) { + if (param == Type::i64) { call->operands.push_back(I64Utilities::recreateI64( - builder, legal->params.size(), legal->params.size() + 1)); - legal->params.push_back(i32); - legal->params.push_back(i32); + builder, legalParams.size(), legalParams.size() + 1)); + legalParams.push_back(Type::i32); + legalParams.push_back(Type::i32); } else { call->operands.push_back( - builder.makeLocalGet(legal->params.size(), param)); - legal->params.push_back(param); + builder.makeLocalGet(legalParams.size(), param)); + legalParams.push_back(param); } } + legal->sig.params = Type(legalParams); - if (func->result == i64) { - Function* f = getFunctionOrImport(module, SET_TEMP_RET0, "vi"); - legal->result = i32; - auto index = Builder::addVar(legal, Name(), i64); + if (func->sig.results == Type::i64) { + Function* f = + getFunctionOrImport(module, SET_TEMP_RET0, Type::i32, Type::none); + legal->sig.results = Type::i32; + auto index = Builder::addVar(legal, Name(), Type::i64); auto* block = builder.makeBlock(); block->list.push_back(builder.makeLocalSet(index, call)); block->list.push_back(builder.makeCall( @@ -217,7 +220,7 @@ private: block->finalize(); legal->body = block; } else { - legal->result = func->result; + legal->sig.results = func->sig.results; legal->body = call; } @@ -232,66 +235,55 @@ private: // JS import Name makeLegalStubForCalledImport(Function* im, Module* module) { Builder builder(*module); - auto type = make_unique<FunctionType>(); - type->name = Name(std::string("legaltype$") + im->name.str); - auto legal = make_unique<Function>(); - legal->name = Name(std::string("legalimport$") + im->name.str); - legal->module = im->module; - legal->base = im->base; - legal->type = type->name; - auto func = make_unique<Function>(); - func->name = Name(std::string("legalfunc$") + im->name.str); + auto legalIm = make_unique<Function>(); + legalIm->name = Name(std::string("legalimport$") + im->name.str); + legalIm->module = im->module; + legalIm->base = im->base; + auto stub = make_unique<Function>(); + stub->name = Name(std::string("legalfunc$") + im->name.str); + stub->sig = im->sig; auto* call = module->allocator.alloc<Call>(); - call->target = legal->name; - - auto* imFunctionType = ensureFunctionType(getSig(im), module); - - for (auto param : imFunctionType->params) { - if (param == i64) { - call->operands.push_back( - I64Utilities::getI64Low(builder, func->params.size())); - call->operands.push_back( - I64Utilities::getI64High(builder, func->params.size())); - type->params.push_back(i32); - type->params.push_back(i32); + call->target = legalIm->name; + + const std::vector<Type>& imParams = im->sig.params.expand(); + std::vector<Type> params; + for (size_t i = 0; i < imParams.size(); ++i) { + if (imParams[i] == Type::i64) { + call->operands.push_back(I64Utilities::getI64Low(builder, i)); + call->operands.push_back(I64Utilities::getI64High(builder, i)); + params.push_back(i32); + params.push_back(i32); } else { - call->operands.push_back( - builder.makeLocalGet(func->params.size(), param)); - type->params.push_back(param); + call->operands.push_back(builder.makeLocalGet(i, imParams[i])); + params.push_back(imParams[i]); } - func->params.push_back(param); } - if (imFunctionType->result == i64) { - Function* f = getFunctionOrImport(module, GET_TEMP_RET0, "i"); - call->type = i32; + if (im->sig.results == Type::i64) { + Function* f = + getFunctionOrImport(module, GET_TEMP_RET0, Type::none, Type::i32); + call->type = Type::i32; Expression* get = builder.makeCall(f->name, {}, call->type); - func->body = I64Utilities::recreateI64(builder, call, get); - type->result = i32; + stub->body = I64Utilities::recreateI64(builder, call, get); } else { - call->type = imFunctionType->result; - func->body = call; - type->result = imFunctionType->result; + call->type = im->sig.results; + stub->body = call; } - func->result = imFunctionType->result; - FunctionTypeUtils::fillFunction(legal.get(), type.get()); + legalIm->sig = Signature(Type(params), call->type); - const auto& funcName = func->name; - if (!module->getFunctionOrNull(funcName)) { - module->addFunction(std::move(func)); - } - if (!module->getFunctionTypeOrNull(type->name)) { - module->addFunctionType(std::move(type)); + const auto& stubName = stub->name; + if (!module->getFunctionOrNull(stubName)) { + module->addFunction(std::move(stub)); } - if (!module->getFunctionOrNull(legal->name)) { - module->addFunction(std::move(legal)); + if (!module->getFunctionOrNull(legalIm->name)) { + module->addFunction(std::move(legalIm)); } - return funcName; + return stubName; } static Function* - getFunctionOrImport(Module* module, Name name, std::string sig) { + getFunctionOrImport(Module* module, Name name, Type params, Type results) { // First look for the function by name if (Function* f = module->getFunctionOrNull(name)) { return f; @@ -306,9 +298,7 @@ private: import->name = name; import->module = ENV; import->base = name; - auto* functionType = ensureFunctionType(std::move(sig), module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(params, results); module->addFunction(import); return import; } diff --git a/src/passes/LogExecution.cpp b/src/passes/LogExecution.cpp index 7bfee7c24..611f79dfd 100644 --- a/src/passes/LogExecution.cpp +++ b/src/passes/LogExecution.cpp @@ -30,7 +30,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "shared-constants.h" #include <pass.h> #include <wasm-builder.h> @@ -63,9 +62,7 @@ struct LogExecution : public WalkerPass<PostWalker<LogExecution>> { import->name = LOGGER; import->module = ENV; import->base = LOGGER; - auto* functionType = ensureFunctionType("vi", curr); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::i32, Type::none); curr->addFunction(import); } diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index a408ccf95..56d04802a 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -50,10 +50,6 @@ struct Metrics ImportInfo imports(*module); // global things - - for (auto& curr : module->functionTypes) { - visitFunctionType(curr.get()); - } for (auto& curr : module->exports) { visitExport(curr.get()); } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 2ec92a086..f8e913079 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -27,14 +27,16 @@ namespace wasm { -static bool isFullForced() { +namespace { + +bool isFullForced() { if (getenv("BINARYEN_PRINT_FULL")) { return std::stoi(getenv("BINARYEN_PRINT_FULL")) != 0; } return false; } -static std::ostream& printName(Name name, std::ostream& o) { +std::ostream& printName(Name name, std::ostream& o) { // we need to quote names if they have tricky chars if (!name.str || !strpbrk(name.str, "()")) { o << '$' << name.str; @@ -55,6 +57,36 @@ static std::ostream& printLocal(Index index, Function* func, std::ostream& o) { return printName(name, o); } +// Wrapper for printing signature names +struct SigName { + Signature sig; + SigName(Signature sig) : sig(sig) {} +}; + +std::ostream& operator<<(std::ostream& os, SigName sigName) { + auto printType = [&](Type type) { + if (type == Type::none) { + os << "none"; + } else { + const std::vector<Type>& types = type.expand(); + for (size_t i = 0; i < types.size(); ++i) { + if (i != 0) { + os << '_'; + } + os << types[i]; + } + } + }; + + os << '$'; + printType(sigName.sig.params); + os << "_=>_"; + printType(sigName.sig.results); + return os; +} + +} // anonymous namespace + // Printing "unreachable" as a instruction prefix type is not valid in wasm text // format. Print something else to make it pass. static Type forceConcrete(Type type) { return type.isConcrete() ? type : i32; } @@ -126,7 +158,7 @@ struct PrintExpressionContents } else { printMedium(o, "call_indirect (type "); } - printName(curr->fullType, o) << ')'; + o << SigName(curr->sig) << ')'; } void visitLocalGet(LocalGet* curr) { printMedium(o, "local.get "); @@ -1869,19 +1901,18 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { o << ')'; } // Module-level visitors - void visitFunctionType(FunctionType* curr, Name* internalName = nullptr) { + void handleSignature(Signature curr, Name* funcName = nullptr) { o << "(func"; - if (internalName) { - o << ' '; - printName(*internalName, o); + if (funcName) { + o << " $" << *funcName; } - if (curr->params.size() > 0) { + if (curr.params.size() > 0) { o << maybeSpace; - o << ParamType(Type(curr->params)); + o << ParamType(curr.params); } - if (curr->result != none) { + if (curr.results.size() > 0) { o << maybeSpace; - o << ResultType(curr->result); + o << ResultType(curr.results); } o << ")"; } @@ -1963,12 +1994,7 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { lastPrintedLocation = {0, 0, 0}; o << '('; emitImportHeader(curr); - if (curr->type.is()) { - visitFunctionType(currModule->getFunctionType(curr->type), &curr->name); - } else { - auto functionType = sigToFunctionType(getSig(curr)); - visitFunctionType(&functionType, &curr->name); - } + handleSignature(curr->sig, &curr->name); o << ')'; o << maybeNewLine; } @@ -1993,21 +2019,19 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { if (!printStackIR && curr->stackIR && !minify) { o << " (; has Stack IR ;)"; } - if (curr->type.is()) { - o << maybeSpace << "(type "; - printName(curr->type, o) << ')'; - } - if (curr->params.size() > 0) { - for (size_t i = 0; i < curr->params.size(); i++) { + const std::vector<Type>& params = curr->sig.params.expand(); + if (params.size() > 0) { + for (size_t i = 0; i < params.size(); i++) { o << maybeSpace; o << '('; printMinor(o, "param "); - printLocal(i, currFunction, o) << ' ' << curr->getLocalType(i) << ')'; + printLocal(i, currFunction, o); + o << ' ' << params[i] << ')'; } } - if (curr->result != none) { + if (curr->sig.results != Type::none) { o << maybeSpace; - o << ResultType(curr->result); + o << ResultType(curr->sig.results); } incIndent(); for (size_t i = curr->getVarIndexBase(); i < curr->getNumLocals(); i++) { @@ -2204,12 +2228,15 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { o << '('; printMajor(o, "module"); incIndent(); - for (auto& child : curr->functionTypes) { + std::vector<Signature> signatures; + std::unordered_map<Signature, Index> indices; + ModuleUtils::collectSignatures(*curr, signatures, indices); + for (auto sig : signatures) { doIndent(o, indent); o << '('; printMedium(o, "type") << ' '; - printName(child->name, o) << ' '; - visitFunctionType(child.get()); + o << SigName(sig) << ' '; + handleSignature(sig); o << ")" << maybeNewLine; } ModuleUtils::iterImportedMemories( diff --git a/src/passes/ReReloop.cpp b/src/passes/ReReloop.cpp index 5d4190d18..53a232dc1 100644 --- a/src/passes/ReReloop.cpp +++ b/src/passes/ReReloop.cpp @@ -321,7 +321,7 @@ struct ReReloop final : public Pass { for (auto* cfgBlock : relooper->Blocks) { auto* block = cfgBlock->Code->cast<Block>(); if (cfgBlock->BranchesOut.empty() && block->type != unreachable) { - block->list.push_back(function->result == none + block->list.push_back(function->sig.results == Type::none ? (Expression*)builder->makeReturn() : (Expression*)builder->makeUnreachable()); block->finalize(); @@ -354,7 +354,8 @@ struct ReReloop final : public Pass { // because of the relooper's boilerplate switch-handling // code, for example, which could be optimized out later // but isn't yet), then make sure it has a proper type - if (function->result != none && function->body->type == none) { + if (function->sig.results != Type::none && + function->body->type == Type::none) { function->body = builder.makeSequence(function->body, builder.makeUnreachable()); } diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 50f7cfa3d..b3b81de3b 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -34,8 +34,8 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { if (!func->imported()) { return; } - Type type = getModule()->getFunctionType(func->type)->result; - if (type == none) { + Type type = func->sig.results; + if (type == Type::none) { replaceCurrent(getModule()->allocator.alloc<Nop>()); } else { Literal nopLiteral; diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index cd69894f3..f5000e3a4 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -130,24 +130,6 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { } }; -// Finds function type usage - -struct FunctionTypeAnalyzer : public PostWalker<FunctionTypeAnalyzer> { - std::vector<Function*> functions; - std::vector<CallIndirect*> indirectCalls; - std::vector<Event*> events; - - void visitFunction(Function* curr) { - if (curr->type.is()) { - functions.push_back(curr); - } - } - - void visitEvent(Event* curr) { events.push_back(curr); } - - void visitCallIndirect(CallIndirect* curr) { indirectCalls.push_back(curr); } -}; - struct RemoveUnusedModuleElements : public Pass { bool rootAllFunctions; @@ -155,11 +137,6 @@ struct RemoveUnusedModuleElements : public Pass { : rootAllFunctions(rootAllFunctions) {} void run(PassRunner* runner, Module* module) override { - optimizeGlobalsAndFunctionsAndEvents(module); - optimizeFunctionTypes(module); - } - - void optimizeGlobalsAndFunctionsAndEvents(Module* module) { std::vector<ModuleElement> roots; // Module start is a root. if (module->start.is()) { @@ -250,39 +227,6 @@ struct RemoveUnusedModuleElements : public Pass { } } } - - void optimizeFunctionTypes(Module* module) { - FunctionTypeAnalyzer analyzer; - analyzer.walkModule(module); - // maps each string signature to a single canonical function type - std::unordered_map<std::string, FunctionType*> canonicals; - std::unordered_set<FunctionType*> needed; - auto canonicalize = [&](Name name) { - if (!name.is()) { - return name; - } - FunctionType* type = module->getFunctionType(name); - auto sig = getSig(type); - auto iter = canonicals.find(sig); - if (iter == canonicals.end()) { - needed.insert(type); - canonicals[sig] = type; - return type->name; - } else { - return iter->second->name; - } - }; - // canonicalize all uses of function types - for (auto* func : analyzer.functions) { - func->type = canonicalize(func->type); - } - for (auto* call : analyzer.indirectCalls) { - call->fullType = canonicalize(call->fullType); - } - // remove no-longer used types - module->removeFunctionTypes( - [&](FunctionType* type) { return needed.count(type) == 0; }); - } }; Pass* createRemoveUnusedModuleElementsPass() { diff --git a/src/passes/ReorderLocals.cpp b/src/passes/ReorderLocals.cpp index 3315d0e02..f5736930a 100644 --- a/src/passes/ReorderLocals.cpp +++ b/src/passes/ReorderLocals.cpp @@ -65,15 +65,16 @@ struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals>> { return counts[a] > counts[b]; }); // sorting left params in front, perhaps slightly reordered. verify and fix. - for (size_t i = 0; i < curr->params.size(); i++) { - assert(newToOld[i] < curr->params.size()); + size_t numParams = curr->sig.params.size(); + for (size_t i = 0; i < numParams; i++) { + assert(newToOld[i] < numParams); } - for (size_t i = 0; i < curr->params.size(); i++) { + for (size_t i = 0; i < numParams; i++) { newToOld[i] = i; } // sort vars, and drop unused ones - auto oldVars = curr->vars; - curr->vars.clear(); + std::vector<Type> oldVars; + std::swap(oldVars, curr->vars); for (size_t i = curr->getVarIndexBase(); i < newToOld.size(); i++) { Index index = newToOld[i]; if (counts[index] > 0) { @@ -102,15 +103,11 @@ struct ReorderLocals : public WalkerPass<PostWalker<ReorderLocals>> { : func(func), oldToNew(oldToNew) {} void visitLocalGet(LocalGet* curr) { - if (func->isVar(curr->index)) { - curr->index = oldToNew[curr->index]; - } + curr->index = oldToNew[curr->index]; } void visitLocalSet(LocalSet* curr) { - if (func->isVar(curr->index)) { - curr->index = oldToNew[curr->index]; - } + curr->index = oldToNew[curr->index]; } }; ReIndexer reIndexer(curr, oldToNew); diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp index f29a1cf66..37610a90b 100644 --- a/src/passes/SafeHeap.cpp +++ b/src/passes/SafeHeap.cpp @@ -23,7 +23,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" #include "ir/bits.h" -#include "ir/function-type-utils.h" #include "ir/import-utils.h" #include "ir/load-utils.h" #include "pass.h" @@ -137,9 +136,7 @@ struct SafeHeap : public Pass { import->name = getSbrkPtr = GET_SBRK_PTR_IMPORT; import->module = ENV; import->base = GET_SBRK_PTR_IMPORT; - auto* functionType = ensureFunctionType("i", module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::i32); module->addFunction(import); } if (auto* existing = info.getImportedFunction(ENV, SEGFAULT_IMPORT)) { @@ -149,9 +146,7 @@ struct SafeHeap : public Pass { import->name = segfault = SEGFAULT_IMPORT; import->module = ENV; import->base = SEGFAULT_IMPORT; - auto* functionType = ensureFunctionType("v", module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::none); module->addFunction(import); } if (auto* existing = info.getImportedFunction(ENV, ALIGNFAULT_IMPORT)) { @@ -161,9 +156,7 @@ struct SafeHeap : public Pass { import->name = alignfault = ALIGNFAULT_IMPORT; import->module = ENV; import->base = ALIGNFAULT_IMPORT; - auto* functionType = ensureFunctionType("v", module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::none); module->addFunction(import); } } @@ -251,10 +244,9 @@ struct SafeHeap : public Pass { } auto* func = new Function; func->name = name; - func->params.push_back(i32); // pointer - func->params.push_back(i32); // offset + // pointer, offset + func->sig = Signature({Type::i32, Type::i32}, style.type); func->vars.push_back(i32); // pointer + offset - func->result = style.type; Builder builder(*module); auto* block = builder.makeBlock(); block->list.push_back(builder.makeLocalSet( @@ -291,11 +283,9 @@ struct SafeHeap : public Pass { } auto* func = new Function; func->name = name; - func->params.push_back(i32); // pointer - func->params.push_back(i32); // offset - func->params.push_back(style.valueType); // value + // pointer, offset, value + func->sig = Signature({Type::i32, Type::i32, style.valueType}, Type::none); func->vars.push_back(i32); // pointer + offset - func->result = none; Builder builder(*module); auto* block = builder.makeBlock(); block->list.push_back(builder.makeLocalSet( diff --git a/src/passes/TrapMode.cpp b/src/passes/TrapMode.cpp index c00c34eca..9c855d8c1 100644 --- a/src/passes/TrapMode.cpp +++ b/src/passes/TrapMode.cpp @@ -22,7 +22,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "ir/trapping.h" #include "mixed_arena.h" #include "pass.h" @@ -125,9 +124,7 @@ Function* generateBinaryFunc(Module& wasm, Binary* curr) { } auto func = new Function; func->name = getBinaryFuncName(curr); - func->params.push_back(type); - func->params.push_back(type); - func->result = type; + func->sig = Signature({type, type}, type); func->body = builder.makeIf(builder.makeUnary(eqZOp, builder.makeLocalGet(1, type)), builder.makeConst(zeroLit), @@ -188,8 +185,7 @@ Function* generateUnaryFunc(Module& wasm, Unary* curr) { auto func = new Function; func->name = getUnaryFuncName(curr); - func->params.push_back(type); - func->result = retType; + func->sig = Signature(type, retType); func->body = builder.makeUnary(truncOp, builder.makeLocalGet(0, type)); // too small XXX this is different than asm.js, which does frem. here we // clamp, which is much simpler/faster, and similar to native builds @@ -240,14 +236,12 @@ void ensureF64ToI64JSImport(TrappingFunctionContainer& trappingFunctions) { return; } - Module& wasm = trappingFunctions.getModule(); - auto import = new Function; // f64-to-int = asm2wasm.f64-to-int; + // f64-to-int = asm2wasm.f64-to-int; + auto import = new Function; import->name = F64_TO_INT; import->module = ASM2WASM; import->base = F64_TO_INT; - auto* functionType = ensureFunctionType("id", &wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::f64, Type::i32); trappingFunctions.addImport(import); } diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index c2911f2b2..26fd2bd0e 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -410,13 +410,14 @@ struct Vacuum : public WalkerPass<ExpressionStackWalker<Vacuum>> { } void visitFunction(Function* curr) { - auto* optimized = optimize(curr->body, curr->result != none, true); + auto* optimized = + optimize(curr->body, curr->sig.results != Type::none, true); if (optimized) { curr->body = optimized; } else { ExpressionManipulator::nop(curr->body); } - if (curr->result == none && + if (curr->sig.results == Type::none && !EffectAnalyzer(getPassOptions(), curr->body).hasSideEffects()) { ExpressionManipulator::nop(curr->body); } diff --git a/src/shell-interface.h b/src/shell-interface.h index 63cbd9807..52533f37c 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -149,7 +149,7 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { Literal callTable(Index index, LiteralList& arguments, - Type result, + Type results, ModuleInstance& instance) override { if (index >= table.size()) { trap("callTable overflow"); @@ -158,15 +158,16 @@ struct ShellExternalInterface : ModuleInstance::ExternalInterface { if (!func) { trap("uninitialized table element"); } - if (func->params.size() != arguments.size()) { + const std::vector<Type>& params = func->sig.params.expand(); + if (params.size() != arguments.size()) { trap("callIndirect: bad # of arguments"); } - for (size_t i = 0; i < func->params.size(); i++) { - if (func->params[i] != arguments[i].type) { + for (size_t i = 0; i < params.size(); i++) { + if (params[i] != arguments[i].type) { trap("callIndirect: bad argument type"); } } - if (func->result != result) { + if (func->sig.results != results) { trap("callIndirect: bad result type"); } if (func->imported()) { diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 9ef3d2e1e..c0c7428cc 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -67,7 +67,7 @@ struct ExecutionResults { } std::cout << "[fuzz-exec] calling " << exp->name << "\n"; auto* func = wasm.getFunction(exp->value); - if (func->result != none) { + if (func->sig.results != Type::none) { // this has a result results[exp->name] = run(func, wasm, instance); // ignore the result if we hit an unreachable and returned no value @@ -136,7 +136,7 @@ struct ExecutionResults { instance.callFunction(ex->value, arguments); } // call the method - for (Type param : func->params) { + for (Type param : func->sig.params.expand()) { // zeros in arguments TODO: more? arguments.push_back(Literal(param)); } diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 8de47900c..3dcb5c665 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -370,8 +370,7 @@ private: contents.push_back(builder.makeLocalGet(0, i32)); auto* body = builder.makeBlock(contents); auto* hasher = wasm.addFunction(builder.makeFunction( - "hashMemory", std::vector<Type>{}, i32, {i32}, body)); - hasher->type = ensureFunctionType(getSig(hasher), &wasm)->name; + "hashMemory", Signature(Type::none, Type::i32), {i32}, body)); wasm.addExport( builder.makeExport(hasher->name, hasher->name, ExternalKind::Function)); // Export memory so JS fuzzing can use it @@ -435,7 +434,7 @@ private: auto* func = new Function; func->name = "hangLimitInitializer"; - func->result = none; + func->sig = Signature(Type::none, Type::none); func->body = builder.makeGlobalSet( glob->name, builder.makeConst(Literal(int32_t(HANG_LIMIT)))); wasm.addFunction(func); @@ -454,9 +453,7 @@ private: func->name = name; func->module = "fuzzing-support"; func->base = name; - func->params.push_back(type); - func->result = none; - func->type = ensureFunctionType(getSig(func), &wasm)->name; + func->sig = Signature(type, Type::none); wasm.addFunction(func); } } @@ -478,8 +475,7 @@ private: auto add = [&](Name name, Type type, Literal literal, BinaryOp op) { auto* func = new Function; func->name = name; - func->params.push_back(type); - func->result = type; + func->sig = Signature(type, type); func->body = builder.makeIf( builder.makeBinary( op, builder.makeLocalGet(0, type), builder.makeLocalGet(0, type)), @@ -521,25 +517,27 @@ private: Index num = wasm.functions.size(); func = new Function; func->name = std::string("func_") + std::to_string(num); - func->result = getReachableType(); assert(typeLocals.empty()); Index numParams = upToSquared(MAX_PARAMS); + std::vector<Type> params; + params.reserve(numParams); for (Index i = 0; i < numParams; i++) { auto type = getConcreteType(); - typeLocals[type].push_back(func->params.size()); - func->params.push_back(type); + typeLocals[type].push_back(params.size()); + params.push_back(type); } + func->sig = Signature(Type(params), getReachableType()); Index numVars = upToSquared(MAX_VARS); for (Index i = 0; i < numVars; i++) { auto type = getConcreteType(); - typeLocals[type].push_back(func->params.size() + func->vars.size()); + typeLocals[type].push_back(params.size() + func->vars.size()); func->vars.push_back(type); } labelIndex = 0; assert(breakableStack.empty()); assert(hangStack.empty()); // with small chance, make the body unreachable - auto bodyType = func->result; + auto bodyType = func->sig.results; if (oneIn(10)) { bodyType = unreachable; } @@ -568,7 +566,6 @@ private: // export some, but not all (to allow inlining etc.). make sure to // export at least one, though, to keep each testcase interesting if (num == 0 || oneIn(2)) { - func->type = ensureFunctionType(getSig(func), &wasm)->name; auto* export_ = new Export; export_->name = func->name; export_->value = func->name; @@ -779,11 +776,12 @@ private: std::vector<Expression*> invocations; while (oneIn(2) && !finishedInput) { std::vector<Expression*> args; - for (auto type : func->params) { + for (auto type : func->sig.params.expand()) { args.push_back(makeConst(type)); } - Expression* invoke = builder.makeCall(func->name, args, func->result); - if (func->result.isConcrete()) { + Expression* invoke = + builder.makeCall(func->name, args, func->sig.results); + if (func->sig.results.isConcrete()) { invoke = builder.makeDrop(invoke); } invocations.push_back(invoke); @@ -797,10 +795,9 @@ private: } auto* invoker = new Function; invoker->name = func->name.str + std::string("_invoker"); - invoker->result = none; + invoker->sig = Signature(Type::none, Type::none); invoker->body = builder.makeBlock(invocations); wasm.addFunction(invoker); - invoker->type = ensureFunctionType(getSig(invoker), &wasm)->name; auto* export_ = new Export; export_->name = invoker->name; export_->value = invoker->name; @@ -1000,8 +997,8 @@ private: } assert(type == unreachable); Expression* ret = nullptr; - if (func->result.isConcrete()) { - ret = makeTrivial(func->result); + if (func->sig.results.isConcrete()) { + ret = makeTrivial(func->sig.results); } return builder.makeReturn(ret); } @@ -1201,14 +1198,14 @@ private: if (!wasm.functions.empty() && !oneIn(wasm.functions.size())) { target = pick(wasm.functions).get(); } - isReturn = type == unreachable && wasm.features.hasTailCall() && - func->result == target->result; - if (target->result != type && !isReturn) { + isReturn = type == Type::unreachable && wasm.features.hasTailCall() && + func->sig.results == target->sig.results; + if (target->sig.results != type && !isReturn) { continue; } // we found one! std::vector<Expression*> args; - for (auto argType : target->params) { + for (auto argType : target->sig.params.expand()) { args.push_back(make(argType)); } return builder.makeCall(target->name, args, type, isReturn); @@ -1231,8 +1228,8 @@ private: // TODO: handle unreachable targetFn = wasm.getFunction(data[i]); isReturn = type == unreachable && wasm.features.hasTailCall() && - func->result == targetFn->result; - if (targetFn->result == type || isReturn) { + func->sig.results == targetFn->sig.results; + if (targetFn->sig.results == type || isReturn) { break; } i++; @@ -1252,12 +1249,10 @@ private: target = make(i32); } std::vector<Expression*> args; - for (auto type : targetFn->params) { + for (auto type : targetFn->sig.params.expand()) { args.push_back(make(type)); } - targetFn->type = ensureFunctionType(getSig(targetFn), &wasm)->name; - return builder.makeCallIndirect( - targetFn->type, target, args, targetFn->result, isReturn); + return builder.makeCallIndirect(target, args, targetFn->sig, isReturn); } Expression* makeLocalGet(Type type) { @@ -2241,8 +2236,8 @@ private: } Expression* makeReturn(Type type) { - return builder.makeReturn(func->result.isConcrete() ? make(func->result) - : nullptr); + return builder.makeReturn( + func->sig.results.isConcrete() ? make(func->sig.results) : nullptr); } Expression* makeNop(Type type) { diff --git a/src/tools/js-wrapper.h b/src/tools/js-wrapper.h index 34a823e97..a2d481f42 100644 --- a/src/tools/js-wrapper.h +++ b/src/tools/js-wrapper.h @@ -91,7 +91,7 @@ static std::string generateJSWrapper(Module& wasm) { ret += "try {\n"; ret += std::string(" console.log('[fuzz-exec] calling $") + exp->name.str + "');\n"; - if (func->result != none) { + if (func->sig.results != Type::none) { ret += std::string(" console.log('[fuzz-exec] note result: $") + exp->name.str + " => ' + literal("; } else { @@ -99,7 +99,7 @@ static std::string generateJSWrapper(Module& wasm) { } ret += std::string("instance.exports.") + exp->name.str + "("; bool first = true; - for (Type param : func->params) { + for (Type param : func->sig.params.expand()) { // zeros in arguments TODO more? if (first) { first = false; @@ -112,8 +112,8 @@ static std::string generateJSWrapper(Module& wasm) { } } ret += ")"; - if (func->result != none) { - ret += ", '" + func->result.toString() + "'))"; + if (func->sig.results != Type::none) { + ret += ", '" + func->sig.results.toString() + "'))"; // TODO: getTempRet } ret += ";\n"; diff --git a/src/tools/spec-wrapper.h b/src/tools/spec-wrapper.h index bcdf2e113..beada1b4b 100644 --- a/src/tools/spec-wrapper.h +++ b/src/tools/spec-wrapper.h @@ -30,7 +30,7 @@ static std::string generateSpecWrapper(Module& wasm) { } ret += std::string("(invoke \"hangLimitInitializer\") (invoke \"") + exp->name.str + "\" "; - for (Type param : func->params) { + for (Type param : func->sig.params.expand()) { // zeros in arguments TODO more? switch (param) { case i32: diff --git a/src/tools/wasm-reduce.cpp b/src/tools/wasm-reduce.cpp index 3e1e464f8..274b6de29 100644 --- a/src/tools/wasm-reduce.cpp +++ b/src/tools/wasm-reduce.cpp @@ -866,8 +866,7 @@ struct Reducer auto* func = module->functions[0].get(); // We can't remove something that might have breaks to it. if (!func->imported() && !Properties::isNamedControlFlow(func->body)) { - auto funcType = func->type; - auto funcResult = func->result; + auto funcSig = func->sig; auto* funcBody = func->body; for (auto* child : ChildIterator(func->body)) { if (!(child->type.isConcrete() || child->type == none)) { @@ -875,8 +874,7 @@ struct Reducer } // Try to replace the body with the child, fixing up the function // to accept it. - func->type = Name(); - func->result = child->type; + func->sig.results = child->type; func->body = child; if (writeAndTestReduction()) { // great, we succeeded! @@ -885,8 +883,7 @@ struct Reducer break; } // Undo. - func->type = funcType; - func->result = funcResult; + func->sig = funcSig; func->body = funcBody; } } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 53436c9d3..d5ee60d8a 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -111,7 +111,7 @@ static void run_asserts(Name moduleName, std::cerr << "Unknown entry " << entry << std::endl; } else { LiteralList arguments; - for (Type param : function->params) { + for (Type param : function->sig.params.expand()) { arguments.push_back(Literal(param)); } try { diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 34919cc79..c453ad1f5 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -1023,7 +1023,7 @@ private: Module* wasm; BufferWithRandomAccess& o; ModuleUtils::BinaryIndexes indexes; - std::unordered_map<Signature, Index> typeIndexes; + std::unordered_map<Signature, Index> typeIndices; std::vector<Signature> types; bool debugInfo = true; @@ -1058,6 +1058,9 @@ class WasmBinaryBuilder { std::set<BinaryConsts::Section> seenSections; + // All signatures present in the type section + std::vector<Signature> signatures; + public: WasmBinaryBuilder(Module& wasm, const std::vector<char>& input) : wasm(wasm), allocator(wasm.allocator), input(input), sourceMap(nullptr), @@ -1106,7 +1109,8 @@ public: Address defaultIfNoMax); void readImports(); - std::vector<FunctionType*> functionTypes; // types of defined functions + // The signatures of each function, given in the function section + std::vector<Signature> functionSignatures; void readFunctionSignatures(); size_t nextLabel; @@ -1133,7 +1137,7 @@ public: void readFunctions(); - std::map<Export*, Index> exportIndexes; + std::map<Export*, Index> exportIndices; std::vector<Export*> exportOrder; void readExports(); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 533f83d94..22fa0e642 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -43,15 +43,13 @@ public: // make* functions, other globals Function* makeFunction(Name name, - std::vector<Type>&& params, - Type resultType, + Signature sig, std::vector<Type>&& vars, Expression* body = nullptr) { auto* func = new Function; func->name = name; - func->result = resultType; + func->sig = sig; func->body = body; - func->params.swap(params); func->vars.swap(vars); return func; } @@ -63,14 +61,15 @@ public: Expression* body = nullptr) { auto* func = new Function; func->name = name; - func->result = resultType; func->body = body; + std::vector<Type> paramVec; for (auto& param : params) { - func->params.push_back(param.type); + paramVec.push_back(param.type); Index index = func->localNames.size(); func->localIndices[param.name] = index; func->localNames[index] = param.name; } + func->sig = Signature(Type(paramVec), resultType); for (auto& var : vars) { func->vars.push_back(var.type); Index index = func->localNames.size(); @@ -210,27 +209,19 @@ public: call->finalize(); return call; } - CallIndirect* makeCallIndirect(FunctionType* type, - Expression* target, + CallIndirect* makeCallIndirect(Expression* target, const std::vector<Expression*>& args, - bool isReturn = false) { - return makeCallIndirect(type->name, target, args, type->result, isReturn); - } - CallIndirect* makeCallIndirect(Name fullType, - Expression* target, - const std::vector<Expression*>& args, - Type type, + Signature sig, bool isReturn = false) { auto* call = allocator.alloc<CallIndirect>(); - call->fullType = fullType; - call->type = type; + call->sig = sig; + call->type = sig.results; call->target = target; call->operands.set(args); call->isReturn = isReturn; call->finalize(); return call; } - // FunctionType LocalGet* makeLocalGet(Index index, Type type) { auto* ret = allocator.alloc<LocalGet>(); ret->index = index; @@ -582,9 +573,11 @@ public: static Index addParam(Function* func, Name name, Type type) { // only ok to add a param if no vars, otherwise indices are invalidated - assert(func->localIndices.size() == func->params.size()); + assert(func->localIndices.size() == func->sig.params.size()); assert(name.is()); - func->params.push_back(type); + std::vector<Type> params = func->sig.params.expand(); + params.push_back(type); + func->sig.params = Type(params); Index index = func->localNames.size(); func->localIndices[name] = index; func->localNames[index] = name; @@ -613,7 +606,7 @@ public: } static void clearLocals(Function* func) { - func->params.clear(); + func->sig.params = Type::none; func->vars.clear(); clearLocalNames(func); } diff --git a/src/wasm-emscripten.h b/src/wasm-emscripten.h index 9dc82bc1a..8aa911ef9 100644 --- a/src/wasm-emscripten.h +++ b/src/wasm-emscripten.h @@ -71,12 +71,12 @@ private: bool useStackPointerGlobal; // Used by generateDynCallThunk to track all the dynCall functions created // so far. - std::unordered_set<std::string> sigs; + std::unordered_set<Signature> sigs; Global* getStackPointerGlobal(); Expression* generateLoadStackPointer(); Expression* generateStoreStackPointer(Function* func, Expression* value); - void generateDynCallThunk(std::string sig); + void generateDynCallThunk(Signature sig); void generateStackSaveFunction(); void generateStackAllocFunction(); void generateStackRestoreFunction(); diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 1a90fdf71..caf8a9a3d 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -1449,20 +1449,21 @@ private: FunctionScope(Function* function, const LiteralList& arguments) : function(function) { - if (function->params.size() != arguments.size()) { + if (function->sig.params.size() != arguments.size()) { std::cerr << "Function `" << function->name << "` expects " - << function->params.size() << " parameters, got " + << function->sig.params.size() << " parameters, got " << arguments.size() << " arguments." << std::endl; WASM_UNREACHABLE("invalid param count"); } locals.resize(function->getNumLocals()); + const std::vector<Type>& params = function->sig.params.expand(); for (size_t i = 0; i < function->getNumLocals(); i++) { if (i < arguments.size()) { - assert(function->isParam(i)); - if (function->params[i] != arguments[i].type) { + assert(i < params.size()); + if (params[i] != arguments[i].type) { std::cerr << "Function `" << function->name << "` expects type " - << function->params[i] << " for parameter " << i - << ", got " << arguments[i].type << "." << std::endl; + << params[i] << " for parameter " << i << ", got " + << arguments[i].type << "." << std::endl; WASM_UNREACHABLE("invalid param count"); } locals[i] = arguments[i]; @@ -1544,7 +1545,7 @@ private: return target; } Index index = target.value.geti32(); - Type type = curr->isReturn ? scope.function->result : curr->type; + Type type = curr->isReturn ? scope.function->sig.results : curr->type; Flow ret = instance.externalInterface->callTable( index, arguments, type, *instance.self()); // TODO: make this a proper tail call (return first) @@ -2053,9 +2054,10 @@ public: // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); Literal ret = flow.value; - if (function->result != ret.type) { + if (function->sig.results != ret.type) { std::cerr << "calling " << function->name << " resulted in " << ret - << " but the function type is " << function->result << '\n'; + << " but the function type is " << function->sig.results + << '\n'; WASM_UNREACHABLE("unexpect result type"); } // may decrease more than one, if we jumped up the stack diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index 924c1d968..d7324d756 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -111,6 +111,8 @@ private: class SExpressionWasmBuilder { Module& wasm; MixedArena& allocator; + std::vector<Signature> signatures; + std::unordered_map<std::string, size_t> signatureIndices; std::vector<Name> functionNames; std::vector<Name> globalNames; std::vector<Name> eventNames; @@ -141,8 +143,8 @@ private: UniqueNameMapper nameMapper; + Signature getFunctionSignature(Element& s); Name getFunctionName(Element& s); - Name getFunctionTypeName(Element& s); Name getGlobalName(Element& s); Name getEventName(Element& s); void parseStart(Element& s) { wasm.addStart(getFunctionName(*s[1])); } @@ -234,19 +236,14 @@ private: Index parseMemoryLimits(Element& s, Index i); std::vector<Type> parseParamOrLocal(Element& s); std::vector<NameType> parseParamOrLocal(Element& s, size_t& localIndex); - Type parseResult(Element& s); - FunctionType* parseTypeRef(Element& s); + Type parseResults(Element& s); + Signature parseTypeRef(Element& s); size_t parseTypeUse(Element& s, size_t startPos, - FunctionType*& functionType, - std::vector<NameType>& namedParams, - Type& result); - size_t parseTypeUse(Element& s, - size_t startPos, - FunctionType*& functionType, - std::vector<Type>& params, - Type& result); - size_t parseTypeUse(Element& s, size_t startPos, FunctionType*& functionType); + Signature& functionSignature, + std::vector<NameType>& namedParams); + size_t + parseTypeUse(Element& s, size_t startPos, Signature& functionSignature); void stringToBinary(const char* input, size_t size, std::vector<char>& data); void parseMemory(Element& s, bool preParseImport = false); diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 85b7ca415..9c6e78360 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -81,7 +81,6 @@ template<typename SubType, typename ReturnType = void> struct Visitor { ReturnType visitPush(Push* curr) { return ReturnType(); } ReturnType visitPop(Pop* curr) { return ReturnType(); } // Module-level visitors - ReturnType visitFunctionType(FunctionType* curr) { return ReturnType(); } ReturnType visitExport(Export* curr) { return ReturnType(); } ReturnType visitGlobal(Global* curr) { return ReturnType(); } ReturnType visitFunction(Function* curr) { return ReturnType(); } @@ -250,7 +249,6 @@ struct OverriddenVisitor { UNIMPLEMENTED(Unreachable); UNIMPLEMENTED(Push); UNIMPLEMENTED(Pop); - UNIMPLEMENTED(FunctionType); UNIMPLEMENTED(Export); UNIMPLEMENTED(Global); UNIMPLEMENTED(Function); @@ -603,9 +601,6 @@ struct Walker : public VisitorType { void doWalkModule(Module* module) { // Dispatch statically through the SubType. SubType* self = static_cast<SubType*>(this); - for (auto& curr : module->functionTypes) { - self->visitFunctionType(curr.get()); - } for (auto& curr : module->exports) { self->visitExport(curr.get()); } diff --git a/src/wasm-type.h b/src/wasm-type.h index ddfb7c9a1..c4d719a35 100644 --- a/src/wasm-type.h +++ b/src/wasm-type.h @@ -99,7 +99,7 @@ struct ResultType { struct Signature { Type params; Type results; - Signature() = default; + Signature() : params(Type::none), results(Type::none) {} Signature(Type params, Type results) : params(params), results(results) {} bool operator==(const Signature& other) const { return params == other.params && results == other.results; diff --git a/src/wasm.h b/src/wasm.h index d9efe446a..fba962bdb 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -680,27 +680,11 @@ public: void finalize(); }; -class FunctionType { -public: - Name name; - Type result = none; - std::vector<Type> params; - - FunctionType() = default; - - bool structuralComparison(FunctionType& b); - bool structuralComparison(const std::vector<Type>& params, Type result); - - bool operator==(FunctionType& b); - bool operator!=(FunctionType& b); -}; - class CallIndirect : public SpecificExpression<Expression::CallIndirectId> { public: CallIndirect(MixedArena& allocator) : operands(allocator) {} - + Signature sig; ExpressionList operands; - Name fullType; Expression* target; bool isReturn = false; @@ -1147,10 +1131,8 @@ typedef std::vector<StackInst*> StackIR; class Function : public Importable { public: Name name; - Type result = none; - std::vector<Type> params; // function locals are + Signature sig; std::vector<Type> vars; // params plus vars - Name type; // if null, it is implicit in params and result // The body of the function Expression* body = nullptr; @@ -1334,7 +1316,6 @@ class Module { public: // wasm contents (generally you shouldn't access these from outside, except // maybe for iterating; use add*() and the get() functions) - std::vector<std::unique_ptr<FunctionType>> functionTypes; std::vector<std::unique_ptr<Export>> exports; std::vector<std::unique_ptr<Function>> functions; std::vector<std::unique_ptr<Global>> globals; @@ -1359,7 +1340,6 @@ public: private: // TODO: add a build option where Names are just indices, and then these // methods are not needed - std::map<Name, FunctionType*> functionTypesMap; // exports map is by the *exported* name, which is unique std::map<Name, Export*> exportsMap; std::map<Name, Function*> functionsMap; @@ -1369,19 +1349,16 @@ private: public: Module() = default; - FunctionType* getFunctionType(Name name); Export* getExport(Name name); Function* getFunction(Name name); Global* getGlobal(Name name); Event* getEvent(Name name); - FunctionType* getFunctionTypeOrNull(Name name); Export* getExportOrNull(Name name); Function* getFunctionOrNull(Name name); Global* getGlobalOrNull(Name name); Event* getEventOrNull(Name name); - FunctionType* addFunctionType(std::unique_ptr<FunctionType> curr); Export* addExport(Export* curr); Function* addFunction(Function* curr); Function* addFunction(std::unique_ptr<Function> curr); @@ -1390,13 +1367,11 @@ public: void addStart(const Name& s); - void removeFunctionType(Name name); void removeExport(Name name); void removeFunction(Name name); void removeGlobal(Name name); void removeEvent(Name name); - void removeFunctionTypes(std::function<bool(FunctionType*)> pred); void removeExports(std::function<bool(Export*)> pred); void removeFunctions(std::function<bool(Function*)> pred); void removeGlobals(std::function<bool(Global*)> pred); diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index b8b001927..d86eea8f5 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -30,56 +30,7 @@ namespace wasm { void WasmBinaryWriter::prepare() { // Collect function types and their frequencies. Collect information in each // function in parallel, then merge. - typedef std::unordered_map<Signature, size_t> Counts; - ModuleUtils::ParallelFunctionAnalysis<Counts> analysis( - *wasm, [&](Function* func, Counts& counts) { - if (func->imported()) { - return; - } - struct TypeCounter : PostWalker<TypeCounter> { - Module& wasm; - Counts& counts; - - TypeCounter(Module& wasm, Counts& counts) - : wasm(wasm), counts(counts) {} - - void visitCallIndirect(CallIndirect* curr) { - auto* type = wasm.getFunctionType(curr->fullType); - Signature sig(Type(type->params), type->result); - counts[sig]++; - } - }; - TypeCounter(*wasm, counts).walk(func->body); - }); - // Collect all the counts. - Counts counts; - for (auto& curr : wasm->functions) { - counts[Signature(Type(curr->params), curr->result)]++; - } - for (auto& curr : wasm->events) { - counts[curr->sig]++; - } - for (auto& pair : analysis.map) { - Counts& functionCounts = pair.second; - for (auto& innerPair : functionCounts) { - counts[innerPair.first] += innerPair.second; - } - } - std::vector<std::pair<Signature, size_t>> sorted(counts.begin(), - counts.end()); - std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { - // order by frequency then simplicity - if (a.second != b.second) { - return a.second > b.second; - } else { - return a.first < b.first; - } - }); - for (Index i = 0; i < sorted.size(); ++i) { - typeIndexes[sorted[i].first] = i; - types.push_back(sorted[i].first); - } - + ModuleUtils::collectSignatures(*wasm, types, typeIndices); importInfo = wasm::make_unique<ImportInfo>(*wasm); } @@ -250,7 +201,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one function\n"); writeImportHeader(func); o << U32LEB(int32_t(ExternalKind::Function)); - o << U32LEB(getTypeIndex(Signature(Type(func->params), func->result))); + o << U32LEB(getTypeIndex(func->sig)); }); ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) { BYN_TRACE("write one global\n"); @@ -297,7 +248,7 @@ void WasmBinaryWriter::writeFunctionSignatures() { o << U32LEB(importInfo->getNumDefinedFunctions()); ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { BYN_TRACE("write one\n"); - o << U32LEB(getTypeIndex(Signature(Type(func->params), func->result))); + o << U32LEB(getTypeIndex(func->sig)); }); finishSection(start); } @@ -458,8 +409,8 @@ uint32_t WasmBinaryWriter::getEventIndex(Name name) const { } uint32_t WasmBinaryWriter::getTypeIndex(Signature sig) const { - auto it = typeIndexes.find(sig); - assert(it != typeIndexes.end()); + auto it = typeIndices.find(sig); + assert(it != typeIndices.end()); return it->second; } @@ -1124,7 +1075,8 @@ void WasmBinaryBuilder::readSignatures() { BYN_TRACE("num: " << numTypes << std::endl); for (size_t i = 0; i < numTypes; i++) { BYN_TRACE("read one\n"); - auto curr = make_unique<FunctionType>(); + std::vector<Type> params; + std::vector<Type> results; auto form = getS32LEB(); if (form != BinaryConsts::EncodedType::Func) { throwError("bad signature form " + std::to_string(form)); @@ -1132,19 +1084,14 @@ void WasmBinaryBuilder::readSignatures() { size_t numParams = getU32LEB(); BYN_TRACE("num params: " << numParams << std::endl); for (size_t j = 0; j < numParams; j++) { - curr->params.push_back(getConcreteType()); + params.push_back(getConcreteType()); } auto numResults = getU32LEB(); - if (numResults == 0) { - curr->result = none; - } else { - if (numResults != 1) { - throwError("signature must have 1 result"); - } - curr->result = getType(); + BYN_TRACE("num results: " << numResults << std::endl); + for (size_t j = 0; j < numResults; j++) { + results.push_back(getConcreteType()); } - curr->name = Name::fromInt(wasm.functionTypes.size()); - wasm.addFunctionType(std::move(curr)); + signatures.emplace_back(Type(params), Type(results)); } } @@ -1205,17 +1152,13 @@ void WasmBinaryBuilder::readImports() { case ExternalKind::Function: { auto name = Name(std::string("fimport$") + std::to_string(i)); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index > signatures.size()) { throwError("invalid function index " + std::to_string(index) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - auto* functionType = wasm.functionTypes[index].get(); - auto params = functionType->params; - auto result = functionType->result; - auto* curr = builder.makeFunction(name, std::move(params), result, {}); + auto* curr = builder.makeFunction(name, signatures[index], {}); curr->module = module; curr->base = base; - curr->type = functionType->name; wasm.addFunction(curr); functionImports.push_back(curr); break; @@ -1267,13 +1210,11 @@ void WasmBinaryBuilder::readImports() { auto name = Name(std::string("eimport$") + std::to_string(i)); auto attribute = getU32LEB(); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("invalid event index " + std::to_string(index) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - Type params = Type(wasm.functionTypes[index]->params); - auto* curr = - builder.makeEvent(name, attribute, Signature(params, Type::none)); + auto* curr = builder.makeEvent(name, attribute, signatures[index]); curr->module = module; curr->base = base; wasm.addEvent(curr); @@ -1302,17 +1243,17 @@ void WasmBinaryBuilder::readFunctionSignatures() { for (size_t i = 0; i < num; i++) { BYN_TRACE("read one\n"); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("invalid function type index for function"); } - functionTypes.push_back(wasm.functionTypes[index].get()); + functionSignatures.push_back(signatures[index]); } } void WasmBinaryBuilder::readFunctions() { BYN_TRACE("== readFunctions\n"); size_t total = getU32LEB(); - if (total != functionTypes.size()) { + if (total != functionSignatures.size()) { throwError("invalid function section size, must equal types"); } for (size_t i = 0; i < total; i++) { @@ -1325,17 +1266,12 @@ void WasmBinaryBuilder::readFunctions() { Function* func = new Function; func->name = Name::fromInt(i); + func->sig = functionSignatures[i]; currFunction = func; readNextDebugLocation(); - auto type = functionTypes[i]; BYN_TRACE("reading " << i << std::endl); - func->type = type->name; - func->result = type->result; - for (size_t j = 0; j < type->params.size(); j++) { - func->params.emplace_back(type->params[j]); - } size_t numLocalTypes = getU32LEB(); for (size_t t = 0; t < numLocalTypes; t++) { auto num = getU32LEB(); @@ -1357,7 +1293,7 @@ void WasmBinaryBuilder::readFunctions() { assert(breakStack.empty()); assert(expressionStack.empty()); assert(depth == 0); - func->body = getBlockOrSingleton(func->result); + func->body = getBlockOrSingleton(func->sig.results); assert(depth == 0); assert(breakStack.size() == 0); assert(breakTargetNames.size() == 0); @@ -1391,7 +1327,7 @@ void WasmBinaryBuilder::readExports() { names.insert(curr->name); curr->kind = (ExternalKind)getU32LEB(); auto index = getU32LEB(); - exportIndexes[curr] = index; + exportIndices[curr] = index; exportOrder.push_back(curr); } } @@ -1753,7 +1689,7 @@ void WasmBinaryBuilder::processFunctions() { } for (auto* curr : exportOrder) { - auto index = exportIndexes[curr]; + auto index = exportIndices[curr]; switch (curr->kind) { case ExternalKind::Function: { curr->value = getFunctionName(index); @@ -1787,8 +1723,8 @@ void WasmBinaryBuilder::processFunctions() { for (auto& pair : functionTable) { auto i = pair.first; - auto& indexes = pair.second; - for (auto j : indexes) { + auto& indices = pair.second; + for (auto j : indices) { wasm.table.segments[i].data.push_back(getFunctionName(j)); } } @@ -1884,13 +1820,12 @@ void WasmBinaryBuilder::readEvents() { BYN_TRACE("read one\n"); auto attribute = getU32LEB(); auto typeIndex = getU32LEB(); - if (typeIndex >= wasm.functionTypes.size()) { + if (typeIndex >= signatures.size()) { throwError("invalid event index " + std::to_string(typeIndex) + " / " + - std::to_string(wasm.functionTypes.size())); + std::to_string(signatures.size())); } - Type params = Type(wasm.functionTypes[typeIndex]->params); wasm.addEvent(Builder::makeEvent( - "event$" + std::to_string(i), attribute, Signature(params, Type::none))); + "event$" + std::to_string(i), attribute, signatures[typeIndex])); } } @@ -2463,24 +2398,23 @@ void WasmBinaryBuilder::visitSwitch(Switch* curr) { void WasmBinaryBuilder::visitCall(Call* curr) { BYN_TRACE("zz node: Call\n"); auto index = getU32LEB(); - FunctionType* type; + Signature sig; if (index < functionImports.size()) { auto* import = functionImports[index]; - type = wasm.getFunctionType(import->type); + sig = import->sig; } else { Index adjustedIndex = index - functionImports.size(); - if (adjustedIndex >= functionTypes.size()) { + if (adjustedIndex >= functionSignatures.size()) { throwError("invalid call index"); } - type = functionTypes[adjustedIndex]; + sig = functionSignatures[adjustedIndex]; } - assert(type); - auto num = type->params.size(); + auto num = sig.params.size(); curr->operands.resize(num); for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->type = type->result; + curr->type = sig.results; functionCalls[index].push_back(curr); // we don't know function names yet curr->finalize(); } @@ -2488,22 +2422,20 @@ void WasmBinaryBuilder::visitCall(Call* curr) { void WasmBinaryBuilder::visitCallIndirect(CallIndirect* curr) { BYN_TRACE("zz node: CallIndirect\n"); auto index = getU32LEB(); - if (index >= wasm.functionTypes.size()) { + if (index >= signatures.size()) { throwError("bad call_indirect function index"); } - auto* fullType = wasm.functionTypes[index].get(); + curr->sig = signatures[index]; auto reserved = getU32LEB(); if (reserved != 0) { throwError("Invalid flags field in call_indirect"); } - curr->fullType = fullType->name; - auto num = fullType->params.size(); + auto num = curr->sig.params.size(); curr->operands.resize(num); curr->target = popNonVoidExpression(); for (size_t i = 0; i < num; i++) { curr->operands[num - i - 1] = popNonVoidExpression(); } - curr->type = fullType->result; curr->finalize(); } @@ -4299,7 +4231,7 @@ void WasmBinaryBuilder::visitSelect(Select* curr) { void WasmBinaryBuilder::visitReturn(Return* curr) { BYN_TRACE("zz node: Return\n"); requireFunctionContext("return"); - if (currFunction->result != none) { + if (currFunction->sig.results != Type::none) { curr->value = popNonVoidExpression(); } curr->finalize(); diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 695561f68..9a0b145da 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -20,7 +20,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" -#include "ir/function-type-utils.h" #include "ir/import-utils.h" #include "ir/literal-utils.h" #include "ir/module-utils.h" @@ -207,7 +206,7 @@ void EmscriptenGlueGenerator::generateRuntimeFunctions() { } static Function* -ensureFunctionImport(Module* module, Name name, std::string sig) { +ensureFunctionImport(Module* module, Name name, Signature sig) { // Then see if its already imported ImportInfo info(*module); if (Function* f = info.getImportedFunction(ENV, name)) { @@ -218,9 +217,7 @@ ensureFunctionImport(Module* module, Name name, std::string sig) { import->name = name; import->module = ENV; import->base = name; - auto* functionType = ensureFunctionType(sig, module); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = sig; module->addFunction(import); return import; } @@ -267,7 +264,7 @@ Function* EmscriptenGlueGenerator::generateAssignGOTEntriesFunction() { for (Global* g : gotMemEntries) { Name getter(std::string("g$") + g->base.c_str()); - ensureFunctionImport(&wasm, getter, "i"); + ensureFunctionImport(&wasm, getter, Signature(Type::none, Type::i32)); Expression* call = builder.makeCall(getter, {}, i32); GlobalSet* globalSet = builder.makeGlobalSet(g->name, call); block->list.push_back(globalSet); @@ -293,7 +290,7 @@ Function* EmscriptenGlueGenerator::generateAssignGOTEntriesFunction() { Name getter( (std::string("fp$") + g->base.c_str() + std::string("$") + getSig(f)) .c_str()); - ensureFunctionImport(&wasm, getter, "i"); + ensureFunctionImport(&wasm, getter, Signature(Type::none, Type::i32)); Expression* call = builder.makeCall(getter, {}, i32); GlobalSet* globalSet = builder.makeGlobalSet(g->name, call); block->list.push_back(globalSet); @@ -371,29 +368,28 @@ inline void exportFunction(Module& wasm, Name name, bool must_export) { wasm.addExport(exp); } -void EmscriptenGlueGenerator::generateDynCallThunk(std::string sig) { - auto* funcType = ensureFunctionType(sig, &wasm); +void EmscriptenGlueGenerator::generateDynCallThunk(Signature sig) { if (!sigs.insert(sig).second) { return; // sig is already in the set } - Name name = std::string("dynCall_") + sig; + Name name = std::string("dynCall_") + getSig(sig.results, sig.params); if (wasm.getFunctionOrNull(name) || wasm.getExportOrNull(name)) { return; // module already contains this dyncall } std::vector<NameType> params; params.emplace_back("fptr", i32); // function pointer param int p = 0; - for (const auto& ty : funcType->params) { + const std::vector<Type>& paramTypes = sig.params.expand(); + for (const auto& ty : paramTypes) { params.emplace_back(std::to_string(p++), ty); } - Function* f = - builder.makeFunction(name, std::move(params), funcType->result, {}); + Function* f = builder.makeFunction(name, std::move(params), sig.results, {}); Expression* fptr = builder.makeLocalGet(0, i32); std::vector<Expression*> args; - for (unsigned i = 0; i < funcType->params.size(); ++i) { - args.push_back(builder.makeLocalGet(i + 1, funcType->params[i])); + for (unsigned i = 0; i < paramTypes.size(); ++i) { + args.push_back(builder.makeLocalGet(i + 1, paramTypes[i])); } - Expression* call = builder.makeCallIndirect(funcType, fptr, args); + Expression* call = builder.makeCallIndirect(fptr, args, sig); f->body = call; wasm.addFunction(f); @@ -407,8 +403,7 @@ void EmscriptenGlueGenerator::generateDynCallThunks() { tableSegmentData = wasm.table.segments[0].data; } for (const auto& indirectFunc : tableSegmentData) { - std::string sig = getSig(wasm.getFunction(indirectFunc)); - generateDynCallThunk(sig); + generateDynCallThunk(wasm.getFunction(indirectFunc)->sig); } } @@ -479,10 +474,11 @@ void EmscriptenGlueGenerator::replaceStackPointerGlobal() { RemoveStackPointer walker(stackPointer); walker.walkModule(&wasm); if (walker.needStackSave) { - ensureFunctionImport(&wasm, STACK_SAVE, "i"); + ensureFunctionImport(&wasm, STACK_SAVE, Signature(Type::none, Type::i32)); } if (walker.needStackRestore) { - ensureFunctionImport(&wasm, STACK_RESTORE, "vi"); + ensureFunctionImport( + &wasm, STACK_RESTORE, Signature(Type::i32, Type::none)); } // Finally remove the stack pointer global itself. This avoids importing @@ -545,7 +541,7 @@ void EmscriptenGlueGenerator::enforceStackLimit() { void EmscriptenGlueGenerator::generateSetStackLimitFunction() { Function* function = - builder.makeFunction(SET_STACK_LIMIT, std::vector<Type>({i32}), none, {}); + builder.makeFunction(SET_STACK_LIMIT, Signature(Type::i32, Type::none), {}); LocalGet* getArg = builder.makeLocalGet(0, i32); Expression* store = builder.makeGlobalSet(STACK_LIMIT, getArg); function->body = store; @@ -562,9 +558,7 @@ Name EmscriptenGlueGenerator::importStackOverflowHandler() { import->name = STACK_OVERFLOW_IMPORT; import->module = ENV; import->base = STACK_OVERFLOW_IMPORT; - auto* functionType = ensureFunctionType("v", &wasm); - import->type = functionType->name; - FunctionTypeUtils::fillFunction(import, functionType); + import->sig = Signature(Type::none, Type::none); wasm.addFunction(import); return STACK_OVERFLOW_IMPORT; } @@ -696,14 +690,14 @@ struct AsmConstWalker : public LinearExecutionWalker<AsmConstWalker> { std::vector<Address> segmentOffsets; // segment index => address offset struct AsmConst { - std::set<std::string> sigs; + std::set<Signature> sigs; Address id; std::string code; Proxying proxy; }; std::vector<AsmConst> asmConsts; - std::set<std::pair<std::string, Proxying>> allSigs; + std::set<std::pair<Signature, Proxying>> allSigs; // last sets in the current basic block, per index std::map<Index, LocalSet*> sets; @@ -719,12 +713,12 @@ struct AsmConstWalker : public LinearExecutionWalker<AsmConstWalker> { void process(); private: - std::string fixupName(Name& name, std::string baseSig, Proxying proxy); + Signature fixupName(Name& name, Signature baseSig, Proxying proxy); AsmConst& - createAsmConst(uint32_t id, std::string code, std::string sig, Name name); - std::string asmConstSig(std::string baseSig); - Name nameForImportWithSig(std::string sig, Proxying proxy); - void queueImport(Name importName, std::string baseSig); + createAsmConst(uint32_t id, std::string code, Signature sig, Name name); + Signature asmConstSig(Signature baseSig); + Name nameForImportWithSig(Signature sig, Proxying proxy); + void queueImport(Name importName, Signature baseSig); void addImports(); Proxying proxyType(Name name); @@ -750,7 +744,7 @@ void AsmConstWalker::visitCall(Call* curr) { return; } - auto baseSig = getSig(curr); + auto baseSig = wasm.getFunction(curr->target)->sig; auto sig = asmConstSig(baseSig); auto* arg = curr->operands[0]; while (!arg->dynCast<Const>()) { @@ -816,9 +810,8 @@ void AsmConstWalker::visitTable(Table* curr) { for (auto& name : segment.data) { auto* func = wasm.getFunction(name); if (func->imported() && func->base.hasSubstring(EM_ASM_PREFIX)) { - std::string baseSig = getSig(func); auto proxy = proxyType(func->base); - fixupName(name, baseSig, proxy); + fixupName(name, func->sig, proxy); } } } @@ -832,8 +825,8 @@ void AsmConstWalker::process() { addImports(); } -std::string -AsmConstWalker::fixupName(Name& name, std::string baseSig, Proxying proxy) { +Signature +AsmConstWalker::fixupName(Name& name, Signature baseSig, Proxying proxy) { auto sig = asmConstSig(baseSig); auto importName = nameForImportWithSig(sig, proxy); name = importName; @@ -848,7 +841,7 @@ AsmConstWalker::fixupName(Name& name, std::string baseSig, Proxying proxy) { AsmConstWalker::AsmConst& AsmConstWalker::createAsmConst(uint32_t id, std::string code, - std::string sig, + Signature sig, Name name) { AsmConst asmConst; asmConst.id = id; @@ -859,31 +852,27 @@ AsmConstWalker::AsmConst& AsmConstWalker::createAsmConst(uint32_t id, return asmConsts.back(); } -std::string AsmConstWalker::asmConstSig(std::string baseSig) { - std::string sig = ""; - for (size_t i = 0; i < baseSig.size(); ++i) { - // Omit the signature of the "code" parameter, taken as a string, as the - // first argument - if (i != 1) { - sig += baseSig[i]; - } - } - return sig; +Signature AsmConstWalker::asmConstSig(Signature baseSig) { + std::vector<Type> params = baseSig.params.expand(); + assert(params.size() >= 1); + // Omit the signature of the "code" parameter, taken as a string, as the + // first argument + params.erase(params.begin()); + return Signature(Type(params), baseSig.results); } -Name AsmConstWalker::nameForImportWithSig(std::string sig, Proxying proxy) { - std::string fixedTarget = - EM_ASM_PREFIX.str + std::string("_") + proxyingSuffix(proxy) + sig; +Name AsmConstWalker::nameForImportWithSig(Signature sig, Proxying proxy) { + std::string fixedTarget = EM_ASM_PREFIX.str + std::string("_") + + proxyingSuffix(proxy) + + getSig(sig.results, sig.params); return Name(fixedTarget.c_str()); } -void AsmConstWalker::queueImport(Name importName, std::string baseSig) { +void AsmConstWalker::queueImport(Name importName, Signature baseSig) { auto import = new Function; import->name = import->base = importName; import->module = ENV; - auto* funcType = ensureFunctionType(baseSig, &wasm); - import->type = funcType->name; - FunctionTypeUtils::fillFunction(import, funcType); + import->sig = baseSig; queuedImports.push_back(std::unique_ptr<Function>(import)); } @@ -994,7 +983,7 @@ struct FixInvokeFunctionNamesWalker std::map<Name, Name> importRenames; std::vector<Name> toRemove; std::set<Name> newImports; - std::set<std::string> invokeSigs; + std::set<Signature> invokeSigs; FixInvokeFunctionNamesWalker(Module& _wasm) : wasm(_wasm) {} @@ -1017,7 +1006,7 @@ struct FixInvokeFunctionNamesWalker // This function converts the names of invoke wrappers based on their lowered // argument types and a return type. In the example above, the resulting new // wrapper name becomes "invoke_vii". - Name fixEmExceptionInvoke(const Name& name, const std::string& sig) { + Name fixEmExceptionInvoke(const Name& name, Signature sig) { std::string nameStr = name.c_str(); if (nameStr.front() == '"' && nameStr.back() == '"') { nameStr = nameStr.substr(1, nameStr.size() - 2); @@ -1025,12 +1014,16 @@ struct FixInvokeFunctionNamesWalker if (nameStr.find("__invoke_") != 0) { return name; } - std::string sigWoOrigFunc = sig.front() + sig.substr(2, sig.size() - 2); + + const std::vector<Type>& params = sig.params.expand(); + std::vector<Type> newParams(params.begin() + 1, params.end()); + Signature sigWoOrigFunc = Signature(Type(newParams), sig.results); invokeSigs.insert(sigWoOrigFunc); - return Name("invoke_" + sigWoOrigFunc); + return Name("invoke_" + + getSig(sigWoOrigFunc.results, sigWoOrigFunc.params)); } - Name fixEmEHSjLjNames(const Name& name, const std::string& sig) { + Name fixEmEHSjLjNames(const Name& name, Signature sig) { if (name == "emscripten_longjmp_jmpbuf") { return "emscripten_longjmp"; } @@ -1042,8 +1035,7 @@ struct FixInvokeFunctionNamesWalker return; } - FunctionType* func = wasm.getFunctionType(curr->type); - Name newname = fixEmEHSjLjNames(curr->base, getSig(func)); + Name newname = fixEmEHSjLjNames(curr->base, curr->sig); if (newname == curr->base) { return; } @@ -1085,16 +1077,16 @@ void EmscriptenGlueGenerator::fixInvokeFunctionNames() { } } -template<class C> void printSet(std::ostream& o, C& c) { +void printSignatures(std::ostream& o, const std::set<Signature>& c) { o << "["; bool first = true; - for (auto& item : c) { + for (auto& sig : c) { if (first) { first = false; } else { o << ","; } - o << '"' << item << '"'; + o << '"' << getSig(sig.results, sig.params) << '"'; } o << "]"; } @@ -1123,7 +1115,7 @@ std::string EmscriptenGlueGenerator::generateEmscriptenMetadata( for (auto& asmConst : emAsmWalker.asmConsts) { meta << nextElement(); meta << '"' << asmConst.id << "\": [\"" << asmConst.code << "\", "; - printSet(meta, asmConst.sigs); + printSignatures(meta, asmConst.sigs); meta << ", [\"" << proxyingSuffix(asmConst.proxy) << "\"]"; meta << "]"; @@ -1303,7 +1295,7 @@ void EmscriptenGlueGenerator::exportWasiStart() { {LiteralUtils::makeZero(i32, wasm), LiteralUtils::makeZero(i32, wasm)}, i32)); auto* func = - builder.makeFunction(_start, std::vector<wasm::Type>{}, none, {}, body); + builder.makeFunction(_start, Signature(Type::none, Type::none), {}, body); wasm.addFunction(func); wasm.addExport(builder.makeExport(_start, _start, ExternalKind::Function)); } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 24a4fcb35..10b6aead7 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -23,7 +23,6 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" #include "ir/branch-utils.h" -#include "ir/function-type-utils.h" #include "shared-constants.h" #include "wasm-binary.h" @@ -452,16 +451,20 @@ Name SExpressionWasmBuilder::getFunctionName(Element& s) { } } -Name SExpressionWasmBuilder::getFunctionTypeName(Element& s) { +Signature SExpressionWasmBuilder::getFunctionSignature(Element& s) { if (s.dollared()) { - return s.str(); + auto it = signatureIndices.find(s.str().str); + if (it == signatureIndices.end()) { + throw ParseException("unknown function type in getFunctionSignature"); + } + return signatures[it->second]; } else { // index size_t offset = atoi(s.str().c_str()); - if (offset >= wasm.functionTypes.size()) { - throw ParseException("unknown function type in getFunctionTypeName"); + if (offset >= signatures.size()) { + throw ParseException("unknown function type in getFunctionSignature"); } - return wasm.functionTypes[offset]->name; + return signatures[offset]; } } @@ -539,7 +542,7 @@ SExpressionWasmBuilder::parseParamOrLocal(Element& s, size_t& localIndex) { } // Parses (result type) element. (e.g. (result i32)) -Type SExpressionWasmBuilder::parseResult(Element& s) { +Type SExpressionWasmBuilder::parseResults(Element& s) { assert(elementStartsWith(s, RESULT)); if (s.size() != 2) { throw ParseException("invalid result arity", s.line, s.col); @@ -550,131 +553,95 @@ Type SExpressionWasmBuilder::parseResult(Element& s) { // Parses an element that references an entry in the type section. The element // should be in the form of (type name) or (type index). // (e.g. (type $a), (type 0)) -FunctionType* SExpressionWasmBuilder::parseTypeRef(Element& s) { +Signature SExpressionWasmBuilder::parseTypeRef(Element& s) { assert(elementStartsWith(s, TYPE)); if (s.size() != 2) { throw ParseException("invalid type reference", s.line, s.col); } - IString name = getFunctionTypeName(*s[1]); - FunctionType* functionType = wasm.getFunctionTypeOrNull(name); - if (!functionType) { - throw ParseException("bad function type for import", s[1]->line, s[1]->col); - } - return functionType; + return getFunctionSignature(*s[1]); } // Prases typeuse, a reference to a type definition. It is in the form of either // (type index) or (type name), possibly augmented by inlined (param) and -// (result) nodes. (type) node can be omitted as well, in which case we get an -// existing type if there's one with the same structure or create one. -// Outputs are returned by parameter references. +// (result) nodes. (type) node can be omitted as well. Outputs are returned by +// parameter references. // typeuse ::= (type index|name)+ | // (type index|name)+ (param ..)* (result ..)* | // (param ..)* (result ..)* -// TODO Remove FunctionType* parameter and the related logic to create -// FunctionType once we remove FunctionType class. -size_t SExpressionWasmBuilder::parseTypeUse(Element& s, - size_t startPos, - FunctionType*& functionType, - std::vector<NameType>& namedParams, - Type& result) { +size_t +SExpressionWasmBuilder::parseTypeUse(Element& s, + size_t startPos, + Signature& functionSignature, + std::vector<NameType>& namedParams) { + std::vector<Type> params, results; size_t i = startPos; - bool typeExists = false, paramOrResultExists = false; + + bool typeExists = false, paramsOrResultsExist = false; if (i < s.size() && elementStartsWith(*s[i], TYPE)) { typeExists = true; - functionType = parseTypeRef(*s[i++]); + functionSignature = parseTypeRef(*s[i++]); } - size_t paramPos = i; + size_t paramPos = i; size_t localIndex = 0; + while (i < s.size() && elementStartsWith(*s[i], PARAM)) { - paramOrResultExists = true; + paramsOrResultsExist = true; auto newParams = parseParamOrLocal(*s[i++], localIndex); namedParams.insert(namedParams.end(), newParams.begin(), newParams.end()); + for (auto p : newParams) { + params.push_back(p.type); + } } - result = none; - if (i < s.size() && elementStartsWith(*s[i], RESULT)) { - paramOrResultExists = true; - result = parseResult(*s[i++]); + + while (i < s.size() && elementStartsWith(*s[i], RESULT)) { + paramsOrResultsExist = true; + // TODO: make parseResults return a vector + results.push_back(parseResults(*s[i++])); } + + auto inlineSig = Signature(Type(params), Type(results)); + // If none of type/param/result exists, this is equivalent to a type that does // not have parameters and returns nothing. - if (!typeExists && !paramOrResultExists) { - paramOrResultExists = true; + if (!typeExists && !paramsOrResultsExist) { + paramsOrResultsExist = true; } - // verify if (type) and (params)/(result) match, if both are specified - if (typeExists && paramOrResultExists) { - size_t line = s[paramPos]->line, col = s[paramPos]->col; - const char* msg = "type and param/result don't match"; - if (functionType->result != result) { - throw ParseException(msg, line, col); - } - if (functionType->params.size() != namedParams.size()) { - throw ParseException(msg, line, col); - } - for (size_t i = 0, n = namedParams.size(); i < n; i++) { - if (functionType->params[i] != namedParams[i].type) { - throw ParseException(msg, line, col); - } + if (!typeExists) { + functionSignature = inlineSig; + } else if (paramsOrResultsExist) { + // verify that (type) and (params)/(result) match + if (inlineSig != functionSignature) { + throw ParseException("type and param/result don't match", + s[paramPos]->line, + s[paramPos]->col); } } - // If only (param)/(result) is specified, check if there's a matching type, - // and if there isn't, create one. - if (!typeExists) { - bool need = true; - std::vector<Type> params; - for (auto& p : namedParams) { - params.push_back(p.type); - } - for (auto& existing : wasm.functionTypes) { - if (existing->structuralComparison(params, result)) { - functionType = existing.get(); - need = false; - break; - } - } - if (need) { - functionType = ensureFunctionType(params, result, &wasm); - } + // Add implicitly defined type to global list so it has an index + if (std::find(signatures.begin(), signatures.end(), functionSignature) == + signatures.end()) { + signatures.push_back(functionSignature); } - // If only (type) is specified, populate params and result. - if (!paramOrResultExists) { - assert(functionType); - result = functionType->result; - for (size_t index = 0, e = functionType->params.size(); index < e; - index++) { - Type type = functionType->params[index]; - namedParams.emplace_back(Name::fromInt(index), type); + // If only (type) is specified, populate `namedParams` + if (!paramsOrResultsExist) { + const std::vector<Type>& funcParams = functionSignature.params.expand(); + for (size_t index = 0, e = funcParams.size(); index < e; index++) { + namedParams.emplace_back(Name::fromInt(index), funcParams[index]); } } return i; } -// Parses a typeuse. Ignores all parameter names. -size_t SExpressionWasmBuilder::parseTypeUse(Element& s, - size_t startPos, - FunctionType*& functionType, - std::vector<Type>& params, - Type& result) { - std::vector<NameType> namedParams; - size_t nextPos = parseTypeUse(s, startPos, functionType, namedParams, result); - for (auto& p : namedParams) { - params.push_back(p.type); - } - return nextPos; -} - // Parses a typeuse. Use this when only FunctionType* is needed. size_t SExpressionWasmBuilder::parseTypeUse(Element& s, size_t startPos, - FunctionType*& functionType) { - std::vector<Type> params; - Type result; - return parseTypeUse(s, startPos, functionType, params, result); + Signature& functionSignature) { + std::vector<NameType> params; + return parseTypeUse(s, startPos, functionSignature, params); } void SExpressionWasmBuilder::preParseFunctionType(Element& s) { @@ -694,11 +661,9 @@ void SExpressionWasmBuilder::preParseFunctionType(Element& s) { } functionNames.push_back(name); functionCounter++; - FunctionType* type = nullptr; - std::vector<Type> params; - parseTypeUse(s, i, type); - assert(type && "type should've been set by parseTypeUse"); - functionTypes[name] = type->result; + Signature sig; + parseTypeUse(s, i, sig); + functionTypes[name] = sig.results; } size_t SExpressionWasmBuilder::parseFunctionNames(Element& s, @@ -771,11 +736,9 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { } // parse typeuse: type/param/result - FunctionType* functionType = nullptr; + Signature sig; std::vector<NameType> params; - Type result = none; - i = parseTypeUse(s, i, functionType, params, result); - assert(functionType && "functionType should've been set by parseTypeUse"); + i = parseTypeUse(s, i, sig, params); // when (import) is inside a (func) element, this is not a function definition // but an import. @@ -790,9 +753,8 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { im->name = name; im->module = importModule; im->base = importBase; - im->type = functionType->name; - FunctionTypeUtils::fillFunction(im.get(), functionType); - functionTypes[name] = im->result; + im->sig = sig; + functionTypes[name] = sig.results; if (wasm.getFunctionOrNull(im->name)) { throw ParseException("duplicate import", s.line, s.col); } @@ -808,7 +770,6 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { throw ParseException("preParseImport in func"); } - result = functionType->result; size_t localIndex = params.size(); // local index for params and locals // parse locals @@ -820,8 +781,7 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { // make a new function currFunction = std::unique_ptr<Function>(Builder(wasm).makeFunction( - name, std::move(params), result, std::move(vars))); - currFunction->type = functionType->name; + name, std::move(params), sig.results, std::move(vars))); // parse body Block* autoBlock = nullptr; // may need to add a block for the very top level @@ -847,14 +807,11 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { autoBlock->name = FAKE_RETURN; } if (autoBlock) { - autoBlock->finalize(result); + autoBlock->finalize(sig.results); } if (!currFunction->body) { currFunction->body = allocator.alloc<Nop>(); } - if (currFunction->result != result) { - throw ParseException("bad func declaration", s.line, s.col); - } if (s.startLoc) { currFunction->prologLocation.insert(getDebugLocation(*s.startLoc)); } @@ -1677,13 +1634,9 @@ Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s, if (!wasm.table.exists) { throw ParseException("no table"); } - auto ret = allocator.alloc<CallIndirect>(); Index i = 1; - FunctionType* functionType = nullptr; - i = parseTypeUse(s, i, functionType); - assert(functionType && "functionType should've been set by parseTypeUse"); - ret->fullType = functionType->name; - ret->type = functionType->result; + auto ret = allocator.alloc<CallIndirect>(); + i = parseTypeUse(s, i, ret->sig); parseCallOperands(s, i, s.size() - 1, ret); ret->target = parseExpression(s[s.size() - 1]); ret->isReturn = isReturn; @@ -2135,14 +2088,13 @@ void SExpressionWasmBuilder::parseImport(Element& s) { Element& inner = newStyle ? *s[3] : s; Index j = newStyle ? newStyleInner : i; if (kind == ExternalKind::Function) { - FunctionType* functionType = nullptr; auto func = make_unique<Function>(); - j = parseTypeUse(inner, j, functionType, func->params, func->result); + + j = parseTypeUse(inner, j, func->sig); func->name = name; func->module = module; func->base = base; - func->type = functionType->name; - functionTypes[name] = func->result; + functionTypes[name] = func->sig.results; wasm.addFunction(func.release()); } else if (kind == ExternalKind::Global) { Type type; @@ -2202,14 +2154,10 @@ void SExpressionWasmBuilder::parseImport(Element& s) { throw ParseException("invalid attribute", attrElem.line, attrElem.col); } event->attribute = atoi(attrElem[1]->c_str()); - std::vector<Type> paramTypes; - FunctionType* fakeFunctionType; // just to call parseTypeUse - Type results; - j = parseTypeUse(inner, j, fakeFunctionType, paramTypes, results); + j = parseTypeUse(inner, j, event->sig); event->name = name; event->module = module; event->base = base; - event->sig = Signature(Type(paramTypes), results); wasm.addEvent(event.release()); } // If there are more elements, they are invalid @@ -2406,10 +2354,15 @@ void SExpressionWasmBuilder::parseInnerElem(Element& s, } void SExpressionWasmBuilder::parseType(Element& s) { - std::unique_ptr<FunctionType> type = make_unique<FunctionType>(); + std::vector<Type> params; + std::vector<Type> results; size_t i = 1; if (s[i]->isStr()) { - type->name = s[i]->str(); + std::string name = s[i]->str().str; + if (signatureIndices.find(name) != signatureIndices.end()) { + throw ParseException("duplicate function type", s.line, s.col); + } + signatureIndices[name] = signatures.size(); i++; } Element& func = *s[i]; @@ -2417,25 +2370,13 @@ void SExpressionWasmBuilder::parseType(Element& s) { Element& curr = *func[k]; if (elementStartsWith(curr, PARAM)) { auto newParams = parseParamOrLocal(curr); - type->params.insert( - type->params.end(), newParams.begin(), newParams.end()); + params.insert(params.end(), newParams.begin(), newParams.end()); } else if (elementStartsWith(curr, RESULT)) { - type->result = parseResult(curr); + // TODO: Parse multiple results at once + results.push_back(parseResults(curr)); } } - while (type->name.is() && wasm.getFunctionTypeOrNull(type->name)) { - throw ParseException("duplicate function type", s.line, s.col); - } - // We allow duplicate types in the type section, i.e., we can have - // (func (param i32) (result i32)) many times. For unnamed types, find a name - // that does not clash with existing ones. - if (!type->name.is()) { - type->name = "FUNCSIG$" + getSig(type.get()); - } - while (wasm.getFunctionTypeOrNull(type->name)) { - type->name = Name(std::string(type->name.c_str()) + "_"); - } - wasm.addFunctionType(std::move(type)); + signatures.emplace_back(Type(params), Type(results)); } void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) { @@ -2515,11 +2456,7 @@ void SExpressionWasmBuilder::parseEvent(Element& s, bool preParseImport) { event->attribute = atoi(attrElem[1]->c_str()); // Parse typeuse - std::vector<Type> paramTypes; - Type results; - FunctionType* fakeFunctionType; // just co call parseTypeUse - i = parseTypeUse(s, i, fakeFunctionType, paramTypes, results); - event->sig = Signature(Type(paramTypes), results); + i = parseTypeUse(s, i, event->sig); // If there are more elements, they are invalid if (i < s.size()) { diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp index 0f12a2f2b..abaf17e05 100644 --- a/src/wasm/wasm-stack.cpp +++ b/src/wasm/wasm-stack.cpp @@ -68,9 +68,7 @@ void BinaryInstWriter::visitCall(Call* curr) { void BinaryInstWriter::visitCallIndirect(CallIndirect* curr) { int8_t op = curr->isReturn ? BinaryConsts::RetCallIndirect : BinaryConsts::CallIndirect; - auto* type = parent.getModule()->getFunctionType(curr->fullType); - Signature sig(Type(type->params), type->result); - o << op << U32LEB(parent.getTypeIndex(sig)) + o << op << U32LEB(parent.getTypeIndex(curr->sig)) << U32LEB(0); // Reserved flags field } diff --git a/src/wasm/wasm-type.cpp b/src/wasm/wasm-type.cpp index bc17e2193..7d29b0dcd 100644 --- a/src/wasm/wasm-type.cpp +++ b/src/wasm/wasm-type.cpp @@ -254,16 +254,23 @@ unsigned getTypeSize(Type type) { } FeatureSet getFeatures(Type type) { - switch (type) { - case v128: - return FeatureSet::SIMD; - case anyref: - return FeatureSet::ReferenceTypes; - case exnref: - return FeatureSet::ExceptionHandling; - default: - return FeatureSet(); + FeatureSet feats = FeatureSet::MVP; + for (Type t : type.expand()) { + switch (t) { + case v128: + feats |= FeatureSet::SIMD; + break; + case anyref: + feats |= FeatureSet::ReferenceTypes; + break; + case exnref: + feats |= FeatureSet::ExceptionHandling; + break; + default: + break; + } } + return feats; } Type getType(unsigned size, bool float_) { diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 75121d4f4..c6de444c1 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -593,14 +593,15 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - if (!shouldBeTrue(curr->operands.size() == target->params.size(), + const std::vector<Type> params = target->sig.params.expand(); + if (!shouldBeTrue(curr->operands.size() == params.size(), curr, "call param number must match")) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - target->params[i], + params[i], curr, "call param types must match") && !info.quiet) { @@ -613,8 +614,8 @@ void FunctionValidator::visitCall(Call* curr) { curr, "return_call should have unreachable type"); shouldBeEqual( - getFunction()->result, - target->result, + getFunction()->sig.results, + target->sig.results, curr, "return_call callee return type must match caller return type"); } else { @@ -629,7 +630,7 @@ void FunctionValidator::visitCall(Call* curr) { "calls may only be unreachable if they have unreachable operands"); } else { shouldBeEqual(curr->type, - target->result, + target->sig.results, curr, "call type must match callee return type"); } @@ -643,20 +644,17 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { if (!info.validateGlobally) { return; } - auto* type = getModule()->getFunctionTypeOrNull(curr->fullType); - if (!shouldBeTrue(!!type, curr, "call_indirect type must exist")) { - return; - } + const std::vector<Type>& params = curr->sig.params.expand(); shouldBeEqualOrFirstIsUnreachable( curr->target->type, i32, curr, "indirect call target must be an i32"); - if (!shouldBeTrue(curr->operands.size() == type->params.size(), + if (!shouldBeTrue(curr->operands.size() == params.size(), curr, "call param number must match")) { return; } for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, - type->params[i], + params[i], curr, "call param types must match") && !info.quiet) { @@ -669,8 +667,8 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { curr, "return_call_indirect should have unreachable type"); shouldBeEqual( - getFunction()->result, - type->result, + getFunction()->sig.results, + curr->sig.results, curr, "return_call_indirect callee return type must match caller return type"); } else { @@ -687,7 +685,7 @@ void FunctionValidator::visitCallIndirect(CallIndirect* curr) { } } else { shouldBeEqual(curr->type, - type->result, + curr->sig.results, curr, "call_indirect type must match callee return type"); } @@ -701,31 +699,35 @@ void FunctionValidator::visitConst(Const* curr) { } void FunctionValidator::visitLocalGet(LocalGet* curr) { - shouldBeTrue(curr->index < getFunction()->getNumLocals(), - curr, - "local.get index must be small enough"); shouldBeTrue(curr->type.isConcrete(), curr, "local.get must have a valid type - check what you provided " "when you constructed the node"); - shouldBeTrue(curr->type == getFunction()->getLocalType(curr->index), - curr, - "local.get must have proper type"); + if (shouldBeTrue(curr->index < getFunction()->getNumLocals(), + curr, + "local.get index must be small enough")) { + shouldBeTrue(curr->type == getFunction()->getLocalType(curr->index), + curr, + "local.get must have proper type"); + } } void FunctionValidator::visitLocalSet(LocalSet* curr) { - shouldBeTrue(curr->index < getFunction()->getNumLocals(), - curr, - "local.set index must be small enough"); - if (curr->value->type != unreachable) { - if (curr->type != none) { // tee is ok anyhow - shouldBeEqualOrFirstIsUnreachable( - curr->value->type, curr->type, curr, "local.set type must be correct"); + if (shouldBeTrue(curr->index < getFunction()->getNumLocals(), + curr, + "local.set index must be small enough")) { + if (curr->value->type != unreachable) { + if (curr->type != none) { // tee is ok anyhow + shouldBeEqualOrFirstIsUnreachable(curr->value->type, + curr->type, + curr, + "local.set type must be correct"); + } + shouldBeEqual(getFunction()->getLocalType(curr->index), + curr->value->type, + curr, + "local.set type must match function"); } - shouldBeEqual(getFunction()->getLocalType(curr->index), - curr->value->type, - curr, - "local.set type must match function"); } } @@ -1754,28 +1756,35 @@ void FunctionValidator::visitBrOnExn(BrOnExn* curr) { } void FunctionValidator::visitFunction(Function* curr) { - FeatureSet typeFeatures = getFeatures(curr->result); - for (auto type : curr->params) { - typeFeatures |= getFeatures(type); + shouldBeTrue(!curr->sig.results.isMulti(), + curr->body, + "Multivalue functions not allowed yet"); + FeatureSet features; + for (auto type : curr->sig.params.expand()) { + features |= getFeatures(type); shouldBeTrue(type.isConcrete(), curr, "params must be concretely typed"); } + for (auto type : curr->sig.results.expand()) { + features |= getFeatures(type); + shouldBeTrue(type.isConcrete(), curr, "results must be concretely typed"); + } for (auto type : curr->vars) { - typeFeatures |= getFeatures(type); + features |= getFeatures(type); shouldBeTrue(type.isConcrete(), curr, "vars must be concretely typed"); } - shouldBeTrue(typeFeatures <= getModule()->features, + shouldBeTrue(features <= getModule()->features, curr, "all used types should be allowed"); // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return if (curr->body->type != unreachable) { - shouldBeEqual(curr->result, + shouldBeEqual(curr->sig.results, curr->body->type, curr->body, "function body type must match, if function returns"); } if (returnType != unreachable) { - shouldBeEqual(curr->result, + shouldBeEqual(curr->sig.results, returnType, curr->body, "function result must match, if function has returns"); @@ -1784,22 +1793,6 @@ void FunctionValidator::visitFunction(Function* curr) { breakInfos.empty(), curr->body, "all named break targets must exist"); returnType = unreachable; labelNames.clear(); - // if function has a named type, it must match up with the function's params - // and result - if (info.validateGlobally && curr->type.is()) { - auto* ft = getModule()->getFunctionType(curr->type); - shouldBeTrue(ft->params == curr->params, - curr->name, - "function params must match its declared type"); - shouldBeTrue(ft->result == curr->result, - curr->name, - "function result must match its declared type"); - } - if (curr->imported()) { - shouldBeTrue(curr->type.is(), - curr->name, - "imported functions must have a function type"); - } // validate optional local names std::set<Name> seen; for (auto& pair : curr->localNames) { @@ -1925,17 +1918,18 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { static void validateImports(Module& module, ValidationInfo& info) { ModuleUtils::iterImportedFunctions(module, [&](Function* curr) { if (info.validateWeb) { - auto* functionType = module.getFunctionType(curr->type); - info.shouldBeUnequal(functionType->result, - i64, - curr->name, - "Imported function must not have i64 return type"); - for (Type param : functionType->params) { + for (Type param : curr->sig.params.expand()) { info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); } + for (Type result : curr->sig.results.expand()) { + info.shouldBeUnequal(result, + i64, + curr->name, + "Imported function must not have i64 results"); + } } }); if (!module.features.hasMutableGlobals()) { @@ -1951,17 +1945,19 @@ static void validateExports(Module& module, ValidationInfo& info) { if (curr->kind == ExternalKind::Function) { if (info.validateWeb) { Function* f = module.getFunction(curr->value); - info.shouldBeUnequal(f->result, - i64, - f->name, - "Exported function must not have i64 return type"); - for (auto param : f->params) { + for (auto param : f->sig.params.expand()) { info.shouldBeUnequal( param, i64, f->name, "Exported function must not have i64 parameters"); } + for (auto result : f->sig.results.expand()) { + info.shouldBeUnequal(result, + i64, + f->name, + "Exported function must not have i64 results"); + } } } else if (curr->kind == ExternalKind::Global && !module.features.hasMutableGlobals()) { @@ -2133,10 +2129,12 @@ static void validateModule(Module& module, ValidationInfo& info) { auto func = module.getFunctionOrNull(module.start); if (info.shouldBeTrue( func != nullptr, module.start, "start must be found")) { - info.shouldBeTrue( - func->params.size() == 0, module.start, "start must have 0 params"); - info.shouldBeTrue( - func->result == none, module.start, "start must not return a value"); + info.shouldBeTrue(func->sig.params == Type::none, + module.start, + "start must have 0 params"); + info.shouldBeTrue(func->sig.results == Type::none, + module.start, + "start must not return a value"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 5a722e600..83829418e 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -432,6 +432,7 @@ void Call::finalize() { } void CallIndirect::finalize() { + type = sig.results; handleUnreachableOperands(this); if (isReturn) { type = unreachable; @@ -441,34 +442,6 @@ void CallIndirect::finalize() { } } -bool FunctionType::structuralComparison(FunctionType& b) { - return structuralComparison(b.params, b.result); -} - -bool FunctionType::structuralComparison(const std::vector<Type>& otherParams, - Type otherResult) { - if (result != otherResult) { - return false; - } - if (params.size() != otherParams.size()) { - return false; - } - for (size_t i = 0; i < params.size(); i++) { - if (params[i] != otherParams[i]) { - return false; - } - } - return true; -} - -bool FunctionType::operator==(FunctionType& b) { - if (name != b.name) { - return false; - } - return structuralComparison(b); -} -bool FunctionType::operator!=(FunctionType& b) { return !(*this == b); } - bool LocalSet::isTee() { return type != none; } void LocalSet::setTee(bool is) { @@ -932,15 +905,23 @@ void Push::finalize() { } } -size_t Function::getNumParams() { return params.size(); } +size_t Function::getNumParams() { return sig.params.size(); } size_t Function::getNumVars() { return vars.size(); } -size_t Function::getNumLocals() { return params.size() + vars.size(); } +size_t Function::getNumLocals() { return sig.params.size() + vars.size(); } -bool Function::isParam(Index index) { return index < params.size(); } +bool Function::isParam(Index index) { + size_t size = sig.params.size(); + assert(index < size + vars.size()); + return index < size; +} -bool Function::isVar(Index index) { return index >= params.size(); } +bool Function::isVar(Index index) { + auto base = getVarIndexBase(); + assert(index < base + vars.size()); + return index >= base; +} bool Function::hasLocalName(Index index) const { return localNames.find(index) != localNames.end(); @@ -973,13 +954,14 @@ Index Function::getLocalIndex(Name name) { return iter->second; } -Index Function::getVarIndexBase() { return params.size(); } +Index Function::getVarIndexBase() { return sig.params.size(); } Type Function::getLocalType(Index index) { - if (isParam(index)) { + const std::vector<Type>& params = sig.params.expand(); + if (index < params.size()) { return params[index]; } else if (isVar(index)) { - return vars[index - getVarIndexBase()]; + return vars[index - params.size()]; } else { WASM_UNREACHABLE("invalid local index"); } @@ -994,14 +976,6 @@ void Function::clearDebugInfo() { epilogLocation.clear(); } -FunctionType* Module::getFunctionType(Name name) { - auto iter = functionTypesMap.find(name); - if (iter == functionTypesMap.end()) { - Fatal() << "Module::getFunctionType: " << name << " does not exist"; - } - return iter->second; -} - Export* Module::getExport(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -1035,14 +1009,6 @@ Event* Module::getEvent(Name name) { return iter->second; } -FunctionType* Module::getFunctionTypeOrNull(Name name) { - auto iter = functionTypesMap.find(name); - if (iter == functionTypesMap.end()) { - return nullptr; - } - return iter->second; -} - Export* Module::getExportOrNull(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -1075,19 +1041,6 @@ Event* Module::getEventOrNull(Name name) { return iter->second; } -FunctionType* Module::addFunctionType(std::unique_ptr<FunctionType> curr) { - if (!curr->name.is()) { - Fatal() << "Module::addFunctionType: empty name"; - } - if (getFunctionTypeOrNull(curr->name)) { - Fatal() << "Module::addFunctionType: " << curr->name << " already exists"; - } - auto* p = curr.get(); - functionTypes.emplace_back(std::move(curr)); - functionTypesMap[p->name] = p; - return p; -} - Export* Module::addExport(Export* curr) { if (!curr->name.is()) { Fatal() << "Module::addExport: empty name"; @@ -1166,9 +1119,6 @@ void removeModuleElement(Vector& v, Map& m, Name name) { } } -void Module::removeFunctionType(Name name) { - removeModuleElement(functionTypes, functionTypesMap, name); -} void Module::removeExport(Name name) { removeModuleElement(exports, exportsMap, name); } @@ -1198,9 +1148,6 @@ void removeModuleElements(Vector& v, v.end()); } -void Module::removeFunctionTypes(std::function<bool(FunctionType*)> pred) { - removeModuleElements(functionTypes, functionTypesMap, pred); -} void Module::removeExports(std::function<bool(Export*)> pred) { removeModuleElements(exports, exportsMap, pred); } @@ -1219,10 +1166,6 @@ void Module::updateMaps() { for (auto& curr : functions) { functionsMap[curr->name] = curr.get(); } - functionTypesMap.clear(); - for (auto& curr : functionTypes) { - functionTypesMap[curr->name] = curr.get(); - } exportsMap.clear(); for (auto& curr : exports) { exportsMap[curr->name] = curr.get(); diff --git a/src/wasm2js.h b/src/wasm2js.h index 906a4b9dc..2cacbe8e5 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -406,14 +406,11 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { }); if (generateFetchHighBits) { Builder builder(allocator); - std::vector<Type> params; - std::vector<Type> vars; asmFunc[3]->push_back(processFunction( wasm, builder.makeFunction(WASM_FETCH_HIGH_BITS, - std::move(params), - i32, - std::move(vars), + Signature(Type::none, Type::i32), + {}, builder.makeReturn(builder.makeGlobalGet( INT64_TO_32_HIGH_BITS, i32))))); auto e = new Export(); |