diff options
Diffstat (limited to 'src')
60 files changed, 1299 insertions, 1654 deletions
diff --git a/src/asm2wasm.h b/src/asm2wasm.h index 3ea9a3a1e..44666c1e5 100644 --- a/src/asm2wasm.h +++ b/src/asm2wasm.h @@ -33,7 +33,9 @@ #include "parsing.h" #include "ir/bits.h" #include "ir/branch-utils.h" +#include "ir/function-type-utils.h" #include "ir/literal-utils.h" +#include "ir/module-utils.h" #include "ir/trapping.h" #include "ir/utils.h" #include "wasm-builder.h" @@ -343,8 +345,8 @@ struct Asm2WasmPreProcessor { } }; -static CallImport* checkDebugInfo(Expression* curr) { - if (auto* call = curr->dynCast<CallImport>()) { +static Call* checkDebugInfo(Expression* curr) { + if (auto* call = curr->dynCast<Call>()) { if (call->target == EMSCRIPTEN_DEBUGINFO) { return call; } @@ -478,7 +480,7 @@ private: std::map<IString, std::unique_ptr<FunctionType>> importedFunctionTypes; - void noteImportedFunctionCall(Ref ast, Type resultType, CallImport* call) { + void noteImportedFunctionCall(Ref ast, Type resultType, Call* call) { assert(ast[0] == CALL && ast[1]->isString()); IString importName = ast[1]->getIString(); auto type = make_unique<FunctionType>(); @@ -696,7 +698,6 @@ private: void fixCallType(Expression* call, Type type) { if (call->is<Call>()) call->cast<Call>()->type = type; - if (call->is<CallImport>()) call->cast<CallImport>()->type = type; else if (call->is<CallIndirect>()) call->cast<CallIndirect>()->type = type; } @@ -762,45 +763,35 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } // import memory - auto memoryImport = make_unique<Import>(); - memoryImport->name = MEMORY; - memoryImport->module = ENV; - memoryImport->base = MEMORY; - memoryImport->kind = ExternalKind::Memory; + wasm.memory.name = MEMORY; + wasm.memory.module = ENV; + wasm.memory.base = MEMORY; wasm.memory.exists = true; - wasm.memory.imported = true; - wasm.addImport(memoryImport.release()); // import table - auto tableImport = make_unique<Import>(); - tableImport->name = TABLE; - tableImport->module = ENV; - tableImport->base = TABLE; - tableImport->kind = ExternalKind::Table; - wasm.addImport(tableImport.release()); + wasm.table.name = TABLE; + wasm.table.module = ENV; + wasm.table.base = TABLE; wasm.table.exists = true; - wasm.table.imported = true; // Import memory offset, if not already there { - auto* import = new Import; - import->name = Name("memoryBase"); - import->module = Name("env"); - import->base = Name("memoryBase"); - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + auto* import = new Global; + import->name = "memoryBase"; + import->module = "env"; + import->base = "memoryBase"; + import->type = i32; + wasm.addGlobal(import); } // Import table offset, if not already there { - auto* import = new Import; - import->name = Name("tableBase"); - import->module = Name("env"); - import->base = Name("tableBase"); - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + auto* import = new Global; + import->name = "tableBase"; + import->module = "env"; + import->base = "tableBase"; + import->type = i32; + wasm.addGlobal(import); } auto addImport = [&](IString name, Ref imported, Type type) { @@ -907,18 +898,18 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } } } - auto import = make_unique<Import>(); - import->name = name; - import->module = moduleName; - import->base = imported[2]->getIString(); + auto base = imported[2]->getIString(); // special-case some asm builtins - if (import->module == GLOBAL && (import->base == NAN_ || import->base == INFINITY_)) { + if (module == GLOBAL && (base == NAN_ || base == INFINITY_)) { type = Type::f64; } if (type != Type::none) { // this is a global - import->kind = ExternalKind::Global; - import->globalType = type; + auto* import = new Global; + import->name = name; + import->module = moduleName; + import->base = base; + import->type = type; mappedGlobals.emplace(name, type); // tableBase and memoryBase are used as segment/element offsets, and must be constant; // otherwise, an asm.js import of a constant is mutable, e.g. STACKTOP @@ -935,15 +926,19 @@ void Asm2WasmBuilder::processAsm(Ref ast) { )); } } + if ((name == "tableBase" || name == "memoryBase") && + wasm.getGlobalOrNull(import->base)) { + return; + } + wasm.addGlobal(import); } else { - import->kind = ExternalKind::Function; + // this is a function + auto* import = new Function; + import->name = name; + import->module = moduleName; + import->base = base; + wasm.addFunction(import); } - // we may have already created an import for this manually - if ((name == "tableBase" || name == "memoryBase") && - (wasm.getImportOrNull(import->base) || wasm.getGlobalOrNull(import->base))) { - return; - } - wasm.addImport(import.release()); }; IString Int8Array, Int16Array, Int32Array, UInt8Array, UInt16Array, UInt32Array, Float32Array, Float64Array; @@ -1205,27 +1200,32 @@ void Asm2WasmBuilder::processAsm(Ref ast) { std::vector<IString> toErase; - for (auto& import : wasm.imports) { - if (import->kind != ExternalKind::Function) continue; + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { IString name = import->name; if (importedFunctionTypes.find(name) != importedFunctionTypes.end()) { // special math builtins FunctionType* builtin = getBuiltinFunctionType(import->module, import->base); if (builtin) { - import->functionType = builtin->name; - continue; + import->type = builtin->name; + } else { + import->type = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm)->name; } - import->functionType = ensureFunctionType(getSig(importedFunctionTypes[name].get()), &wasm)->name; } else if (import->module != ASM2WASM) { // special-case the special module // never actually used, which means we don't know the function type since the usage tells us, so illegal for it to remain toErase.push_back(name); } - } + }); for (auto curr : toErase) { - wasm.removeImport(curr); + wasm.removeFunction(curr); } + // Finalize function imports now that we've seen all the calls + + ModuleUtils::iterImportedFunctions(wasm, [&](Function* func) { + FunctionTypeUtils::fillFunction(func, wasm.getFunctionType(func->type)); + }); + // Finalize calls now that everything is known and generated struct FinalizeCalls : public WalkerPass<PostWalker<FinalizeCalls>> { @@ -1258,96 +1258,94 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } void visitCall(Call* curr) { + // The call target may not exist if it is one of our special fake imports for callIndirect fixups auto* calledFunc = getModule()->getFunctionOrNull(curr->target); - if (!calledFunc) { - std::cerr << "invalid call target: " << curr->target << '\n'; - WASM_UNREACHABLE(); - } - // The result type of the function being called is now known, and can be applied. - auto result = calledFunc->result; - if (curr->type != result) { - curr->type = result; - } - // Handle mismatched numbers of arguments. In clang, if a function is declared one way - // but called in another, it inserts bitcasts to make things work. Those end up - // working since it is "ok" to drop or add parameters in native platforms, even - // though it's undefined behavior. We warn about it here, but tolerate it, if there is - // a simple solution. - if (curr->operands.size() < calledFunc->params.size()) { - notifyAboutWrongOperands("warning: asm2wasm adding operands", calledFunc); - while (curr->operands.size() < calledFunc->params.size()) { - // Add params as necessary, with zeros. - curr->operands.push_back( - LiteralUtils::makeZero(calledFunc->params[curr->operands.size()], *getModule()) - ); + if (calledFunc && !calledFunc->imported()) { + // The result type of the function being called is now known, and can be applied. + auto result = calledFunc->result; + if (curr->type != result) { + curr->type = result; } - } - if (curr->operands.size() > calledFunc->params.size()) { - notifyAboutWrongOperands("warning: asm2wasm dropping operands", calledFunc); - curr->operands.resize(calledFunc->params.size()); - } - // If the types are wrong, validation will fail later anyhow, but add a warning here, - // it may help people. - for (Index i = 0; i < curr->operands.size(); i++) { - auto sent = curr->operands[i]->type; - auto expected = calledFunc->params[i]; - if (sent != unreachable && sent != expected) { - notifyAboutWrongOperands("error: asm2wasm seeing an invalid argument type at index " + std::to_string(i) + " (this will not validate)", calledFunc); + // Handle mismatched numbers of arguments. In clang, if a function is declared one way + // but called in another, it inserts bitcasts to make things work. Those end up + // working since it is "ok" to drop or add parameters in native platforms, even + // though it's undefined behavior. We warn about it here, but tolerate it, if there is + // a simple solution. + if (curr->operands.size() < calledFunc->params.size()) { + notifyAboutWrongOperands("warning: asm2wasm adding operands", calledFunc); + while (curr->operands.size() < calledFunc->params.size()) { + // Add params as necessary, with zeros. + curr->operands.push_back( + LiteralUtils::makeZero(calledFunc->params[curr->operands.size()], *getModule()) + ); + } } - } - } - - void visitCallImport(CallImport* curr) { - // fill out call_import - add extra params as needed, etc. asm tolerates ffi overloading, wasm does not - auto iter = parent->importedFunctionTypes.find(curr->target); - if (iter == parent->importedFunctionTypes.end()) return; // one of our fake imports for callIndirect fixups - auto type = iter->second.get(); - for (size_t i = 0; i < type->params.size(); i++) { - if (i >= curr->operands.size()) { - // add a new param - auto val = parent->allocator.alloc<Const>(); - val->type = val->value.type = type->params[i]; - curr->operands.push_back(val); - } else if (curr->operands[i]->type != type->params[i]) { - // if the param is used, then we have overloading here and the combined type must be f64; - // if this is an unreachable param, then it doesn't matter. - assert(type->params[i] == f64 || curr->operands[i]->type == unreachable); - // overloaded, upgrade to f64 - switch (curr->operands[i]->type) { - case i32: curr->operands[i] = parent->builder.makeUnary(ConvertSInt32ToFloat64, curr->operands[i]); break; - case f32: curr->operands[i] = parent->builder.makeUnary(PromoteFloat32, curr->operands[i]); break; - default: {} // f64, unreachable, etc., are all good + if (curr->operands.size() > calledFunc->params.size()) { + notifyAboutWrongOperands("warning: asm2wasm dropping operands", calledFunc); + curr->operands.resize(calledFunc->params.size()); + } + // If the types are wrong, validation will fail later anyhow, but add a warning here, + // it may help people. + for (Index i = 0; i < curr->operands.size(); i++) { + auto sent = curr->operands[i]->type; + auto expected = calledFunc->params[i]; + if (sent != unreachable && sent != expected) { + notifyAboutWrongOperands("error: asm2wasm seeing an invalid argument type at index " + std::to_string(i) + " (this will not validate)", calledFunc); } } - } - Module* wasm = getModule(); - auto importResult = wasm->getFunctionType(wasm->getImport(curr->target)->functionType)->result; - if (curr->type != importResult) { - auto old = curr->type; - curr->type = importResult; - if (importResult == f64) { - // we use a JS f64 value which is the most general, and convert to it - switch (old) { - case i32: { - Unary* trunc = parent->builder.makeUnary(TruncSFloat64ToInt32, curr); - replaceCurrent(makeTrappingUnary(trunc, parent->trappingFunctions)); - break; - } - case f32: { - replaceCurrent(parent->builder.makeUnary(DemoteFloat64, curr)); - break; + } else { + // A call to an import + // fill things out: add extra params as needed, etc. asm tolerates ffi overloading, wasm does not + auto iter = parent->importedFunctionTypes.find(curr->target); + if (iter == parent->importedFunctionTypes.end()) return; // one of our fake imports for callIndirect fixups + auto type = iter->second.get(); + for (size_t i = 0; i < type->params.size(); i++) { + if (i >= curr->operands.size()) { + // add a new param + auto val = parent->allocator.alloc<Const>(); + val->type = val->value.type = type->params[i]; + curr->operands.push_back(val); + } else if (curr->operands[i]->type != type->params[i]) { + // if the param is used, then we have overloading here and the combined type must be f64; + // if this is an unreachable param, then it doesn't matter. + assert(type->params[i] == f64 || curr->operands[i]->type == unreachable); + // overloaded, upgrade to f64 + switch (curr->operands[i]->type) { + case i32: curr->operands[i] = parent->builder.makeUnary(ConvertSInt32ToFloat64, curr->operands[i]); break; + case f32: curr->operands[i] = parent->builder.makeUnary(PromoteFloat32, curr->operands[i]); break; + default: {} // f64, unreachable, etc., are all good } - case none: { - // this function returns a value, but we are not using it, so it must be dropped. - // autodrop will do that for us. - break; + } + } + Module* wasm = getModule(); + auto importResult = wasm->getFunctionType(wasm->getFunction(curr->target)->type)->result; + if (curr->type != importResult) { + auto old = curr->type; + curr->type = importResult; + if (importResult == f64) { + // we use a JS f64 value which is the most general, and convert to it + switch (old) { + case i32: { + Unary* trunc = parent->builder.makeUnary(TruncSFloat64ToInt32, curr); + replaceCurrent(makeTrappingUnary(trunc, parent->trappingFunctions)); + break; + } + case f32: { + replaceCurrent(parent->builder.makeUnary(DemoteFloat64, curr)); + break; + } + case none: { + // this function returns a value, but we are not using it, so it must be dropped. + // autodrop will do that for us. + break; + } + default: WASM_UNREACHABLE(); } - default: WASM_UNREACHABLE(); + } else { + assert(old == none); + // we don't want a return value here, but the import does provide one + // autodrop will do that for us. } - } else { - assert(old == none); - // we don't want a return value here, but the import does provide one - // autodrop will do that for us. } } } @@ -1361,7 +1359,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { target = block->list.back(); } // the something might have been optimized out, leaving only the call - if (auto* call = target->dynCast<CallImport>()) { + if (auto* call = target->dynCast<Call>()) { auto tableName = call->target; if (parent->functionTableStarts.find(tableName) == parent->functionTableStarts.end()) return; curr->target = parent->builder.makeConst(Literal((int32_t)parent->functionTableStarts[tableName])); @@ -1369,13 +1367,13 @@ void Asm2WasmBuilder::processAsm(Ref ast) { } auto* add = target->dynCast<Binary>(); if (!add) return; - if (add->right->is<CallImport>()) { - auto* offset = add->right->cast<CallImport>(); + if (add->right->is<Call>()) { + auto* offset = add->right->cast<Call>(); auto tableName = offset->target; if (parent->functionTableStarts.find(tableName) == parent->functionTableStarts.end()) return; add->right = parent->builder.makeConst(Literal((int32_t)parent->functionTableStarts[tableName])); } else { - auto* offset = add->left->dynCast<CallImport>(); + auto* offset = add->left->dynCast<Call>(); if (!offset) return; auto tableName = offset->target; if (parent->functionTableStarts.find(tableName) == parent->functionTableStarts.end()) return; @@ -1400,7 +1398,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { name = "apply-debug-info"; } - CallImport* lastDebugInfo = nullptr; + Call* lastDebugInfo = nullptr; void visitExpression(Expression* curr) { if (auto* call = checkDebugInfo(curr)) { @@ -1483,7 +1481,7 @@ void Asm2WasmBuilder::processAsm(Ref ast) { // remove the debug info intrinsic if (preprocessor.debugInfo) { - wasm.removeImport(EMSCRIPTEN_DEBUGINFO); + wasm.removeFunction(EMSCRIPTEN_DEBUGINFO); } if (udivmoddi4.is() && getTempRet0.is()) { @@ -1630,19 +1628,20 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { return ret; } if (name == DEBUGGER) { - CallImport *call = allocator.alloc<CallImport>(); + Call *call = allocator.alloc<Call>(); call->target = DEBUGGER; call->type = none; static bool addedImport = false; if (!addedImport) { addedImport = true; - auto import = new Import; // debugger = asm2wasm.debugger; + auto import = new Function; // debugger = asm2wasm.debugger; import->name = DEBUGGER; import->module = ASM2WASM; import->base = DEBUGGER; - import->functionType = ensureFunctionType("v", &wasm)->name; - import->kind = ExternalKind::Function; - wasm.addImport(import); + auto* functionType = ensureFunctionType("v", &wasm); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + wasm.addFunction(import); } return call; } @@ -1732,7 +1731,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->finalize(); if (ret->op == BinaryOp::RemSInt32 && isFloatType(ret->type)) { // WebAssembly does not have floating-point remainder, we have to emit a call to a special import of ours - CallImport *call = allocator.alloc<CallImport>(); + Call *call = allocator.alloc<Call>(); call->target = F64_REM; call->operands.push_back(ensureDouble(ret->left)); call->operands.push_back(ensureDouble(ret->right)); @@ -1740,13 +1739,14 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { static bool addedImport = false; if (!addedImport) { addedImport = true; - auto import = new Import; // f64-rem = asm2wasm.f64-rem; + auto import = new Function; // f64-rem = asm2wasm.f64-rem; import->name = F64_REM; import->module = ASM2WASM; import->base = F64_REM; - import->functionType = ensureFunctionType("ddd", &wasm)->name; - import->kind = ExternalKind::Function; - wasm.addImport(import); + auto* functionType = ensureFunctionType("ddd", &wasm); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + wasm.addFunction(import); } return call; } @@ -2200,7 +2200,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { } Expression* ret; ExpressionList* operands; - CallImport* callImport = nullptr; + bool callImport = false; Index firstOperand = 0; Ref args = ast[2]; if (tableCall) { @@ -2209,12 +2209,11 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { firstOperand = 1; operands = &specific->operands; ret = specific; - } else if (wasm.getImportOrNull(name)) { - callImport = allocator.alloc<CallImport>(); - callImport->target = name; - operands = &callImport->operands; - ret = callImport; } else { + // if we call an import, it definitely exists already; if it's a + // defined function then it might not have been seen yet + auto* target = wasm.getFunctionOrNull(name); + callImport = target && target->imported(); auto specific = allocator.alloc<Call>(); specific->target = name; operands = &specific->operands; @@ -2239,8 +2238,9 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { // this is important as we run the optimizer on functions before we get // to finalizeCalls (which we can only do once we've read all the functions, // and we optimize in parallel starting earlier). - callImport->type = getResultTypeOfCallUsingParent(astStackHelper.getParent(), &asmData); - noteImportedFunctionCall(ast, callImport->type, callImport); + auto* call = ret->cast<Call>(); + call->type = getResultTypeOfCallUsingParent(astStackHelper.getParent(), &asmData); + noteImportedFunctionCall(ast, call->type, call); } return ret; } @@ -2257,7 +2257,7 @@ Function* Asm2WasmBuilder::processFunction(Ref ast) { ret->fullType = fullType->name; ret->type = fullType->result; // we don't know the table offset yet. emit target = target + callImport(tableName), which we fix up later when we know how asm function tables are layed out inside the wasm table. - ret->target = builder.makeBinary(BinaryOp::AddInt32, ret->target, builder.makeCallImport(target[1]->getIString(), {}, i32)); + ret->target = builder.makeBinary(BinaryOp::AddInt32, ret->target, builder.makeCall(target[1]->getIString(), {}, i32)); return ret; } else if (what == RETURN) { Type type = !!ast[1] ? detectWasmType(ast[1], &asmData) : none; diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 45000ba95..2d082c088 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -31,6 +31,7 @@ #include "wasm-validator.h" #include "wasm2js.h" #include "cfg/Relooper.h" +#include "ir/function-type-utils.h" #include "ir/utils.h" #include "shell-interface.h" @@ -85,7 +86,7 @@ void traceNameOrNULL(const char* name) { std::map<BinaryenFunctionTypeRef, size_t> functionTypes; std::map<BinaryenExpressionRef, size_t> expressions; std::map<BinaryenFunctionRef, size_t> functions; -std::map<BinaryenImportRef, size_t> imports; +std::map<BinaryenGlobalRef, size_t> globals; std::map<BinaryenExportRef, size_t> exports; std::map<RelooperBlockRef, size_t> relooperBlocks; @@ -128,7 +129,6 @@ BinaryenExpressionId BinaryenLoopId(void) { return Expression::Id::LoopId; } BinaryenExpressionId BinaryenBreakId(void) { return Expression::Id::BreakId; } BinaryenExpressionId BinaryenSwitchId(void) { return Expression::Id::SwitchId; } BinaryenExpressionId BinaryenCallId(void) { return Expression::Id::CallId; } -BinaryenExpressionId BinaryenCallImportId(void) { return Expression::Id::CallImportId; } BinaryenExpressionId BinaryenCallIndirectId(void) { return Expression::Id::CallIndirectId; } BinaryenExpressionId BinaryenGetLocalId(void) { return Expression::Id::GetLocalId; } BinaryenExpressionId BinaryenSetLocalId(void) { return Expression::Id::SetLocalId; } @@ -174,13 +174,13 @@ void BinaryenModuleDispose(BinaryenModuleRef module) { std::cout << " functionTypes.clear();\n"; std::cout << " expressions.clear();\n"; std::cout << " functions.clear();\n"; - std::cout << " imports.clear();\n"; + std::cout << " globals.clear();\n"; std::cout << " exports.clear();\n"; std::cout << " relooperBlocks.clear();\n"; functionTypes.clear(); expressions.clear(); functions.clear(); - imports.clear(); + globals.clear(); exports.clear(); relooperBlocks.clear(); } @@ -508,31 +508,6 @@ BinaryenExpressionRef BinaryenCall(BinaryenModuleRef module, const char* target, ret->finalize(); return static_cast<Expression*>(ret); } -BinaryenExpressionRef BinaryenCallImport(BinaryenModuleRef module, const char* target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, BinaryenType returnType) { - auto* ret = ((Module*)module)->allocator.alloc<CallImport>(); - - if (tracing) { - std::cout << " {\n"; - std::cout << " BinaryenExpressionRef operands[] = { "; - for (BinaryenIndex i = 0; i < numOperands; i++) { - if (i > 0) std::cout << ", "; - std::cout << "expressions[" << expressions[operands[i]] << "]"; - } - if (numOperands == 0) std::cout << "0"; // ensure the array is not empty, otherwise a compiler error on VS - std::cout << " };\n"; - auto id = noteExpression(ret); - std::cout << " expressions[" << id << "] = BinaryenCallImport(the_module, \"" << target << "\", operands, " << numOperands << ", " << returnType << ");\n"; - std::cout << " }\n"; - } - - ret->target = target; - for (BinaryenIndex i = 0; i < numOperands; i++) { - ret->operands.push_back((Expression*)operands[i]); - } - ret->type = Type(returnType); - ret->finalize(); - return static_cast<Expression*>(ret); -} BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, const char* type) { auto* wasm = (Module*)module; auto* ret = wasm->allocator.alloc<CallIndirect>(); @@ -1056,35 +1031,6 @@ BinaryenExpressionRef BinaryenCallGetOperand(BinaryenExpressionRef expr, Binarye assert(index < static_cast<Call*>(expression)->operands.size()); return static_cast<Call*>(expression)->operands[index]; } -// CallImport -const char* BinaryenCallImportGetTarget(BinaryenExpressionRef expr) { - if (tracing) { - std::cout << " BinaryenCallImportGetTarget(expressions[" << expressions[expr] << "]);\n"; - } - - auto* expression = (Expression*)expr; - assert(expression->is<CallImport>()); - return static_cast<CallImport*>(expression)->target.c_str(); -} -BinaryenIndex BinaryenCallImportGetNumOperands(BinaryenExpressionRef expr) { - if (tracing) { - std::cout << " BinaryenCallImportGetNumOperands(expressions[" << expressions[expr] << "]);\n"; - } - - auto* expression = (Expression*)expr; - assert(expression->is<CallImport>()); - return static_cast<CallImport*>(expression)->operands.size(); -} -BinaryenExpressionRef BinaryenCallImportGetOperand(BinaryenExpressionRef expr, BinaryenIndex index) { - if (tracing) { - std::cout << " BinaryenCallImportGetOperand(expressions[" << expressions[expr] << "], " << index << ");\n"; - } - - auto* expression = (Expression*)expr; - assert(expression->is<CallImport>()); - assert(index < static_cast<CallImport*>(expression)->operands.size()); - return static_cast<CallImport*>(expression)->operands[index]; -} // CallIndirect BinaryenExpressionRef BinaryenCallIndirectGetTarget(BinaryenExpressionRef expr) { if (tracing) { @@ -1702,102 +1648,54 @@ BinaryenGlobalRef BinaryenAddGlobal(BinaryenModuleRef module, const char* name, // Imports -WASM_DEPRECATED BinaryenImportRef BinaryenAddImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef type) { - return BinaryenAddFunctionImport(module, internalName, externalModuleName, externalBaseName, type); -} -BinaryenImportRef BinaryenAddFunctionImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef functionType) { - auto* ret = new Import(); +void BinaryenAddFunctionImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef functionType) { auto* wasm = (Module*)module; + auto* ret = new Function(); if (tracing) { - auto id = imports.size(); - imports[ret] = id; - std::cout << " imports[" << id << "] = BinaryenAddFunctionImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\", functionTypes[" << functionTypes[functionType] << "]);\n"; + std::cout << " BinaryenAddFunctionImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\", functionTypes[" << functionTypes[functionType] << "]);\n"; } ret->name = internalName; ret->module = externalModuleName; ret->base = externalBaseName; - ret->functionType = ((FunctionType*)functionType)->name; - ret->kind = ExternalKind::Function; - wasm->addImport(ret); - return ret; + ret->type = ((FunctionType*)functionType)->name; + FunctionTypeUtils::fillFunction(ret, (FunctionType*)functionType); + wasm->addFunction(ret); } -BinaryenImportRef BinaryenAddTableImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName) { +void BinaryenAddTableImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName) { auto* wasm = (Module*)module; - auto* ret = new Import(); if (tracing) { - auto id = imports.size(); - imports[ret] = id; - std::cout << " imports[" << id << "] = BinaryenAddTableImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\");\n"; + std::cout << " BinaryenAddTableImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\");\n"; } - ret->name = internalName; - ret->module = externalModuleName; - ret->base = externalBaseName; - ret->kind = ExternalKind::Table; - if (wasm->table.name == ret->name) { - wasm->table.imported = true; - } - wasm->addImport(ret); - return ret; + wasm->table.module = externalModuleName; + wasm->table.base = externalBaseName; } -BinaryenImportRef BinaryenAddMemoryImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName) { +void BinaryenAddMemoryImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName) { auto* wasm = (Module*)module; - auto* ret = new Import(); if (tracing) { - auto id = imports.size(); - imports[ret] = id; - std::cout << " imports[" << id << "] = BinaryenAddMemoryImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\");\n"; + std::cout << " BinaryenAddMemoryImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\");\n"; } - ret->name = internalName; - ret->module = externalModuleName; - ret->base = externalBaseName; - ret->kind = ExternalKind::Memory; - if (wasm->memory.name == ret->name) { - wasm->memory.imported = true; - } - wasm->addImport(ret); - return ret; + wasm->memory.module = externalModuleName; + wasm->memory.base = externalBaseName; } -BinaryenImportRef BinaryenAddGlobalImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenType globalType) { +void BinaryenAddGlobalImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenType globalType) { auto* wasm = (Module*)module; - auto* ret = new Import(); + auto* ret = new Global(); if (tracing) { - auto id = imports.size(); - imports[ret] = id; - std::cout << " imports[" << id << "] = BinaryenAddGlobalImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\", " << globalType << ");\n"; + std::cout << " BinaryenAddGlobalImport(the_module, \"" << internalName << "\", \"" << externalModuleName << "\", \"" << externalBaseName << "\", " << globalType << ");\n"; } ret->name = internalName; ret->module = externalModuleName; ret->base = externalBaseName; - ret->globalType = Type(globalType); - ret->kind = ExternalKind::Global; - wasm->addImport(ret); - return ret; -} -void BinaryenRemoveImport(BinaryenModuleRef module, const char* internalName) { - if (tracing) { - std::cout << " BinaryenRemoveImport(the_module, \"" << internalName << "\");\n"; - } - - auto* wasm = (Module*)module; - auto* import = wasm->getImport(internalName); - if (import->kind == ExternalKind::Table) { - if (import->name == wasm->table.name) { - wasm->table.imported = false; - } - } else if (import->kind == ExternalKind::Memory) { - if (import->name == wasm->memory.name) { - wasm->memory.imported = false; - } - } - wasm->removeImport(internalName); + ret->type = Type(globalType); + wasm->addGlobal(ret); } // Exports @@ -2385,47 +2283,53 @@ void BinaryenFunctionSetDebugLocation(BinaryenFunctionRef func, BinaryenExpressi // =========== Import operations =========== // -BinaryenExternalKind BinaryenImportGetKind(BinaryenImportRef import) { +const char* BinaryenFunctionImportGetModule(BinaryenFunctionRef import) { if (tracing) { - std::cout << " BinaryenImportGetKind(imports[" << imports[import] << "]);\n"; + std::cout << " BinaryenFunctionImportGetModule(functions[" << functions[import] << "]);\n"; } - return BinaryenExternalKind(((Import*)import)->kind); -} -const char* BinaryenImportGetModule(BinaryenImportRef import) { - if (tracing) { - std::cout << " BinaryenImportGetModule(imports[" << imports[import] << "]);\n"; + auto* func = (Function*)import; + if (func->imported()) { + return func->module.c_str(); + } else { + return ""; } - - return ((Import*)import)->module.c_str(); } -const char* BinaryenImportGetBase(BinaryenImportRef import) { +const char* BinaryenGlobalImportGetModule(BinaryenGlobalRef import) { if (tracing) { - std::cout << " BinaryenImportGetBase(imports[" << imports[import] << "]);\n"; + std::cout << " BinaryenGlobalImportGetModule(globals[" << globals[import] << "]);\n"; } - return ((Import*)import)->base.c_str(); -} -const char* BinaryenImportGetName(BinaryenImportRef import) { - if (tracing) { - std::cout << " BinaryenImportGetName(imports[" << imports[import] << "]);\n"; + auto* global = (Global*)import; + if (global->imported()) { + return global->module.c_str(); + } else { + return ""; } - - return ((Import*)import)->name.c_str(); } -BinaryenType BinaryenImportGetGlobalType(BinaryenImportRef import) { +const char* BinaryenFunctionImportGetBase(BinaryenFunctionRef import) { if (tracing) { - std::cout << " BinaryenImportGetGlobalType(imports[" << imports[import] << "]);\n"; + std::cout << " BinaryenFunctionImportGetBase(functions[" << functions[import] << "]);\n"; } - return ((Import*)import)->globalType; + auto* func = (Function*)import; + if (func->imported()) { + return func->base.c_str(); + } else { + return ""; + } } -const char* BinaryenImportGetFunctionType(BinaryenImportRef import) { +const char* BinaryenGlobalImportGetBase(BinaryenGlobalRef import) { if (tracing) { - std::cout << " BinaryenImportGetFunctionType(imports[" << imports[import] << "]);\n"; + std::cout << " BinaryenGlobalImportGetBase(globals[" << globals[import] << "]);\n"; } - return ((Import*)import)->functionType.c_str(); + auto* global = (Global*)import; + if (global->imported()) { + return global->base.c_str(); + } else { + return ""; + } } // @@ -2557,7 +2461,7 @@ void BinaryenSetAPITracing(int on) { " std::map<size_t, BinaryenFunctionTypeRef> functionTypes;\n" " std::map<size_t, BinaryenExpressionRef> expressions;\n" " std::map<size_t, BinaryenFunctionRef> functions;\n" - " std::map<size_t, BinaryenImportRef> imports;\n" + " std::map<size_t, BinaryenGlobalRef> globals;\n" " std::map<size_t, BinaryenExportRef> exports;\n" " std::map<size_t, RelooperBlockRef> relooperBlocks;\n" " BinaryenModuleRef the_module = NULL;\n" diff --git a/src/binaryen-c.h b/src/binaryen-c.h index 95892443d..9a2750353 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -96,7 +96,6 @@ BinaryenExpressionId BinaryenLoopId(void); BinaryenExpressionId BinaryenBreakId(void); BinaryenExpressionId BinaryenSwitchId(void); BinaryenExpressionId BinaryenCallId(void); -BinaryenExpressionId BinaryenCallImportId(void); BinaryenExpressionId BinaryenCallIndirectId(void); BinaryenExpressionId BinaryenGetLocalId(void); BinaryenExpressionId BinaryenSetLocalId(void); @@ -339,15 +338,11 @@ BinaryenExpressionRef BinaryenLoop(BinaryenModuleRef module, const char* in, Bin BinaryenExpressionRef BinaryenBreak(BinaryenModuleRef module, const char* name, BinaryenExpressionRef condition, BinaryenExpressionRef value); // Switch: value can be NULL BinaryenExpressionRef BinaryenSwitch(BinaryenModuleRef module, const char** names, BinaryenIndex numNames, const char* defaultName, BinaryenExpressionRef condition, BinaryenExpressionRef value); -// Call, CallImport: Note the 'returnType' parameter. You must declare the -// type returned by the function being called, as that -// function might not have been created yet, so we don't -// know what it is. -// Also note that WebAssembly does not differentiate -// between Call and CallImport, but Binaryen does, so you -// must use CallImport if calling an import, and vice versa. +// Call: Note the 'returnType' parameter. You must declare the +// type returned by the function being called, as that +// function might not have been created yet, so we don't +// know what it is. BinaryenExpressionRef BinaryenCall(BinaryenModuleRef module, const char* target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, BinaryenType returnType); -BinaryenExpressionRef BinaryenCallImport(BinaryenModuleRef module, const char* target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, BinaryenType returnType); BinaryenExpressionRef BinaryenCallIndirect(BinaryenModuleRef module, BinaryenExpressionRef target, BinaryenExpressionRef* operands, BinaryenIndex numOperands, const char* type); // GetLocal: Note the 'type' parameter. It might seem redundant, since the // local at that index must have a type. However, this API lets you @@ -441,13 +436,6 @@ BinaryenIndex BinaryenCallGetNumOperands(BinaryenExpressionRef expr); // Gets the nested operand expression at the specified index within the specified `Call` expression. BinaryenExpressionRef BinaryenCallGetOperand(BinaryenExpressionRef expr, BinaryenIndex index); -// Gets the name of the target of the specified `CallImport` expression. -const char* BinaryenCallImportGetTarget(BinaryenExpressionRef expr); -// Gets the number of nested operand expressions within the specified `CallImport` expression. -BinaryenIndex BinaryenCallImportGetNumOperands(BinaryenExpressionRef expr); -// Gets the nested operand expression at the specified index within the specified `CallImport` expression. -BinaryenExpressionRef BinaryenCallImportGetOperand(BinaryenExpressionRef expr, BinaryenIndex index); - // Gets the nested target expression of the specified `CallIndirect` expression. BinaryenExpressionRef BinaryenCallIndirectGetTarget(BinaryenExpressionRef expr); // Gets the number of nested operand expressions within the specified `CallIndirect` expression. @@ -604,14 +592,10 @@ void BinaryenRemoveFunction(BinaryenModuleRef module, const char* name); // Imports -typedef void* BinaryenImportRef; - -WASM_DEPRECATED BinaryenImportRef BinaryenAddImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef type); -BinaryenImportRef BinaryenAddFunctionImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef functionType); -BinaryenImportRef BinaryenAddTableImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName); -BinaryenImportRef BinaryenAddMemoryImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName); -BinaryenImportRef BinaryenAddGlobalImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenType globalType); -void BinaryenRemoveImport(BinaryenModuleRef module, const char* internalName); +void BinaryenAddFunctionImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenFunctionTypeRef functionType); +void BinaryenAddTableImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName); +void BinaryenAddMemoryImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName); +void BinaryenAddGlobalImport(BinaryenModuleRef module, const char* internalName, const char* externalModuleName, const char* externalBaseName, BinaryenType globalType); // Exports @@ -790,18 +774,12 @@ void BinaryenFunctionSetDebugLocation(BinaryenFunctionRef func, BinaryenExpressi // ========== Import Operations ========== // -// Gets the external kind of the specified import. -BinaryenExternalKind BinaryenImportGetKind(BinaryenImportRef import); // Gets the external module name of the specified import. -const char* BinaryenImportGetModule(BinaryenImportRef import); +const char* BinaryenFunctionImportGetModule(BinaryenFunctionRef import); +const char* BinaryeGlobalImportGetModule(BinaryenGlobalRef import); // Gets the external base name of the specified import. -const char* BinaryenImportGetBase(BinaryenImportRef import); -// Gets the internal name of the specified import. -const char* BinaryenImportGetName(BinaryenImportRef import); -// Gets the type of the imported global, if referencing a `Global`. -BinaryenType BinaryenImportGetGlobalType(BinaryenImportRef import); -// Gets the name of the function type of the imported function, if referencing a `Function`. -const char* BinaryenImportGetFunctionType(BinaryenImportRef import); +const char* BinaryenFunctionImportGetBase(BinaryenFunctionRef import); +const char* BinaryenGlobalImportGetBase(BinaryenGlobalRef import); // // ========== Export Operations ========== diff --git a/src/ir/ExpressionAnalyzer.cpp b/src/ir/ExpressionAnalyzer.cpp index d91bc370c..44abc2f64 100644 --- a/src/ir/ExpressionAnalyzer.cpp +++ b/src/ir/ExpressionAnalyzer.cpp @@ -170,14 +170,6 @@ bool ExpressionAnalyzer::flexibleEqual(Expression* left, Expression* right, Expr } break; } - case Expression::Id::CallImportId: { - CHECK(CallImport, target); - CHECK(CallImport, operands.size()); - for (Index i = 0; i < left->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } case Expression::Id::CallIndirectId: { PUSH(CallIndirect, target); CHECK(CallIndirect, fullType); @@ -423,14 +415,6 @@ uint32_t ExpressionAnalyzer::hash(Expression* curr) { } break; } - case Expression::Id::CallImportId: { - HASH_NAME(CallImport, target); - HASH(CallImport, operands.size()); - for (Index i = 0; i < curr->cast<CallImport>()->operands.size(); i++) { - PUSH(CallImport, operands[i]); - } - break; - } case Expression::Id::CallIndirectId: { PUSH(CallIndirect, target); HASH_NAME(CallIndirect, fullType); diff --git a/src/ir/ExpressionManipulator.cpp b/src/ir/ExpressionManipulator.cpp index a3bff2d82..d65509c52 100644 --- a/src/ir/ExpressionManipulator.cpp +++ b/src/ir/ExpressionManipulator.cpp @@ -63,13 +63,6 @@ Expression* flexibleCopy(Expression* original, Module& wasm, CustomCopier custom } return ret; } - Expression* visitCallImport(CallImport *curr) { - auto* ret = builder.makeCallImport(curr->target, {}, curr->type); - for (Index i = 0; i < curr->operands.size(); i++) { - ret->operands.push_back(copy(curr->operands[i])); - } - return ret; - } Expression* visitCallIndirect(CallIndirect *curr) { auto* ret = builder.makeCallIndirect(curr->fullType, copy(curr->target), {}, curr->type); for (Index i = 0; i < curr->operands.size(); i++) { diff --git a/src/ir/cost.h b/src/ir/cost.h index 9a97574f4..6d1078094 100644 --- a/src/ir/cost.h +++ b/src/ir/cost.h @@ -51,15 +51,12 @@ struct CostAnalyzer : public Visitor<CostAnalyzer, Index> { return 2 + visit(curr->condition) + maybeVisit(curr->value); } Index visitCall(Call *curr) { + // XXX this does not take into account if the call is to an import, which + // may be costlier in general Index ret = 4; for (auto* child : curr->operands) ret += visit(child); return ret; } - Index visitCallImport(CallImport *curr) { - Index ret = 15; - for (auto* child : curr->operands) ret += visit(child); - return ret; - } Index visitCallIndirect(CallIndirect *curr) { Index ret = 6 + visit(curr->target); for (auto* child : curr->operands) ret += visit(child); diff --git a/src/ir/effects.h b/src/ir/effects.h index 98687c0dc..8919c6da4 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -176,13 +176,13 @@ struct EffectAnalyzer : public PostWalker<EffectAnalyzer> { } } - void visitCall(Call *curr) { calls = true; } - void visitCallImport(CallImport *curr) { + void visitCall(Call *curr) { calls = true; if (debugInfo) { // debugInfo call imports must be preserved very strongly, do not // move code around them - branches = true; // ! + // FIXME: we could check if the call is to an import + branches = true; } } void visitCallIndirect(CallIndirect *curr) { calls = true; } diff --git a/src/ir/function-type-utils.h b/src/ir/function-type-utils.h new file mode 100644 index 000000000..3c98cb16b --- /dev/null +++ b/src/ir/function-type-utils.h @@ -0,0 +1,35 @@ +/* + * Copyright 2018 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_function_type_utils_h +#define wasm_ir_function_type_utils_h + + +namespace wasm { + +namespace FunctionTypeUtils { + +// Fills in function info from a function type +inline void fillFunction(Function* func, FunctionType* type) { + func->params = type->params; + func->result = type->result; +} + +} // namespace FunctionTypeUtils + +} // namespace wasm + +#endif // wasm_ir_function_type_utils_h diff --git a/src/ir/function-utils.h b/src/ir/function-utils.h index 317b2f1b1..3895b33bf 100644 --- a/src/ir/function-utils.h +++ b/src/ir/function-utils.h @@ -35,7 +35,10 @@ inline bool equal(Function* left, Function* right) { } if (left->result != right->result) return false; if (left->type != right->type) return false; - return ExpressionAnalyzer::equal(left->body, right->body); + if (!left->imported() && !right->imported()) { + return ExpressionAnalyzer::equal(left->body, right->body); + } + return left->imported() && right->imported(); } } // namespace FunctionUtils diff --git a/src/ir/global-utils.h b/src/ir/global-utils.h index bcf0dae72..02bbbf2d2 100644 --- a/src/ir/global-utils.h +++ b/src/ir/global-utils.h @@ -22,6 +22,7 @@ #include "literal.h" #include "wasm.h" +#include "ir/module-utils.h" namespace wasm { @@ -30,22 +31,22 @@ namespace GlobalUtils { inline Global* getGlobalInitializedToImport(Module& wasm, Name module, Name base) { // find the import Name imported; - for (auto& import : wasm.imports) { + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { if (import->module == module && import->base == base) { imported = import->name; - break; } - } + }); if (imported.isNull()) return nullptr; // find a global inited to it - for (auto& global : wasm.globals) { - if (auto* init = global->init->dynCast<GetGlobal>()) { + Global* ret = nullptr; + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* defined) { + if (auto* init = defined->init->dynCast<GetGlobal>()) { if (init->name == imported) { - return global.get(); + ret = defined; } } - } - return nullptr; + }); + return ret; } }; diff --git a/src/ir/import-utils.h b/src/ir/import-utils.h index f3f01c266..5f7bbc9ed 100644 --- a/src/ir/import-utils.h +++ b/src/ir/import-utils.h @@ -22,17 +22,66 @@ namespace wasm { -namespace ImportUtils { - // find an import by the module.base that is being imported. - // return the internal name - inline Import* getImport(Module& wasm, Name module, Name base) { - for (auto& import : wasm.imports) { +// Collects info on imports, into a form convenient for summarizing +// and searching. +struct ImportInfo { + Module& wasm; + + std::vector<Global*> importedGlobals; + std::vector<Function*> importedFunctions; + + ImportInfo(Module& wasm) : wasm(wasm) { + for (auto& import : wasm.globals) { + if (import->imported()) { + importedGlobals.push_back(import.get()); + } + } + for (auto& import : wasm.functions) { + if (import->imported()) { + importedFunctions.push_back(import.get()); + } + } + } + + Global* getImportedGlobal(Name module, Name base) { + for (auto* import : importedGlobals) { + if (import->module == module && import->base == base) { + return import; + } + } + return nullptr; + } + + Function* getImportedFunction(Name module, Name base) { + for (auto* import : importedFunctions) { if (import->module == module && import->base == base) { - return import.get(); + return import; } } return nullptr; } + + Index getNumImportedGlobals() { + return importedGlobals.size(); + } + + Index getNumImportedFunctions() { + return importedFunctions.size(); + } + + Index getNumImports() { + return getNumImportedGlobals() + getNumImportedFunctions() + + (wasm.memory.imported() ? 1 : 0) + + (wasm.table.imported() ? 1 : 0); + } + + Index getNumDefinedGlobals() { + return wasm.globals.size() - getNumImportedGlobals(); + } + + Index getNumDefinedFunctions() { + return wasm.functions.size() - getNumImportedFunctions(); + } }; } // namespace wasm diff --git a/src/ir/module-utils.h b/src/ir/module-utils.h index 4205514e5..1d0512401 100644 --- a/src/ir/module-utils.h +++ b/src/ir/module-utils.h @@ -25,30 +25,45 @@ namespace wasm { namespace ModuleUtils { // Computes the indexes in a wasm binary, i.e., with function imports -// and function implementations sharing a single index space, etc. +// and function implementations sharing a single index space, etc., +// and with the imports first (the Module's functions and globals +// arrays are not assumed to be in a particular order, so we can't +// just use them directly). struct BinaryIndexes { std::unordered_map<Name, Index> functionIndexes; std::unordered_map<Name, Index> globalIndexes; BinaryIndexes(Module& wasm) { - for (Index i = 0; i < wasm.imports.size(); i++) { - auto& import = wasm.imports[i]; - if (import->kind == ExternalKind::Function) { - auto index = functionIndexes.size(); - functionIndexes[import->name] = index; - } else if (import->kind == ExternalKind::Global) { - auto index = globalIndexes.size(); - globalIndexes[import->name] = index; + auto addGlobal = [&](Global* curr) { + auto index = globalIndexes.size(); + globalIndexes[curr->name] = index; + }; + for (auto& curr : wasm.globals) { + if (curr->imported()) { + addGlobal(curr.get()); + } + } + for (auto& curr : wasm.globals) { + if (!curr->imported()) { + addGlobal(curr.get()); } } - for (Index i = 0; i < wasm.functions.size(); i++) { + assert(globalIndexes.size() == wasm.globals.size()); + auto addFunction = [&](Function* curr) { auto index = functionIndexes.size(); - functionIndexes[wasm.functions[i]->name] = index; + functionIndexes[curr->name] = index; + }; + for (auto& curr : wasm.functions) { + if (curr->imported()) { + addFunction(curr.get()); + } } - for (Index i = 0; i < wasm.globals.size(); i++) { - auto index = globalIndexes.size(); - globalIndexes[wasm.globals[i]->name] = index; + for (auto& curr : wasm.functions) { + if (!curr->imported()) { + addFunction(curr.get()); + } } + assert(functionIndexes.size() == wasm.functions.size()); } }; @@ -63,21 +78,34 @@ inline Function* copyFunction(Function* func, Module& out) { ret->localIndices = func->localIndices; ret->debugLocations = func->debugLocations; ret->body = ExpressionManipulator::copy(func->body, out); + ret->module = func->module; + ret->base = func->base; // TODO: copy Stack IR assert(!func->stackIR); out.addFunction(ret); return ret; } +inline Global* copyGlobal(Global* global, Module& out) { + auto* ret = new Global(); + ret->name = global->name; + ret->type = global->type; + ret->mutable_ = global->mutable_; + if (global->imported()) { + ret->init = nullptr; + } else { + ret->init = ExpressionManipulator::copy(global->init, out); + } + out.addGlobal(ret); + return ret; +} + inline void copyModule(Module& in, Module& out) { // we use names throughout, not raw points, so simple copying is fine // for everything *but* expressions for (auto& curr : in.functionTypes) { out.addFunctionType(new FunctionType(*curr)); } - for (auto& curr : in.imports) { - out.addImport(new Import(*curr)); - } for (auto& curr : in.exports) { out.addExport(new Export(*curr)); } @@ -85,7 +113,7 @@ inline void copyModule(Module& in, Module& out) { copyFunction(curr.get(), out); } for (auto& curr : in.globals) { - out.addGlobal(new Global(*curr)); + copyGlobal(curr.get(), out); } out.table = in.table; for (auto& segment : out.table.segments) { @@ -100,6 +128,44 @@ inline void copyModule(Module& in, Module& out) { out.debugInfoFileNames = in.debugInfoFileNames; } +// Convenient iteration over imported/non-imported functions/globals + +template<typename T> +inline void iterImportedGlobals(Module& wasm, T visitor) { + for (auto& import : wasm.globals) { + if (import->imported()) { + visitor(import.get()); + } + } +} + +template<typename T> +inline void iterDefinedGlobals(Module& wasm, T visitor) { + for (auto& import : wasm.globals) { + if (!import->imported()) { + visitor(import.get()); + } + } +} + +template<typename T> +inline void iterImportedFunctions(Module& wasm, T visitor) { + for (auto& import : wasm.functions) { + if (import->imported()) { + visitor(import.get()); + } + } +} + +template<typename T> +inline void iterDefinedFunctions(Module& wasm, T visitor) { + for (auto& import : wasm.functions) { + if (!import->imported()) { + visitor(import.get()); + } + } +} + } // namespace ModuleUtils } // namespace wasm diff --git a/src/ir/trapping.h b/src/ir/trapping.h index a3a87f8ef..48f485faa 100644 --- a/src/ir/trapping.h +++ b/src/ir/trapping.h @@ -57,10 +57,10 @@ public: wasm.addFunction(function); } } - void addImport(Import* import) { + void addImport(Function* import) { imports[import->name] = import; if (immediate) { - wasm.addImport(import); + wasm.addFunction(import); } } @@ -70,7 +70,7 @@ public: wasm.addFunction(pair.second); } for (auto &pair : imports) { - wasm.addImport(pair.second); + wasm.addFunction(pair.second); } } functions.clear(); @@ -91,7 +91,7 @@ public: private: std::map<Name, Function*> functions; - std::map<Name, Import*> imports; + std::map<Name, Function*> imports; TrapMode mode; Module& wasm; diff --git a/src/ir/utils.h b/src/ir/utils.h index 61d8917be..5eff0a52d 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -158,7 +158,6 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R updateBreakValueType(curr->default_, valueType); } void visitCall(Call* curr) { curr->finalize(); } - void visitCallImport(CallImport* curr) { curr->finalize(); } void visitCallIndirect(CallIndirect* curr) { curr->finalize(); } void visitGetLocal(GetLocal* curr) { curr->finalize(); } void visitSetLocal(SetLocal* curr) { curr->finalize(); } @@ -190,7 +189,6 @@ struct ReFinalize : public WalkerPass<PostWalker<ReFinalize, OverriddenVisitor<R } void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } - void visitImport(Import* curr) { WASM_UNREACHABLE(); } void visitExport(Export* curr) { WASM_UNREACHABLE(); } void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } void visitTable(Table* curr) { WASM_UNREACHABLE(); } @@ -217,7 +215,6 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitBreak(Break* curr) { curr->finalize(); } void visitSwitch(Switch* curr) { curr->finalize(); } void visitCall(Call* curr) { curr->finalize(); } - void visitCallImport(CallImport* curr) { curr->finalize(); } void visitCallIndirect(CallIndirect* curr) { curr->finalize(); } void visitGetLocal(GetLocal* curr) { curr->finalize(); } void visitSetLocal(SetLocal* curr) { curr->finalize(); } @@ -240,7 +237,6 @@ struct ReFinalizeNode : public OverriddenVisitor<ReFinalizeNode> { void visitUnreachable(Unreachable* curr) { curr->finalize(); } void visitFunctionType(FunctionType* curr) { WASM_UNREACHABLE(); } - void visitImport(Import* curr) { WASM_UNREACHABLE(); } void visitExport(Export* curr) { WASM_UNREACHABLE(); } void visitGlobal(Global* curr) { WASM_UNREACHABLE(); } void visitTable(Table* curr) { WASM_UNREACHABLE(); } diff --git a/src/js/binaryen.js-post.js b/src/js/binaryen.js-post.js index 6b5aab0fc..6b4a0ae16 100644 --- a/src/js/binaryen.js-post.js +++ b/src/js/binaryen.js-post.js @@ -39,7 +39,6 @@ Module['LoopId'] = Module['_BinaryenLoopId'](); Module['BreakId'] = Module['_BinaryenBreakId'](); Module['SwitchId'] = Module['_BinaryenSwitchId'](); Module['CallId'] = Module['_BinaryenCallId'](); -Module['CallImportId'] = Module['_BinaryenCallImportId'](); Module['CallIndirectId'] = Module['_BinaryenCallIndirectId'](); Module['GetLocalId'] = Module['_BinaryenGetLocalId'](); Module['SetLocalId'] = Module['_BinaryenSetLocalId'](); @@ -250,11 +249,6 @@ Module['Module'] = function(module) { return Module['_BinaryenCall'](module, strToStack(name), i32sToStack(operands), operands.length, type); }); }; - this['callImport'] = this['call_import'] = function(name, operands, type) { - return preserveStack(function() { - return Module['_BinaryenCallImport'](module, strToStack(name), i32sToStack(operands), operands.length, type); - }); - }; this['callIndirect'] = this['call_indirect'] = function(target, operands, type) { return preserveStack(function() { return Module['_BinaryenCallIndirect'](module, target, i32sToStack(operands), operands.length, strToStack(type)); @@ -1072,7 +1066,6 @@ Module['Module'] = function(module) { return Module['_BinaryenAddGlobal'](module, strToStack(name), type, mutable, init); }); } - this['addImport'] = // deprecated this['addFunctionImport'] = function(internalName, externalModuleName, externalBaseName, functionType) { return preserveStack(function() { return Module['_BinaryenAddFunctionImport'](module, strToStack(internalName), strToStack(externalModuleName), strToStack(externalBaseName), functionType); @@ -1093,11 +1086,6 @@ Module['Module'] = function(module) { return Module['_BinaryenAddGlobalImport'](module, strToStack(internalName), strToStack(externalModuleName), strToStack(externalBaseName), globalType); }); }; - this['removeImport'] = function(internalName) { - return preserveStack(function() { - return Module['_BinaryenRemoveImport'](module, strToStack(internalName)); - }); - }; this['addExport'] = // deprecated this['addFunctionExport'] = function(internalName, externalName) { return preserveStack(function() { @@ -1332,13 +1320,6 @@ Module['getExpressionInfo'] = function(expr) { 'target': Pointer_stringify(Module['_BinaryenCallGetTarget'](expr)), 'operands': getAllNested(expr, Module[ '_BinaryenCallGetNumOperands'], Module['_BinaryenCallGetOperand']) }; - case Module['CallImportId']: - return { - 'id': id, - 'type': type, - 'target': Pointer_stringify(Module['_BinaryenCallImportGetTarget'](expr)), - 'operands': getAllNested(expr, Module['_BinaryenCallImportGetNumOperands'], Module['_BinaryenCallImportGetOperand']), - }; case Module['CallIndirectId']: return { 'id': id, @@ -1513,6 +1494,8 @@ Module['getFunctionTypeInfo'] = function(func) { Module['getFunctionInfo'] = function(func) { return { 'name': Pointer_stringify(Module['_BinaryenFunctionGetName'](func)), + 'module': Pointer_stringify(Module['_BinaryenFunctionImportGetModule'](func)), + 'base': Pointer_stringify(Module['_BinaryenFunctionImportGetBase'](func)), 'type': Pointer_stringify(Module['_BinaryenFunctionGetType'](func)), 'params': getAllNested(func, Module['_BinaryenFunctionGetNumParams'], Module['_BinaryenFunctionGetParam']), 'result': Module['_BinaryenFunctionGetResult'](func), @@ -1521,15 +1504,13 @@ Module['getFunctionInfo'] = function(func) { }; }; -// Obtains information about an 'Import' -Module['getImportInfo'] = function(import_) { +// Obtains information about a 'Global' +Module['getGlobalInfo'] = function(func) { return { - 'kind': Module['_BinaryenImportGetKind'](import_), - 'module': Pointer_stringify(Module['_BinaryenImportGetModule'](import_)), - 'base': Pointer_stringify(Module['_BinaryenImportGetBase'](import_)), - 'name': Pointer_stringify(Module['_BinaryenImportGetName'](import_)), - 'globalType': Module['_BinaryenImportGetGlobalType'](import_), - 'functionType': Pointer_stringify(Module['_BinaryenImportGetFunctionType'](import_)) + 'name': Pointer_stringify(Module['_BinaryenGlobalGetName'](func)), + 'module': Pointer_stringify(Module['_BinaryenGlobalImportGetModule'](func)), + 'base': Pointer_stringify(Module['_BinaryenGlobalImportGetBase'](func)), + 'type': Pointer_stringify(Module['_BinaryenGlobalGetType'](func)) }; }; diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index b1a1008f4..434ea1ece 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -36,13 +36,14 @@ #include <unordered_map> #include <unordered_set> -#include <wasm.h> -#include <pass.h> -#include <wasm-builder.h> -#include <cfg/cfg-traversal.h> -#include <ir/effects.h> -#include <passes/opt-utils.h> -#include <support/sorted_vector.h> +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "cfg/cfg-traversal.h" +#include "ir/effects.h" +#include "ir/module-utils.h" +#include "passes/opt-utils.h" +#include "support/sorted_vector.h" namespace wasm { @@ -110,7 +111,9 @@ struct DAEScanner : public WalkerPass<CFGWalker<DAEScanner, Visitor<DAEScanner>, } void visitCall(Call* curr) { - info->calls[curr->target].push_back(curr); + if (!getModule()->getFunction(curr->target)->imported()) { + info->calls[curr->target].push_back(curr); + } } // main entry point @@ -196,9 +199,9 @@ struct DAE : public Pass { void run(PassRunner* runner, Module* module) override { DAEFunctionInfoMap infoMap; // Ensure they all exist so the parallel threads don't modify the data structure. - for (auto& func : module->functions) { + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { infoMap[func->name]; - } + }); // Check the influence of the table and exports. for (auto& curr : module->exports) { if (curr->kind == ExternalKind::Function) { diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index 1879b1fc8..87dbacf64 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -237,7 +237,6 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination>> case Expression::Id::BreakId: DELEGATE(Break); case Expression::Id::SwitchId: DELEGATE(Switch); case Expression::Id::CallId: DELEGATE(Call); - case Expression::Id::CallImportId: DELEGATE(CallImport); case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); case Expression::Id::GetLocalId: DELEGATE(GetLocal); case Expression::Id::SetLocalId: DELEGATE(SetLocal); @@ -312,10 +311,6 @@ struct DeadCodeElimination : public WalkerPass<PostWalker<DeadCodeElimination>> handleCall(curr); } - void visitCallImport(CallImport* curr) { - handleCall(curr); - } - void visitCallIndirect(CallIndirect* curr) { if (handleCall(curr) != curr) return; if (isUnreachable(curr->target)) { diff --git a/src/passes/DuplicateFunctionElimination.cpp b/src/passes/DuplicateFunctionElimination.cpp index 60667cdd6..b7fcb556c 100644 --- a/src/passes/DuplicateFunctionElimination.cpp +++ b/src/passes/DuplicateFunctionElimination.cpp @@ -25,6 +25,7 @@ #include "ir/utils.h" #include "ir/function-utils.h" #include "ir/hashed.h" +#include "ir/module-utils.h" namespace wasm { @@ -72,9 +73,9 @@ struct DuplicateFunctionElimination : public Pass { hasherRunner.run(); // Find hash-equal groups std::map<uint32_t, std::vector<Function*>> hashGroups; - for (auto& func : module->functions) { - hashGroups[hashes[func.get()]].push_back(func.get()); - } + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { + hashGroups[hashes[func]].push_back(func); + }); // Find actually equal functions and prepare to replace them std::map<Name, Name> replacements; std::set<Name> duplicates; diff --git a/src/passes/FuncCastEmulation.cpp b/src/passes/FuncCastEmulation.cpp index 013e9403e..30ae3e5e8 100644 --- a/src/passes/FuncCastEmulation.cpp +++ b/src/passes/FuncCastEmulation.cpp @@ -200,18 +200,15 @@ private: Fatal() << "FuncCastEmulation::makeThunk seems a thunk name already in use. Was the pass already run on this code?"; } // The item in the table may be a function or a function import. - auto* func = module->getFunctionOrNull(name); - Import* imp = nullptr; - if (!func) imp = module->getImport(name); - std::vector<Type>& params = func ? func->params : module->getFunctionType(imp->functionType)->params; - Type type = func ? func->result : module->getFunctionType(imp->functionType)->result; + auto* func = module->getFunction(name); + std::vector<Type>& params = func->params; + Type type = func->result; Builder builder(*module); std::vector<Expression*> callOperands; for (Index i = 0; i < params.size(); i++) { callOperands.push_back(fromABI(builder.makeGetLocal(i, i64), params[i], module)); } - Expression* call = func ? (Expression*)builder.makeCall(name, callOperands, type) - : (Expression*)builder.makeCallImport(name, callOperands, type); + auto* call = builder.makeCall(name, callOperands, type); std::vector<Type> thunkParams; for (Index i = 0; i < NUM_PARAMS; i++) { thunkParams.push_back(i64); diff --git a/src/passes/I64ToI32Lowering.cpp b/src/passes/I64ToI32Lowering.cpp index 2986deb9f..697df1eef 100644 --- a/src/passes/I64ToI32Lowering.cpp +++ b/src/passes/I64ToI32Lowering.cpp @@ -105,7 +105,7 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { auto& curr = module->globals[i]; if (curr->type != i64) continue; curr->type = i32; - auto* high = new Global(*curr); + auto* high = ModuleUtils::copyGlobal(curr.get(), *module); high->name = makeHighName(curr->name); module->addGlobal(high); } @@ -361,11 +361,6 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> { ); } - void visitCallImport(CallImport* curr) { - // imports cannot contain i64s - return; - } - void visitCallIndirect(CallIndirect* curr) { visitGenericCall<CallIndirect>( curr, diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index a507e5fc8..ebc33bf97 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -32,13 +32,14 @@ #include <atomic> -#include <wasm.h> -#include <pass.h> -#include <wasm-builder.h> -#include <ir/utils.h> -#include <ir/literal-utils.h> -#include <parsing.h> -#include <passes/opt-utils.h> +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/literal-utils.h" +#include "ir/module-utils.h" +#include "ir/utils.h" +#include "parsing.h" +#include "passes/opt-utils.h" namespace wasm { @@ -117,11 +118,6 @@ struct FunctionInfoScanner : public WalkerPass<PostWalker<FunctionInfoScanner>> (*infos)[getFunction()->name].lightweight = false; } - void visitCallImport(CallImport* curr) { - // having a call is not lightweight - (*infos)[getFunction()->name].lightweight = false; - } - void visitFunction(Function* curr) { (*infos)[curr->name].size = Measurer::measure(curr->body); } @@ -279,9 +275,7 @@ struct Inlining : public Pass { } for (auto& segment : module->table.segments) { for (auto name : segment.data) { - if (module->getFunctionOrNull(name)) { - infos[name].usedGlobally = true; - } + infos[name].usedGlobally = true; } } } @@ -289,12 +283,11 @@ struct Inlining : public Pass { bool iteration(PassRunner* runner, Module* module) { // decide which to inline InliningState state; - for (auto& func : module->functions) { - // on the first iteration, allow multiple inlinings per function + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { if (infos[func->name].worthInlining(runner->options)) { state.worthInlining.insert(func->name); } - } + }); if (state.worthInlining.size() == 0) return false; // fill in actionsForFunction, as we operate on it in parallel (each function to its own entry) for (auto& func : module->functions) { diff --git a/src/passes/InstrumentLocals.cpp b/src/passes/InstrumentLocals.cpp index 22b8ebf70..4da41da17 100644 --- a/src/passes/InstrumentLocals.cpp +++ b/src/passes/InstrumentLocals.cpp @@ -49,6 +49,7 @@ #include "shared-constants.h" #include "asmjs/shared-constants.h" #include "asm_v_wasm.h" +#include "ir/function-type-utils.h" namespace wasm { @@ -74,7 +75,7 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { default: WASM_UNREACHABLE(); } replaceCurrent( - builder.makeCallImport( + builder.makeCall( import, { builder.makeConst(Literal(int32_t(id++))), @@ -97,7 +98,7 @@ struct InstrumentLocals : public WalkerPass<PostWalker<InstrumentLocals>> { case unreachable: return; // nothing to do here default: WASM_UNREACHABLE(); } - curr->value = builder.makeCallImport( + curr->value = builder.makeCall( import, { builder.makeConst(Literal(int32_t(id++))), @@ -123,13 +124,14 @@ private: Index id = 0; void addImport(Module* wasm, Name name, std::string sig) { - auto import = new Import; + auto import = new Function; import->name = name; import->module = INSTRUMENT; import->base = name; - import->functionType = ensureFunctionType(sig, wasm)->name; - import->kind = ExternalKind::Function; - wasm->addImport(import); + auto* functionType = ensureFunctionType(sig, wasm); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + wasm->addFunction(import); } }; diff --git a/src/passes/InstrumentMemory.cpp b/src/passes/InstrumentMemory.cpp index d9a5a4316..17b5850f4 100644 --- a/src/passes/InstrumentMemory.cpp +++ b/src/passes/InstrumentMemory.cpp @@ -61,6 +61,7 @@ #include "shared-constants.h" #include "asmjs/shared-constants.h" #include "asm_v_wasm.h" +#include "ir/function-type-utils.h" namespace wasm { @@ -76,13 +77,14 @@ struct InstrumentMemory : public WalkerPass<PostWalker<InstrumentMemory>> { makeStoreCall(curr); } void addImport(Module *curr, Name name, std::string sig) { - auto import = new Import; + auto import = new Function; import->name = name; import->module = INSTRUMENT; import->base = name; - import->functionType = ensureFunctionType(sig, curr)->name; - import->kind = ExternalKind::Function; - curr->addImport(import); + auto* functionType = ensureFunctionType(sig, curr); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + curr->addFunction(import); } void visitModule(Module *curr) { @@ -94,7 +96,7 @@ private: std::atomic<Index> id; Expression* makeLoadCall(Load* curr) { Builder builder(*getModule()); - curr->ptr = builder.makeCallImport(load, + curr->ptr = builder.makeCall(load, { builder.makeConst(Literal(int32_t(id.fetch_add(1)))), builder.makeConst(Literal(int32_t(curr->bytes))), builder.makeConst(Literal(int32_t(curr->offset.addr))), @@ -106,7 +108,7 @@ private: Expression* makeStoreCall(Store* curr) { Builder builder(*getModule()); - curr->ptr = builder.makeCallImport(store, + curr->ptr = builder.makeCall(store, { builder.makeConst(Literal(int32_t(id.fetch_add(1)))), builder.makeConst(Literal(int32_t(curr->bytes))), builder.makeConst(Literal(int32_t(curr->offset.addr))), diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 1f02cd4d7..5f64ad2e6 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -26,11 +26,12 @@ // disallow f32s. TODO: an option to not do that, if it matters? // -#include <wasm.h> -#include <pass.h> -#include <wasm-builder.h> -#include <ir/utils.h> -#include <ir/literal-utils.h> +#include "wasm.h" +#include "pass.h" +#include "wasm-builder.h" +#include "ir/function-type-utils.h" +#include "ir/literal-utils.h" +#include "ir/utils.h" namespace wasm { @@ -44,22 +45,23 @@ struct LegalizeJSInterface : public Pass { for (auto& ex : module->exports) { if (ex->kind == ExternalKind::Function) { // if it's an import, ignore it - if (auto* func = module->getFunctionOrNull(ex->value)) { - if (isIllegal(func)) { - auto legalName = makeLegalStub(func, module); - ex->value = legalName; - } + auto* func = module->getFunction(ex->value); + if (isIllegal(func)) { + auto legalName = makeLegalStub(func, module); + ex->value = legalName; } } } + // Avoid iterator invalidation later. + std::vector<Function*> originalFunctions; + for (auto& func : module->functions) { + originalFunctions.push_back(func.get()); + } // for each illegal import, we must call a legalized stub instead - std::vector<Import*> newImports; // add them at the end, to not invalidate the iter - for (auto& im : module->imports) { - if (im->kind == ExternalKind::Function && isIllegal(module->getFunctionType(im->functionType))) { - Name funcName; - auto* legal = makeLegalStub(im.get(), module, funcName); - illegalToLegal[im->name] = funcName; - newImports.push_back(legal); + for (auto* im : originalFunctions) { + if (im->imported() && isIllegal(module->getFunctionType(im->type))) { + auto funcName = makeLegalStubForCalledImport(im, module); + illegalImportsToLegal[im->name] = funcName; // we need to use the legalized version in the table, as the import from JS // is legal for JS. Our stub makes it look like a native wasm function. for (auto& segment : module->table.segments) { @@ -71,13 +73,9 @@ struct LegalizeJSInterface : public Pass { } } } - if (illegalToLegal.size() > 0) { - for (auto& pair : illegalToLegal) { - module->removeImport(pair.first); - } - - for (auto* im : newImports) { - module->addImport(im); + if (illegalImportsToLegal.size() > 0) { + for (auto& pair : illegalImportsToLegal) { + module->removeFunction(pair.first); } // fix up imports: call_import of an illegal must be turned to a call of a legal @@ -85,15 +83,15 @@ struct LegalizeJSInterface : public Pass { struct FixImports : public WalkerPass<PostWalker<FixImports>> { bool isFunctionParallel() override { return true; } - Pass* create() override { return new FixImports(illegalToLegal); } + Pass* create() override { return new FixImports(illegalImportsToLegal); } - std::map<Name, Name>* illegalToLegal; + std::map<Name, Name>* illegalImportsToLegal; - FixImports(std::map<Name, Name>* illegalToLegal) : illegalToLegal(illegalToLegal) {} + FixImports(std::map<Name, Name>* illegalImportsToLegal) : illegalImportsToLegal(illegalImportsToLegal) {} - void visitCallImport(CallImport* curr) { - auto iter = illegalToLegal->find(curr->target); - if (iter == illegalToLegal->end()) return; + void visitCall(Call* curr) { + auto iter = illegalImportsToLegal->find(curr->target); + if (iter == illegalImportsToLegal->end()) return; if (iter->second == getFunction()->name) return; // inside the stub function itself, is the one safe place to do the call replaceCurrent(Builder(*getModule()).makeCall(iter->second, curr->operands, curr->type)); @@ -102,7 +100,7 @@ struct LegalizeJSInterface : public Pass { PassRunner passRunner(module); passRunner.setIsNested(true); - passRunner.add<FixImports>(&illegalToLegal); + passRunner.add<FixImports>(&illegalImportsToLegal); passRunner.run(); } @@ -113,7 +111,7 @@ struct LegalizeJSInterface : public Pass { private: // map of illegal to legal names for imports - std::map<Name, Name> illegalToLegal; + std::map<Name, Name> illegalImportsToLegal; bool needTempRet0Helpers = false; @@ -179,24 +177,22 @@ private: } // wasm calls the import, so it must call a stub that calls the actual legal JS import - Import* makeLegalStub(Import* im, Module* module, Name& funcName) { + Name makeLegalStubForCalledImport(Function* im, Module* module) { Builder builder(*module); - auto* type = new FunctionType(); + auto* type = new FunctionType; type->name = Name(std::string("legaltype$") + im->name.str); - auto* legal = new Import(); + auto* legal = new Function; legal->name = Name(std::string("legalimport$") + im->name.str); legal->module = im->module; legal->base = im->base; - legal->kind = ExternalKind::Function; - legal->functionType = type->name; - auto* func = new Function(); + legal->type = type->name; + auto* func = new Function; func->name = Name(std::string("legalfunc$") + im->name.str); - funcName = func->name; - auto* call = module->allocator.alloc<CallImport>(); + auto* call = module->allocator.alloc<Call>(); call->target = legal->name; - auto* imFunctionType = module->getFunctionType(im->functionType); + auto* imFunctionType = module->getFunctionType(im->type); for (auto param : imFunctionType->params) { if (param == i64) { @@ -231,10 +227,18 @@ private: type->result = imFunctionType->result; } func->result = imFunctionType->result; + FunctionTypeUtils::fillFunction(legal, type); - module->addFunction(func); - module->addFunctionType(type); - return legal; + if (!module->getFunctionOrNull(func->name)) { + module->addFunction(func); + } + if (!module->getFunctionTypeOrNull(type->name)) { + module->addFunctionType(type); + } + if (!module->getFunctionOrNull(legal->name)) { + module->addFunction(legal); + } + return func->name; } void ensureTempRet0(Module* module) { diff --git a/src/passes/LogExecution.cpp b/src/passes/LogExecution.cpp index 8d555fefe..45a29eae9 100644 --- a/src/passes/LogExecution.cpp +++ b/src/passes/LogExecution.cpp @@ -30,6 +30,7 @@ #include "shared-constants.h" #include "asmjs/shared-constants.h" #include "asm_v_wasm.h" +#include "ir/function-type-utils.h" namespace wasm { @@ -46,13 +47,14 @@ struct LogExecution : public WalkerPass<PostWalker<LogExecution>> { void visitModule(Module *curr) { // Add the import - auto import = new Import; + auto import = new Function; import->name = LOGGER; import->module = ENV; import->base = LOGGER; - import->functionType = ensureFunctionType("vi", curr)->name; - import->kind = ExternalKind::Function; - curr->addImport(import); + auto* functionType = ensureFunctionType("vi", curr); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + curr->addFunction(import); } private: @@ -60,7 +62,7 @@ private: static Index id = 0; Builder builder(*getModule()); return builder.makeSequence( - builder.makeCallImport( + builder.makeCall( LOGGER, { builder.makeConst(Literal(int32_t(id++))) }, none diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index a68b202d7..352a0a08d 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -453,10 +453,6 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { handleCall(curr); } - void visitCallImport(CallImport* curr) { - handleCall(curr); - } - void visitCallIndirect(CallIndirect* curr) { Block* outer = nullptr; for (Index i = 0; i < curr->operands.size(); i++) { diff --git a/src/passes/Metrics.cpp b/src/passes/Metrics.cpp index 81706042b..5176f8762 100644 --- a/src/passes/Metrics.cpp +++ b/src/passes/Metrics.cpp @@ -46,25 +46,26 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor< } void doWalkModule(Module* module) { + ImportInfo imports(*module); + // global things for (auto& curr : module->functionTypes) { visitFunctionType(curr.get()); } - for (auto& curr : module->imports) { - visitImport(curr.get()); - } for (auto& curr : module->exports) { visitExport(curr.get()); } - for (auto& curr : module->globals) { - walkGlobal(curr.get()); - } + ModuleUtils::iterDefinedGlobals(*module, [&](Global* curr) { + walkGlobal(curr); + }); walkTable(&module->table); walkMemory(&module->memory); + // add imports + counts["[imports]"] = imports.getNumImports(); // add functions - counts["[funcs]"] = module->functions.size(); + counts["[funcs]"] = imports.getNumDefinedFunctions(); // add memory and table if (module->memory.exists) { Index size = 0; @@ -89,14 +90,14 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor< WasmBinaryWriter writer(module, buffer); writer.write(); // print for each function - for (Index i = 0; i < module->functions.size(); i++) { - auto* func = module->functions[i].get(); + Index binaryIndex = 0; + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { counts.clear(); walkFunction(func); counts["[vars]"] = func->getNumVars(); - counts["[binary-bytes]"] = writer.tableOfContents.functionBodies[i].size; + counts["[binary-bytes]"] = writer.tableOfContents.functionBodies[binaryIndex++].size; printCounts(std::string("func: ") + func->name.str); - } + }); // print for each export how much code size is due to it, i.e., // how much the module could shrink without it. auto sizeAfterGlobalCleanup = [](Module* module) { @@ -138,10 +139,10 @@ struct Metrics : public WalkerPass<PostWalker<Metrics, UnifiedExpressionVisitor< } else { // add function info size_t vars = 0; - for (auto& func : module->functions) { - walkFunction(func.get()); + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { + walkFunction(func); vars += func->getNumVars(); - } + }); counts["[vars]"] = vars; // print printCounts("total"); diff --git a/src/passes/NameList.cpp b/src/passes/NameList.cpp index ebc3a5c55..6b1d528e4 100644 --- a/src/passes/NameList.cpp +++ b/src/passes/NameList.cpp @@ -20,15 +20,16 @@ #include "wasm.h" #include "pass.h" +#include "ir/module-utils.h" #include "ir/utils.h" namespace wasm { struct NameList : public Pass { void run(PassRunner* runner, Module* module) override { - for (auto& func : module->functions) { + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { std::cout << " " << func->name << " : " << Measurer::measure(func->body) << '\n'; - } + }); } }; diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 73f78611b..255ff4cbd 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -43,132 +43,6 @@ Name I32_EXPR = "i32.expr", F64_EXPR = "f64.expr", ANY_EXPR = "any.expr"; -// A pattern -struct Pattern { - Expression* input; - Expression* output; - - Pattern(Expression* input, Expression* output) : input(input), output(output) {} -}; - -#if 0 -// Database of patterns -struct PatternDatabase { - Module wasm; - - char* input; - - std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more - - PatternDatabase() { - // generate module - input = strdup( - #include "OptimizeInstructions.wast.processed" - ); - try { - SExpressionParser parser(input); - Element& root = *parser.root; - SExpressionWasmBuilder builder(wasm, *root[0]); - // parse module form - auto* func = wasm.getFunction("patterns"); - auto* body = func->body->cast<Block>(); - for (auto* item : body->list) { - auto* pair = item->cast<Block>(); - patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]); - } - } catch (ParseException& p) { - p.dump(std::cerr); - Fatal() << "error in parsing wasm binary"; - } - } - - ~PatternDatabase() { - free(input); - }; -}; - -static PatternDatabase* database = nullptr; - -struct DatabaseEnsurer { - DatabaseEnsurer() { - assert(!database); - database = new PatternDatabase; - } -}; -#endif - -// Check for matches and apply them -struct Match { - Module& wasm; - Pattern& pattern; - - Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {} - - std::vector<Expression*> wildcards; // id in i32.any(id) etc. => the expression it represents in this match - - // Comparing/checking - - // Check if we can match to this pattern, updating ourselves with the info if so - bool check(Expression* seen) { - // compare seen to the pattern input, doing a special operation for our "wildcards" - assert(wildcards.size() == 0); - auto compare = [this](Expression* subInput, Expression* subSeen) { - CallImport* call = subInput->dynCast<CallImport>(); - if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false; - Index index = call->operands[0]->cast<Const>()->value.geti32(); - // handle our special functions - auto checkMatch = [&](Type type) { - if (type != none && subSeen->type != type) return false; - while (index >= wildcards.size()) { - wildcards.push_back(nullptr); - } - if (!wildcards[index]) { - // new wildcard - wildcards[index] = subSeen; // NB: no need to copy - return true; - } else { - // We are seeing this index for a second or later time, check it matches - return ExpressionAnalyzer::equal(subSeen, wildcards[index]); - }; - }; - if (call->target == I32_EXPR) { - if (checkMatch(i32)) return true; - } else if (call->target == I64_EXPR) { - if (checkMatch(i64)) return true; - } else if (call->target == F32_EXPR) { - if (checkMatch(f32)) return true; - } else if (call->target == F64_EXPR) { - if (checkMatch(f64)) return true; - } else if (call->target == ANY_EXPR) { - if (checkMatch(none)) return true; - } - return false; - }; - - return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, compare); - } - - - // Applying/copying - - // Apply the match, generate an output expression from the matched input, performing substitutions as necessary - Expression* apply() { - // When copying a wildcard, perform the substitution. - // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow. - auto copy = [this](Expression* curr) -> Expression* { - CallImport* call = curr->dynCast<CallImport>(); - if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr; - Index index = call->operands[0]->cast<Const>()->value.geti32(); - // handle our special functions - if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) { - return ExpressionManipulator::copy(wildcards.at(index), wasm); - } - return nullptr; - }; - return ExpressionManipulator::flexibleCopy(pattern.output, wasm, copy); - } -}; - // Utilities // returns the maximum amount of bits used in an integer expression diff --git a/src/passes/PostEmscripten.cpp b/src/passes/PostEmscripten.cpp index a7f0e6282..72c0d8808 100644 --- a/src/passes/PostEmscripten.cpp +++ b/src/passes/PostEmscripten.cpp @@ -92,11 +92,12 @@ struct PostEmscripten : public WalkerPass<PostWalker<PostEmscripten>> { optimizeMemoryAccess(curr->ptr, curr->offset); } - void visitCallImport(CallImport* curr) { + void visitCall(Call* curr) { // special asm.js imports can be optimized - auto* import = getModule()->getImport(curr->target); - if (import->module == GLOBAL_MATH) { - if (import->base == POW) { + auto* func = getModule()->getFunction(curr->target); + if (!func->imported()) return; + if (func->module == GLOBAL_MATH) { + if (func->base == POW) { if (auto* exponent = curr->operands[1]->dynCast<Const>()) { if (exponent->value == Literal(double(2.0))) { // This is just a square operation, do a multiply diff --git a/src/passes/Precompute.cpp b/src/passes/Precompute.cpp index 9735b91d7..b6b7d5097 100644 --- a/src/passes/Precompute.cpp +++ b/src/passes/Precompute.cpp @@ -61,9 +61,6 @@ public: Flow visitCall(Call* curr) { return Flow(NOTPRECOMPUTABLE_FLOW); } - Flow visitCallImport(CallImport* curr) { - return Flow(NOTPRECOMPUTABLE_FLOW); - } Flow visitCallIndirect(CallIndirect* curr) { return Flow(NOTPRECOMPUTABLE_FLOW); } @@ -89,11 +86,9 @@ public: return Flow(NOTPRECOMPUTABLE_FLOW); } Flow visitGetGlobal(GetGlobal *curr) { - auto* global = module->getGlobalOrNull(curr->name); - if (global) { - if (!global->mutable_) { - return visit(global->init); - } + auto* global = module->getGlobal(curr->name); + if (!global->imported() && !global->mutable_) { + return visit(global->init); } return Flow(NOTPRECOMPUTABLE_FLOW); } diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index 77562af4b..3ab1cc75d 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -109,10 +109,6 @@ struct PrintExpressionContents : public Visitor<PrintExpressionContents> { printMedium(o, "call "); printName(curr->target, o); } - void visitCallImport(CallImport* curr) { - printMedium(o, "call "); - printName(curr->target, o); - } void visitCallIndirect(CallIndirect* curr) { printMedium(o, "call_indirect (type ") << curr->fullType << ')'; } @@ -478,6 +474,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void decIndent() { if (!minify) { + assert(indent > 0); indent--; doIndent(o, indent); } @@ -636,11 +633,6 @@ struct PrintSExpression : public Visitor<PrintSExpression> { PrintExpressionContents(currFunction, o).visit(curr); printCallOperands(curr); } - void visitCallImport(CallImport* curr) { - o << '('; - PrintExpressionContents(currFunction, o).visit(curr); - printCallOperands(curr); - } void visitCallIndirect(CallIndirect* curr) { o << '('; PrintExpressionContents(currFunction, o).visit(curr); @@ -818,20 +810,6 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } o << ")"; } - void visitImport(Import* curr) { - o << '('; - printMedium(o, "import "); - printText(o, curr->module.str) << ' '; - printText(o, curr->base.str) << ' '; - switch (curr->kind) { - case ExternalKind::Function: if (curr->functionType.is()) visitFunctionType(currModule->getFunctionType(curr->functionType), &curr->name); break; - case ExternalKind::Table: printTableHeader(&currModule->table); break; - case ExternalKind::Memory: printMemoryHeader(&currModule->memory); break; - case ExternalKind::Global: o << "(global " << curr->name << ' ' << printType(curr->globalType) << ")"; break; - default: WASM_UNREACHABLE(); - } - o << ')'; - } void visitExport(Export* curr) { o << '('; printMedium(o, "export "); @@ -846,19 +824,66 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << ' '; printName(curr->value, o) << "))"; } + void emitImportHeader(Importable* curr) { + printMedium(o, "import "); + printText(o, curr->module.str) << ' '; + printText(o, curr->base.str) << ' '; + } void visitGlobal(Global* curr) { - o << '('; - printMedium(o, "global "); - printName(curr->name, o) << ' '; + if (curr->imported()) { + visitImportedGlobal(curr); + } else { + visitDefinedGlobal(curr); + } + } + void emitGlobalType(Global* curr) { if (curr->mutable_) { - o << "(mut " << printType(curr->type) << ") "; + o << "(mut " << printType(curr->type) << ')'; } else { - o << printType(curr->type) << ' '; + o << printType(curr->type); } + } + void visitImportedGlobal(Global* curr) { + doIndent(o, indent); + o << '('; + emitImportHeader(curr); + o << "(global "; + printName(curr->name, o) << ' '; + emitGlobalType(curr); + o << "))" << maybeNewLine; + } + void visitDefinedGlobal(Global* curr) { + doIndent(o, indent); + o << '('; + printMedium(o, "global "); + printName(curr->name, o) << ' '; + emitGlobalType(curr); + o << ' '; visit(curr->init); o << ')'; + o << maybeNewLine; } void visitFunction(Function* curr) { + if (curr->imported()) { + visitImportedFunction(curr); + } else { + visitDefinedFunction(curr); + } + } + void visitImportedFunction(Function* curr) { + doIndent(o, indent); + currFunction = curr; + lastPrintedLocation = { 0, 0, 0 }; + o << '('; + emitImportHeader(curr); + if (curr->type.is()) { + visitFunctionType(currModule->getFunctionType(curr->type), &curr->name); + } + o << ')'; + o << maybeNewLine; + } + void visitDefinedFunction(Function* curr) { + doIndent(o, indent); currFunction = curr; lastPrintedLocation = { 0, 0, 0 }; if (currFunction->prologLocation.size()) { @@ -927,6 +952,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } else { decIndent(); } + o << maybeNewLine; } void printTableHeader(Table* curr) { o << '('; @@ -937,8 +963,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitTable(Table* curr) { if (!curr->exists) return; - // if table wasn't imported, declare it - if (!curr->imported) { + if (curr->imported()) { + doIndent(o, indent); + o << '('; + emitImportHeader(curr); + printTableHeader(&currModule->table); + o << ')' << maybeNewLine; + } else { doIndent(o, indent); printTableHeader(curr); o << maybeNewLine; @@ -954,7 +985,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { o << ' '; printName(name, o); } - o << ")\n"; + o << ')' << maybeNewLine; } } void printMemoryHeader(Memory* curr) { @@ -972,8 +1003,13 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } void visitMemory(Memory* curr) { if (!curr->exists) return; - // if memory wasn't imported, declare it - if (!curr->imported) { + if (curr->imported()) { + doIndent(o, indent); + o << '('; + emitImportHeader(curr); + printMemoryHeader(&currModule->memory); + o << ')' << maybeNewLine; + } else { doIndent(o, indent); printMemoryHeader(curr); o << '\n'; @@ -1004,7 +1040,7 @@ struct PrintSExpression : public Visitor<PrintSExpression> { } } } - o << "\")\n"; + o << "\")" << maybeNewLine; } } void visitModule(Module* curr) { @@ -1020,20 +1056,19 @@ struct PrintSExpression : public Visitor<PrintSExpression> { visitFunctionType(child.get()); o << ")" << maybeNewLine; } - for (auto& child : curr->imports) { - doIndent(o, indent); - visitImport(child.get()); - o << maybeNewLine; - } - for (auto& child : curr->globals) { - doIndent(o, indent); - visitGlobal(child.get()); - o << maybeNewLine; - } + visitMemory(&curr->memory); if (curr->table.exists) { visitTable(&curr->table); // Prints its own newlines } - visitMemory(&curr->memory); + ModuleUtils::iterImportedGlobals(*curr, [&](Global* global) { + visitGlobal(global); + }); + ModuleUtils::iterImportedFunctions(*curr, [&](Function* func) { + visitFunction(func); + }); + ModuleUtils::iterDefinedGlobals(*curr, [&](Global* global) { + visitGlobal(global); + }); for (auto& child : curr->exports) { doIndent(o, indent); visitExport(child.get()); @@ -1045,11 +1080,9 @@ struct PrintSExpression : public Visitor<PrintSExpression> { printMedium(o, "start") << ' ' << curr->start << ')'; o << maybeNewLine; } - for (auto& child : curr->functions) { - doIndent(o, indent); - visitFunction(child.get()); - o << maybeNewLine; - } + ModuleUtils::iterDefinedFunctions(*curr, [&](Function* func) { + visitFunction(func); + }); for (auto& section : curr->userSections) { doIndent(o, indent); o << ";; custom section \"" << section.name << "\", size " << section.data.size(); diff --git a/src/passes/PrintCallGraph.cpp b/src/passes/PrintCallGraph.cpp index fa58e3859..2a82b7aa1 100644 --- a/src/passes/PrintCallGraph.cpp +++ b/src/passes/PrintCallGraph.cpp @@ -24,6 +24,7 @@ #include "wasm.h" #include "pass.h" +#include "ir/module-utils.h" #include "ir/utils.h" namespace wasm { @@ -46,19 +47,17 @@ struct PrintCallGraph : public Pass { " }\n\n" " node [shape=box, fontname=courier, fontsize=10];\n"; - // All Functions - for (auto& func : module->functions) { - std::cout << " \"" << func.get()->name << "\" [style=\"filled\", fillcolor=\"white\"];\n"; - } + // Defined functions + ModuleUtils::iterDefinedFunctions(*module, [&](Function* curr) { + std::cout << " \"" << curr->name << "\" [style=\"filled\", fillcolor=\"white\"];\n"; + }); - // Imports Nodes - for (auto& curr : module->imports) { - if (curr->kind == ExternalKind::Function) { - o << " \"" << curr->name << "\" [style=\"filled\", fillcolor=\"turquoise\"];\n"; - } - } + // Imported functions + ModuleUtils::iterImportedFunctions(*module, [&](Function* curr) { + o << " \"" << curr->name << "\" [style=\"filled\", fillcolor=\"turquoise\"];\n"; + }); - // Exports Nodes + // Exports for (auto& curr : module->exports) { if (curr->kind == ExternalKind::Function) { Function* func = module->getFunction(curr->value); @@ -73,11 +72,11 @@ struct PrintCallGraph : public Pass { std::vector<Function*> allIndirectTargets; CallPrinter(Module *module) : module(module) { // Walk function bodies. - for (auto& func : module->functions) { - currFunction = func.get(); + ModuleUtils::iterDefinedFunctions(*module, [&](Function* curr) { + currFunction = curr; visitedTargets.clear(); - walk(func.get()->body); - } + walk(curr->body); + }); } void visitCall(Call *curr) { auto* target = module->getFunction(curr->target); @@ -85,12 +84,6 @@ struct PrintCallGraph : public Pass { visitedTargets.insert(target->name); std::cout << " \"" << currFunction->name << "\" -> \"" << target->name << "\"; // call\n"; } - void visitCallImport(CallImport *curr) { - auto name = curr->target; - if (visitedTargets.count(name) > 0) return; - visitedTargets.insert(name); - std::cout << " \"" << currFunction->name << "\" -> \"" << name << "\"; // callImport\n"; - } }; CallPrinter printer(module); diff --git a/src/passes/RemoveImports.cpp b/src/passes/RemoveImports.cpp index 116cd3296..e70cbb3ac 100644 --- a/src/passes/RemoveImports.cpp +++ b/src/passes/RemoveImports.cpp @@ -22,14 +22,19 @@ // look at all the rest of the code). // -#include <wasm.h> -#include <pass.h> +#include "wasm.h" +#include "pass.h" +#include "ir/module-utils.h" namespace wasm { struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { - void visitCallImport(CallImport *curr) { - Type type = getModule()->getFunctionType(getModule()->getImport(curr->target)->functionType)->result; + void visitCall(Call *curr) { + auto* func = getModule()->getFunction(curr->target); + if (!func->imported()) { + return; + } + Type type = getModule()->getFunctionType(func->type)->result; if (type == none) { replaceCurrent(getModule()->allocator.alloc<Nop>()); } else { @@ -41,13 +46,11 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> { void visitModule(Module *curr) { std::vector<Name> names; - for (auto& import : curr->imports) { - if (import->kind == ExternalKind::Function) { - names.push_back(import->name); - } - } + ModuleUtils::iterImportedFunctions(*curr, [&](Function* func) { + names.push_back(func->name); + }); for (auto& name : names) { - curr->removeImport(name); + curr->removeFunction(name); } } }; diff --git a/src/passes/RemoveUnusedModuleElements.cpp b/src/passes/RemoveUnusedModuleElements.cpp index 6cd050da9..069137ff6 100644 --- a/src/passes/RemoveUnusedModuleElements.cpp +++ b/src/passes/RemoveUnusedModuleElements.cpp @@ -25,6 +25,7 @@ #include "wasm.h" #include "pass.h" +#include "ir/module-utils.h" #include "ir/utils.h" #include "asm_v_wasm.h" @@ -63,15 +64,15 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { reachable.insert(curr); if (curr.first == ModuleElementKind::Function) { // if not an import, walk it - auto* func = module->getFunctionOrNull(curr.second); - if (func) { + auto* func = module->getFunction(curr.second); + if (!func->imported()) { walk(func->body); } } else { // if not imported, it has an init expression we need to walk - auto* glob = module->getGlobalOrNull(curr.second); - if (glob) { - walk(glob->init); + auto* global = module->getGlobal(curr.second); + if (!global->imported()) { + walk(global->init); } } } @@ -83,11 +84,6 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { queue.emplace_back(ModuleElementKind::Function, curr->target); } } - void visitCallImport(CallImport* curr) { - if (reachable.count(ModuleElement(ModuleElementKind::Function, curr->target)) == 0) { - queue.emplace_back(ModuleElementKind::Function, curr->target); - } - } void visitCallIndirect(CallIndirect* curr) { usesTable = true; } @@ -131,19 +127,17 @@ struct ReachabilityAnalyzer : public PostWalker<ReachabilityAnalyzer> { // Finds function type usage struct FunctionTypeAnalyzer : public PostWalker<FunctionTypeAnalyzer> { - std::vector<Import*> functionImports; + std::vector<Function*> functionImports; std::vector<Function*> functions; std::vector<CallIndirect*> indirectCalls; - void visitImport(Import* curr) { - if (curr->kind == ExternalKind::Function && curr->functionType.is()) { - functionImports.push_back(curr); - } - } - void visitFunction(Function* curr) { if (curr->type.is()) { - functions.push_back(curr); + if (curr->imported()) { + functionImports.push_back(curr); + } else { + functions.push_back(curr); + } } } @@ -176,9 +170,9 @@ struct RemoveUnusedModuleElements : public Pass { } // If told to, root all the functions if (rootAllFunctions) { - for (auto& func : module->functions) { + ModuleUtils::iterDefinedFunctions(*module, [&](Function* func) { roots.emplace_back(ModuleElementKind::Function, func->name); - } + }); } // Exports are roots. bool exportsMemory = false; @@ -194,15 +188,14 @@ struct RemoveUnusedModuleElements : public Pass { exportsTable = true; } } - // Check for special imports are roots. + // Check for special imports, which are roots. bool importsMemory = false; bool importsTable = false; - for (auto& curr : module->imports) { - if (curr->kind == ExternalKind::Memory) { - importsMemory = true; - } else if (curr->kind == ExternalKind::Table) { - importsTable = true; - } + if (module->memory.imported()) { + importsMemory = true; + } + if (module->table.imported()) { + importsTable = true; } // For now, all functions that can be called indirectly are marked as roots. for (auto& segment : module->table.segments) { @@ -225,17 +218,6 @@ struct RemoveUnusedModuleElements : public Pass { return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0; }), v.end()); } - { - auto& v = module->imports; - v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Import>& curr) { - if (curr->kind == ExternalKind::Function) { - return analyzer.reachable.count(ModuleElement(ModuleElementKind::Function, curr->name)) == 0; - } else if (curr->kind == ExternalKind::Global) { - return analyzer.reachable.count(ModuleElement(ModuleElementKind::Global, curr->name)) == 0; - } - return false; - }), v.end()); - } module->updateMaps(); // Handle the memory and table if (!exportsMemory && !analyzer.usesMemory) { @@ -245,10 +227,9 @@ struct RemoveUnusedModuleElements : public Pass { } if (module->memory.segments.empty()) { module->memory.exists = false; - module->memory.imported = false; + module->memory.module = module->memory.base = Name(); module->memory.initial = 0; module->memory.max = 0; - removeImport(ExternalKind::Memory, module); } } if (!exportsTable && !analyzer.usesTable) { @@ -258,21 +239,13 @@ struct RemoveUnusedModuleElements : public Pass { } if (module->table.segments.empty()) { module->table.exists = false; - module->table.imported = false; + module->table.module = module->table.base = Name(); module->table.initial = 0; module->table.max = 0; - removeImport(ExternalKind::Table, module); } } } - void removeImport(ExternalKind kind, Module* module) { - auto& v = module->imports; - v.erase(std::remove_if(v.begin(), v.end(), [&](const std::unique_ptr<Import>& curr) { - return curr->kind == kind; - }), v.end()); - } - void optimizeFunctionTypes(Module* module) { FunctionTypeAnalyzer analyzer; analyzer.walkModule(module); @@ -294,7 +267,7 @@ struct RemoveUnusedModuleElements : public Pass { }; // canonicalize all uses of function types for (auto* import : analyzer.functionImports) { - import->functionType = canonicalize(import->functionType); + import->type = canonicalize(import->type); } for (auto* func : analyzer.functions) { func->type = canonicalize(func->type); diff --git a/src/passes/SafeHeap.cpp b/src/passes/SafeHeap.cpp index 325516f4d..5c1980f28 100644 --- a/src/passes/SafeHeap.cpp +++ b/src/passes/SafeHeap.cpp @@ -26,6 +26,7 @@ #include "asmjs/shared-constants.h" #include "wasm-builder.h" #include "ir/bits.h" +#include "ir/function-type-utils.h" #include "ir/import-utils.h" namespace wasm { @@ -114,39 +115,40 @@ struct SafeHeap : public Pass { Name dynamicTopPtr, segfault, alignfault; void addImports(Module* module) { - // imports - if (auto* existing = ImportUtils::getImport(*module, ENV, DYNAMICTOP_PTR_IMPORT)) { + ImportInfo info(*module); + if (auto* existing = info.getImportedGlobal(ENV, DYNAMICTOP_PTR_IMPORT)) { dynamicTopPtr = existing->name; } else { - auto* import = new Import; + auto* import = new Global; import->name = dynamicTopPtr = DYNAMICTOP_PTR_IMPORT; import->module = ENV; import->base = DYNAMICTOP_PTR_IMPORT; - import->kind = ExternalKind::Global; - import->globalType = i32; - module->addImport(import); + import->type = i32; + module->addGlobal(import); } - if (auto* existing = ImportUtils::getImport(*module, ENV, SEGFAULT_IMPORT)) { + if (auto* existing = info.getImportedFunction(ENV, SEGFAULT_IMPORT)) { segfault = existing->name; } else { - auto* import = new Import; + auto* import = new Function; import->name = segfault = SEGFAULT_IMPORT; import->module = ENV; import->base = SEGFAULT_IMPORT; - import->kind = ExternalKind::Function; - import->functionType = ensureFunctionType("v", module)->name; - module->addImport(import); + auto* functionType = ensureFunctionType("v", module); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + module->addFunction(import); } - if (auto* existing = ImportUtils::getImport(*module, ENV, ALIGNFAULT_IMPORT)) { + if (auto* existing = info.getImportedFunction(ENV, ALIGNFAULT_IMPORT)) { alignfault = existing->name; } else { - auto* import = new Import; + auto* import = new Function; import->name = alignfault = ALIGNFAULT_IMPORT; import->module = ENV; import->base = ALIGNFAULT_IMPORT; - import->kind = ExternalKind::Function; - import->functionType = ensureFunctionType("v", module)->name; - module->addImport(import); + auto* functionType = ensureFunctionType("v", module); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); + module->addFunction(import); } } @@ -291,7 +293,7 @@ struct SafeHeap : public Pass { builder.makeGetLocal(local, i32), builder.makeConst(Literal(int32_t(align - 1))) ), - builder.makeCallImport(alignfault, {}, none) + builder.makeCall(alignfault, {}, none) ); } @@ -316,7 +318,7 @@ struct SafeHeap : public Pass { ) ) ), - builder.makeCallImport(segfault, {}, none) + builder.makeCall(segfault, {}, none) ); } }; diff --git a/src/passes/SpillPointers.cpp b/src/passes/SpillPointers.cpp index d65c89d3e..36c2ae948 100644 --- a/src/passes/SpillPointers.cpp +++ b/src/passes/SpillPointers.cpp @@ -60,9 +60,6 @@ struct SpillPointers : public WalkerPass<LivenessWalker<SpillPointers, Visitor<S void visitCall(Call* curr) { visitSpillable(curr); } - void visitCallImport(CallImport* curr) { - visitSpillable(curr); - } void visitCallIndirect(CallIndirect* curr) { visitSpillable(curr); } @@ -161,10 +158,6 @@ struct SpillPointers : public WalkerPass<LivenessWalker<SpillPointers, Visitor<S for (auto*& operand : call->cast<Call>()->operands) { handleOperand(operand); } - } else if (call->is<CallImport>()) { - for (auto*& operand : call->cast<CallImport>()->operands) { - handleOperand(operand); - } } else if (call->is<CallIndirect>()) { for (auto*& operand : call->cast<CallIndirect>()->operands) { handleOperand(operand); diff --git a/src/passes/TrapMode.cpp b/src/passes/TrapMode.cpp index 68d6aad4b..19301ee3a 100644 --- a/src/passes/TrapMode.cpp +++ b/src/passes/TrapMode.cpp @@ -22,6 +22,7 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" +#include "ir/function-type-utils.h" #include "ir/trapping.h" #include "mixed_arena.h" #include "pass.h" @@ -222,12 +223,13 @@ void ensureF64ToI64JSImport(TrappingFunctionContainer &trappingFunctions) { } Module& wasm = trappingFunctions.getModule(); - auto import = new Import; // f64-to-int = asm2wasm.f64-to-int; + auto import = new Function; // f64-to-int = asm2wasm.f64-to-int; import->name = F64_TO_INT; import->module = ASM2WASM; import->base = F64_TO_INT; - import->functionType = ensureFunctionType("id", &wasm)->name; - import->kind = ExternalKind::Function; + auto* functionType = ensureFunctionType("id", &wasm); + import->type = functionType->name; + FunctionTypeUtils::fillFunction(import, functionType); trappingFunctions.addImport(import); } @@ -263,7 +265,7 @@ Expression* makeTrappingUnary(Unary* curr, TrappingFunctionContainer &trappingFu // WebAssembly traps on float-to-int overflows, but asm.js wouldn't, so we must emulate that ensureF64ToI64JSImport(trappingFunctions); Expression* f64Value = ensureDouble(curr->value, wasm.allocator); - return builder.makeCallImport(F64_TO_INT, {f64Value}, i32); + return builder.makeCall(F64_TO_INT, {f64Value}, i32); } ensureUnaryFunc(curr, wasm, trappingFunctions); diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index f24f76881..f59fb287a 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -63,7 +63,6 @@ struct Vacuum : public WalkerPass<PostWalker<Vacuum>> { case Expression::Id::BreakId: case Expression::Id::SwitchId: case Expression::Id::CallId: - case Expression::Id::CallImportId: case Expression::Id::CallIndirectId: case Expression::Id::SetLocalId: case Expression::Id::StoreId: diff --git a/src/passes/pass.cpp b/src/passes/pass.cpp index 9215341e3..66ffb9b30 100644 --- a/src/passes/pass.cpp +++ b/src/passes/pass.cpp @@ -17,12 +17,13 @@ #include <chrono> #include <sstream> -#include <support/colors.h> -#include <passes/passes.h> -#include <pass.h> -#include <wasm-validator.h> -#include <wasm-io.h> +#include "support/colors.h" +#include "passes/passes.h" +#include "pass.h" +#include "wasm-validator.h" +#include "wasm-io.h" #include "ir/hashed.h" +#include "ir/module-utils.h" namespace wasm { @@ -267,9 +268,9 @@ void PassRunner::run() { auto before = std::chrono::steady_clock::now(); if (pass->isFunctionParallel()) { // function-parallel passes should get a new instance per function - for (auto& func : wasm->functions) { - runPassOnFunction(pass, func.get()); - } + ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { + runPassOnFunction(pass, func); + }); } else { runPass(pass); } @@ -320,9 +321,11 @@ void PassRunner::run() { return ThreadWorkState::Finished; // nothing left } Function* func = this->wasm->functions[index].get(); - // do the current task: run all passes on this function - for (auto* pass : stack) { - runPassOnFunction(pass, func); + if (!func->imported()) { + // do the current task: run all passes on this function + for (auto* pass : stack) { + runPassOnFunction(pass, func); + } } if (index + 1 == numFunctions) { return ThreadWorkState::Finished; // we did the last one diff --git a/src/shell-interface.h b/src/shell-interface.h index 5b01b4be6..abfd4ac18 100644 --- a/src/shell-interface.h +++ b/src/shell-interface.h @@ -26,6 +26,7 @@ #include "support/name.h" #include "wasm.h" #include "wasm-interpreter.h" +#include "ir/module-utils.h" namespace wasm { @@ -118,24 +119,25 @@ struct ShellExternalInterface final : ModuleInstance::ExternalInterface { void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override { // add spectest globals - for (auto& import : wasm.imports) { - if (import->kind == ExternalKind::Global && import->module == SPECTEST && import->base == GLOBAL) { - switch (import->globalType) { + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + if (import->module == SPECTEST && import->base == GLOBAL) { + switch (import->type) { case i32: globals[import->name] = Literal(int32_t(666)); break; case i64: globals[import->name] = Literal(int64_t(666)); break; case f32: globals[import->name] = Literal(float(666.6)); break; case f64: globals[import->name] = Literal(double(666.6)); break; default: WASM_UNREACHABLE(); } - } else if (import->kind == ExternalKind::Memory && import->module == SPECTEST && import->base == MEMORY) { - // imported memory has initial 1 and max 2 - wasm.memory.initial = 1; - wasm.memory.max = 2; } + }); + if (wasm.memory.imported() && wasm.memory.module == SPECTEST && wasm.memory.base == MEMORY) { + // imported memory has initial 1 and max 2 + wasm.memory.initial = 1; + wasm.memory.max = 2; } } - Literal callImport(Import *import, LiteralList& arguments) override { + Literal callImport(Function* import, LiteralList& arguments) override { if (import->module == SPECTEST && import->base == PRINT) { for (auto argument : arguments) { std::cout << '(' << argument << ')' << '\n'; @@ -163,7 +165,11 @@ struct ShellExternalInterface final : ModuleInstance::ExternalInterface { if (func->result != result) { trap("callIndirect: bad result type"); } - return instance.callFunctionInternal(func->name, arguments); + if (func->imported()) { + return callImport(func, arguments); + } else { + return instance.callFunctionInternal(func->name, arguments); + } } int8_t load8s(Address addr) override { return memory.get<int8_t>(addr); } diff --git a/src/tools/execution-results.h b/src/tools/execution-results.h index 1e8fba4ff..ca4b819d9 100644 --- a/src/tools/execution-results.h +++ b/src/tools/execution-results.h @@ -20,6 +20,7 @@ #include "wasm.h" #include "shell-interface.h" +#include "ir/import-utils.h" namespace wasm { @@ -32,7 +33,7 @@ struct ExecutionResults { // get results of execution void get(Module& wasm) { - if (wasm.imports.size() > 0) { + if (ImportInfo(wasm).getNumImports() > 0) { std::cout << "[fuzz-exec] imports, so quitting\n"; return; } diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index e11454b04..4272c5dba 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -36,6 +36,7 @@ #include "ir/global-utils.h" #include "ir/import-utils.h" #include "ir/literal-utils.h" +#include "ir/module-utils.h" using namespace wasm; @@ -124,23 +125,23 @@ public: EvallingModuleInstance(Module& wasm, ExternalInterface* externalInterface) : ModuleInstanceBase(wasm, externalInterface) { // if any global in the module has a non-const constructor, it is using a global import, // which we don't have, and is illegal to use - for (auto& global : wasm.globals) { + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) { if (!global->init->is<Const>()) { // some constants are ok to use if (auto* get = global->init->dynCast<GetGlobal>()) { auto name = get->name; - auto* import = wasm.getImport(name); + auto* import = wasm.getGlobal(name); if (import->module == Name("env") && ( import->base == Name("STACKTOP") || // stack constants are special, we handle them import->base == Name("STACK_MAX") )) { - continue; // this is fine + return; // this is fine } } // this global is dangerously initialized by an import, so if it is used, we must fail globals.addDangerous(global->name); } - } + }); } std::vector<char> stack; @@ -173,34 +174,33 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { void importGlobals(EvallingGlobalManager& globals, Module& wasm_) override { // fill usable values for stack imports, and globals initialized to them - if (auto* stackTop = ImportUtils::getImport(wasm_, "env", "STACKTOP")) { + ImportInfo imports(wasm_); + if (auto* stackTop = imports.getImportedGlobal("env", "STACKTOP")) { globals[stackTop->name] = Literal(int32_t(STACK_START)); if (auto* stackTop = GlobalUtils::getGlobalInitializedToImport(wasm_, "env", "STACKTOP")) { globals[stackTop->name] = Literal(int32_t(STACK_START)); } } - if (auto* stackMax = ImportUtils::getImport(wasm_, "env", "STACK_MAX")) { + if (auto* stackMax = imports.getImportedGlobal("env", "STACK_MAX")) { globals[stackMax->name] = Literal(int32_t(STACK_START)); if (auto* stackMax = GlobalUtils::getGlobalInitializedToImport(wasm_, "env", "STACK_MAX")) { globals[stackMax->name] = Literal(int32_t(STACK_START)); } } // fill in fake values for everything else, which is dangerous to use - for (auto& global : wasm_.globals) { - if (globals.find(global->name) == globals.end()) { - globals[global->name] = LiteralUtils::makeLiteralZero(global->type); + ModuleUtils::iterDefinedGlobals(wasm_, [&](Global* defined) { + if (globals.find(defined->name) == globals.end()) { + globals[defined->name] = LiteralUtils::makeLiteralZero(defined->type); } - } - for (auto& import : wasm_.imports) { - if (import->kind == ExternalKind::Global) { - if (globals.find(import->name) == globals.end()) { - globals[import->name] = LiteralUtils::makeLiteralZero(import->globalType); - } + }); + ModuleUtils::iterImportedGlobals(wasm_, [&](Global* import) { + if (globals.find(import->name) == globals.end()) { + globals[import->name] = LiteralUtils::makeLiteralZero(import->type); } - } + }); } - Literal callImport(Import *import, LiteralList& arguments) override { + Literal callImport(Function* import, LiteralList& arguments) override { std::string extra; if (import->module == "env" && import->base == "___cxa_atexit") { extra = "\nrecommendation: build with -s NO_EXIT_RUNTIME=1 so that calls to atexit are not emitted"; @@ -227,7 +227,8 @@ struct CtorEvalExternalInterface : EvallingModuleInstance::ExternalInterface { if (start <= index && index < end) { auto name = segment.data[index - start]; // if this is one of our functions, we can call it; if it was imported, fail - if (wasm->getFunctionOrNull(name)) { + auto* func = wasm->getFunction(name); + if (!func->imported()) { return instance.callFunctionInternal(name, arguments); } else { throw FailToEvalException(std::string("callTable on imported function: ") + name.str); diff --git a/src/tools/wasm-merge.cpp b/src/tools/wasm-merge.cpp index e8910860e..70611e269 100644 --- a/src/tools/wasm-merge.cpp +++ b/src/tools/wasm-merge.cpp @@ -34,20 +34,10 @@ #include "wasm-binary.h" #include "wasm-builder.h" #include "wasm-validator.h" +#include "ir/module-utils.h" using namespace wasm; -// Calls note() on every import that has form "env".(base) -static void findImportsByBase(Module& wasm, Name base, std::function<void (Name)> note) { - for (auto& curr : wasm.imports) { - if (curr->module == ENV) { - if (curr->base == base) { - note(curr->name); - } - } - } -} - // Ensure a memory or table is of at least a size template<typename T> static void ensureSize(T& what, Index size) { @@ -119,31 +109,33 @@ struct Mergeable { } void findImports() { - findImportsByBase(wasm, MEMORY_BASE, [&](Name name) { - memoryBaseGlobals.insert(name); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + if (import->module == ENV && import->base == MEMORY_BASE) { + memoryBaseGlobals.insert(import->name); + } }); if (memoryBaseGlobals.size() == 0) { // add one - auto* import = new Import; + auto* import = new Global; import->name = MEMORY_BASE; import->module = ENV; import->base = MEMORY_BASE; - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + import->type = i32; + wasm.addGlobal(import); memoryBaseGlobals.insert(import->name); } - findImportsByBase(wasm, TABLE_BASE, [&](Name name) { - tableBaseGlobals.insert(name); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + if (import->module == ENV && import->base == TABLE_BASE) { + tableBaseGlobals.insert(import->name); + } }); if (tableBaseGlobals.size() == 0) { - auto* import = new Import; + auto* import = new Global; import->name = TABLE_BASE; import->module = ENV; import->base = TABLE_BASE; - import->kind = ExternalKind::Global; - import->globalType = i32; - wasm.addImport(import); + import->type = i32; + wasm.addGlobal(import); tableBaseGlobals.insert(import->name); } } @@ -244,7 +236,7 @@ struct Mergeable { struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergeable>>, public Mergeable { OutputMergeable(Module& wasm) : Mergeable(wasm) {} - void visitCallImport(CallImport* curr) { + void visitCall(Call* curr) { auto iter = implementedFunctionImports.find(curr->target); if (iter != implementedFunctionImports.end()) { // this import is now in the module - call it @@ -264,10 +256,10 @@ struct OutputMergeable : public PostWalker<OutputMergeable, Visitor<OutputMergea void visitModule(Module* curr) { // remove imports that are being implemented for (auto& pair : implementedFunctionImports) { - curr->removeImport(pair.first); + curr->removeFunction(pair.first); } for (auto& pair : implementedGlobalImports) { - curr->removeImport(pair.first); + curr->removeGlobal(pair.first); } } }; @@ -288,11 +280,6 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp std::map<Name, Name> gNames; // globals void visitCall(Call* curr) { - curr->target = fNames[curr->target]; - assert(curr->target.is()); - } - - void visitCallImport(CallImport* curr) { auto iter = implementedFunctionImports.find(curr->target); if (iter != implementedFunctionImports.end()) { // this import is now in the module - call it @@ -333,29 +320,38 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp void merge() { // find function imports in us that are implemented in the output // TODO make maps, avoid N^2 - for (auto& imp : wasm.imports) { + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { // per wasm dynamic library rules, we expect to see exports on 'env' - if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) { + if (import->module == ENV) { // seek an export on the other side that matches for (auto& exp : outputMergeable.wasm.exports) { - if (exp->kind == imp->kind && exp->name == imp->base) { + if (exp->name == import->base) { // fits! - if (imp->kind == ExternalKind::Function) { - implementedFunctionImports[imp->name] = exp->value; - } else { - implementedGlobalImports[imp->name] = exp->value; - } + implementedFunctionImports[import->name] = exp->value; break; } } } - } + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + // per wasm dynamic library rules, we expect to see exports on 'env' + if (import->module == ENV) { + // seek an export on the other side that matches + for (auto& exp : outputMergeable.wasm.exports) { + if (exp->name == import->base) { + // fits! + implementedGlobalImports[import->name] = exp->value; + break; + } + } + } + }); // remove the unneeded ones for (auto& pair : implementedFunctionImports) { - wasm.removeImport(pair.first); + wasm.removeFunction(pair.first); } for (auto& pair : implementedGlobalImports) { - wasm.removeImport(pair.first); + wasm.removeGlobal(pair.first); } // find new names @@ -364,27 +360,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp return outputMergeable.wasm.getFunctionTypeOrNull(name); }); } - for (auto& curr : wasm.imports) { - if (curr->kind == ExternalKind::Function) { - curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { - return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getFunctionOrNull(name); - }); - } else if (curr->kind == ExternalKind::Global) { - curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { - return !!outputMergeable.wasm.getImportOrNull(name) || !!outputMergeable.wasm.getGlobalOrNull(name); - }); - } - } - for (auto& curr : wasm.functions) { + ModuleUtils::iterImportedFunctions(wasm, [&](Function* curr) { + curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { + return !!outputMergeable.wasm.getFunctionOrNull(name); + }); + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* curr) { + curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { + return !!outputMergeable.wasm.getGlobalOrNull(name); + }); + }); + ModuleUtils::iterDefinedFunctions(wasm, [&](Function* curr) { curr->name = fNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { return outputMergeable.wasm.getFunctionOrNull(name); }); - } - for (auto& curr : wasm.globals) { + }); + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* curr) { curr->name = gNames[curr->name] = getNonColliding(curr->name, [&](Name name) -> bool { return outputMergeable.wasm.getGlobalOrNull(name); }); - } + }); // update global names in input { @@ -403,20 +398,26 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp } // find function imports in output that are implemented in the input - for (auto& imp : outputMergeable.wasm.imports) { - if ((imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) && imp->module == ENV) { + ModuleUtils::iterImportedFunctions(outputMergeable.wasm, [&](Function* import) { + if (import->module == ENV) { for (auto& exp : wasm.exports) { - if (exp->kind == imp->kind && exp->name == imp->base) { - if (imp->kind == ExternalKind::Function) { - outputMergeable.implementedFunctionImports[imp->name] = fNames[exp->value]; - } else { - outputMergeable.implementedGlobalImports[imp->name] = gNames[exp->value]; - } + if (exp->name == import->base) { + outputMergeable.implementedFunctionImports[import->name] = fNames[exp->value]; break; } } } - } + }); + ModuleUtils::iterImportedGlobals(outputMergeable.wasm, [&](Global* import) { + if (import->module == ENV) { + for (auto& exp : wasm.exports) { + if (exp->name == import->base) { + outputMergeable.implementedGlobalImports[import->name] = gNames[exp->value]; + break; + } + } + } + }); // update the output before bringing anything in. avoid doing so when possible, as in the // common case the output module is very large. @@ -448,16 +449,19 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp for (auto& curr : wasm.functionTypes) { outputMergeable.wasm.addFunctionType(curr.release()); } - for (auto& curr : wasm.imports) { - if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) { - continue; // wasm has just 1 of each, they must match + for (auto& curr : wasm.globals) { + if (curr->imported()) { + outputMergeable.wasm.addGlobal(curr.release()); } - // update and add - if (curr->functionType.is()) { - curr->functionType = ftNames[curr->functionType]; - assert(curr->functionType.is()); + } + for (auto& curr : wasm.functions) { + if (curr->imported()) { + if (curr->type.is()) { + curr->type = ftNames[curr->type]; + assert(curr->type.is()); + } + outputMergeable.wasm.addFunction(curr.release()); } - outputMergeable.wasm.addImport(curr.release()); } for (auto& curr : wasm.exports) { if (curr->kind == ExternalKind::Memory || curr->kind == ExternalKind::Table) { @@ -477,13 +481,21 @@ struct InputMergeable : public ExpressionStackWalker<InputMergeable, Visitor<Inp } } } + // Copy over the remaining non-imports (we have already transferred + // the imports, and they are nullptrs). for (auto& curr : wasm.functions) { - curr->type = ftNames[curr->type]; - assert(curr->type.is()); - outputMergeable.wasm.addFunction(curr.release()); + if (curr) { + assert(!curr->imported()); + curr->type = ftNames[curr->type]; + assert(curr->type.is()); + outputMergeable.wasm.addFunction(curr.release()); + } } for (auto& curr : wasm.globals) { - outputMergeable.wasm.addGlobal(curr.release()); + if (curr) { + assert(!curr->imported()); + outputMergeable.wasm.addGlobal(curr.release()); + } } } diff --git a/src/tools/wasm-metadce.cpp b/src/tools/wasm-metadce.cpp index 4f90ee612..5caf8fea7 100644 --- a/src/tools/wasm-metadce.cpp +++ b/src/tools/wasm-metadce.cpp @@ -33,7 +33,7 @@ #include "support/colors.h" #include "wasm-io.h" #include "wasm-builder.h" -#include "ir/import-utils.h" +#include "ir/module-utils.h" using namespace wasm; @@ -69,8 +69,13 @@ struct MetaDCEGraph { return std::string(module.str) + " (*) " + std::string(base.str); } - ImportId getImportId(Name name) { - auto* imp = wasm.getImport(name); + ImportId getFunctionImportId(Name name) { + auto* imp = wasm.getFunction(name); + return getImportId(imp->module, imp->base); + } + + ImportId getGlobalImportId(Name name) { + auto* imp = wasm.getGlobal(name); return getImportId(imp->module, imp->base); } @@ -86,28 +91,33 @@ struct MetaDCEGraph { // Add an entry for everything we might need ahead of time, so parallel work // does not alter parent state, just adds to things pointed by it, independently // (each thread will add for one function, etc.) - for (auto& func : wasm.functions) { + ModuleUtils::iterDefinedFunctions(wasm, [&](Function* func) { auto dceName = getName("func", func->name.str); DCENodeToFunction[dceName] = func->name; functionToDCENode[func->name] = dceName; nodes[dceName] = DCENode(dceName); - } - for (auto& global : wasm.globals) { + }); + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) { auto dceName = getName("global", global->name.str); DCENodeToGlobal[dceName] = global->name; globalToDCENode[global->name] = dceName; nodes[dceName] = DCENode(dceName); - } - for (auto& imp : wasm.imports) { - // only process function and global imports - the table and memory are always there - if (imp->kind == ExternalKind::Function || imp->kind == ExternalKind::Global) { - auto id = getImportId(imp->module, imp->base); - if (importIdToDCENode.find(id) == importIdToDCENode.end()) { - auto dceName = getName("importId", imp->name.str); - importIdToDCENode[id] = dceName; - } + }); + // only process function and global imports - the table and memory are always there + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { + auto id = getImportId(import->module, import->base); + if (importIdToDCENode.find(id) == importIdToDCENode.end()) { + auto dceName = getName("importId", import->name.str); + importIdToDCENode[id] = dceName; } - } + }); + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + auto id = getImportId(import->module, import->base); + if (importIdToDCENode.find(id) == importIdToDCENode.end()) { + auto dceName = getName("importId", import->name.str); + importIdToDCENode[id] = dceName; + } + }); for (auto& exp : wasm.exports) { if (exportToDCENode.find(exp->name) == exportToDCENode.end()) { auto dceName = getName("export", exp->name.str); @@ -118,16 +128,16 @@ struct MetaDCEGraph { // we can also link the export to the thing being exported auto& node = nodes[exportToDCENode[exp->name]]; if (exp->kind == ExternalKind::Function) { - if (wasm.getFunctionOrNull(exp->value)) { + if (!wasm.getFunction(exp->value)->imported()) { node.reaches.push_back(functionToDCENode[exp->value]); } else { - node.reaches.push_back(importIdToDCENode[getImportId(exp->value)]); + node.reaches.push_back(importIdToDCENode[getFunctionImportId(exp->value)]); } } else if (exp->kind == ExternalKind::Global) { - if (wasm.getGlobalOrNull(exp->value)) { + if (!wasm.getGlobal(exp->value)->imported()) { node.reaches.push_back(globalToDCENode[exp->value]); } else { - node.reaches.push_back(importIdToDCENode[getImportId(exp->value)]); + node.reaches.push_back(importIdToDCENode[getGlobalImportId(exp->value)]); } } } @@ -150,12 +160,12 @@ struct MetaDCEGraph { void handleGlobal(Name name) { Name dceName; - if (getModule()->getGlobalOrNull(name)) { - // its a global + if (!getModule()->getGlobal(name)->imported()) { + // its a defined global dceName = parent->globalToDCENode[name]; } else { // it's an import. - dceName = parent->importIdToDCENode[parent->getImportId(name)]; + dceName = parent->importIdToDCENode[parent->getGlobalImportId(name)]; } if (parentDceName.isNull()) { parent->roots.insert(parentDceName); @@ -164,11 +174,11 @@ struct MetaDCEGraph { } } }; - for (auto& global : wasm.globals) { + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) { InitScanner scanner(this, globalToDCENode[global->name]); scanner.setModule(&wasm); scanner.walk(global->init); - } + }); // we can't remove segments, so root what they need InitScanner rooter(this, Name()); rooter.setModule(&wasm); @@ -176,10 +186,10 @@ struct MetaDCEGraph { // TODO: currently, all functions in the table are roots, but we // should add an option to refine that for (auto& name : segment.data) { - if (wasm.getFunctionOrNull(name)) { + if (!wasm.getFunction(name)->imported()) { roots.insert(functionToDCENode[name]); } else { - roots.insert(importIdToDCENode[getImportId(name)]); + roots.insert(importIdToDCENode[getFunctionImportId(name)]); } } rooter.walk(segment.offset); @@ -199,15 +209,16 @@ struct MetaDCEGraph { } void visitCall(Call* curr) { - parent->nodes[parent->functionToDCENode[getFunction()->name]].reaches.push_back( - parent->functionToDCENode[curr->target] - ); - } - void visitCallImport(CallImport* curr) { - assert(parent->functionToDCENode.count(getFunction()->name) > 0); - parent->nodes[parent->functionToDCENode[getFunction()->name]].reaches.push_back( - parent->importIdToDCENode[parent->getImportId(curr->target)] - ); + if (!getModule()->getFunction(curr->target)->imported()) { + parent->nodes[parent->functionToDCENode[getFunction()->name]].reaches.push_back( + parent->functionToDCENode[curr->target] + ); + } else { + assert(parent->functionToDCENode.count(getFunction()->name) > 0); + parent->nodes[parent->functionToDCENode[getFunction()->name]].reaches.push_back( + parent->importIdToDCENode[parent->getFunctionImportId(curr->target)] + ); + } } void visitGetGlobal(GetGlobal* curr) { handleGlobal(curr->name); @@ -222,12 +233,12 @@ struct MetaDCEGraph { void handleGlobal(Name name) { if (!getFunction()) return; // non-function stuff (initializers) are handled separately Name dceName; - if (getModule()->getGlobalOrNull(name)) { + if (!getModule()->getGlobal(name)->imported()) { // its a global dceName = parent->globalToDCENode[name]; } else { // it's an import. - dceName = parent->importIdToDCENode[parent->getImportId(name)]; + dceName = parent->importIdToDCENode[parent->getGlobalImportId(name)]; } parent->nodes[parent->functionToDCENode[getFunction()->name]].reaches.push_back(dceName); } diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 2a59be167..9ad1fa799 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -146,24 +146,29 @@ static void run_asserts(Name moduleName, size_t* i, bool* checked, Module* wasm, } if (!invalid && id == ASSERT_UNLINKABLE) { // validate "instantiating" the mdoule - for (auto& import : wasm.imports) { + auto reportUnknownImport = [&](Importable* import) { + std::cerr << "unknown import: " << import->module << '.' << import->base << '\n'; + invalid = true; + }; + ModuleUtils::iterImportedGlobals(wasm, reportUnknownImport); + ModuleUtils::iterImportedFunctions(wasm, [&](Importable* import) { if (import->module == SPECTEST && import->base == PRINT) { - if (import->kind != ExternalKind::Function) { - std::cerr << "spectest.print should be a function, but is " << int32_t(import->kind) << '\n'; - invalid = true; - break; - } + // We can handle it. } else { - std::cerr << "unknown import: " << import->module << '.' << import->base << '\n'; - invalid = true; - break; + reportUnknownImport(import); } + }); + if (wasm.memory.imported()) { + reportUnknownImport(&wasm.memory); + } + if (wasm.table.imported()) { + reportUnknownImport(&wasm.table); } for (auto& segment : wasm.table.segments) { for (auto name : segment.data) { // spec tests consider it illegal to use spectest.print in a table - if (auto* import = wasm.getImportOrNull(name)) { - if (import->module == SPECTEST && import->base == PRINT) { + if (auto* import = wasm.getFunction(name)) { + if (import->imported() && import->module == SPECTEST && import->base == PRINT) { std::cerr << "cannot put spectest.print in table\n"; invalid = true; } diff --git a/src/wasm-binary.h b/src/wasm-binary.h index 05c883acb..2a455b808 100644 --- a/src/wasm-binary.h +++ b/src/wasm-binary.h @@ -32,6 +32,7 @@ #include "wasm-builder.h" #include "parsing.h" #include "wasm-validator.h" +#include "ir/import-utils.h" namespace wasm { @@ -763,6 +764,8 @@ private: size_t sourceMapLocationsSizeAtSectionStart; Function::DebugLocation lastDebugLocation; + std::unique_ptr<ImportInfo> importInfo; + void prepare(); }; @@ -833,9 +836,8 @@ public: // We read functions before we know their names, so we need to backpatch the names later std::vector<Function*> functions; // we store functions here before wasm.addFunction after we know their names - std::vector<Import*> functionImports; // we store function imports here before wasm.addFunctionImport after we know their names - std::map<Index, std::vector<Call*>> functionCalls; // at index i we have all calls to the defined function i - std::map<Index, std::vector<CallImport*>> functionImportCalls; // at index i we have all callImports to the imported function i + std::vector<Function*> functionImports; // we store function imports here before wasm.addFunctionImport after we know their names + std::map<Index, std::vector<Call*>> functionCalls; // at index i we have all calls to the function i Function* currFunction = nullptr; Index endOfFunction = -1; // before we see a function (like global init expressions), there is no end of function to check @@ -924,18 +926,7 @@ public: void visitBreak(Break *curr, uint8_t code); void visitSwitch(Switch* curr); - template<typename T> - void fillCall(T* call, FunctionType* type) { - assert(type); - auto num = type->params.size(); - call->operands.resize(num); - for (size_t i = 0; i < num; i++) { - call->operands[num - i - 1] = popNonVoidExpression(); - } - call->type = type->result; - } - - Expression* visitCall(); + void visitCall(Call* curr); void visitCallIndirect(CallIndirect* curr); void visitGetLocal(GetLocal* curr); void visitSetLocal(SetLocal *curr, uint8_t code); diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 4032eb10a..f59646869 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -167,13 +167,6 @@ public: call->operands.set(args); return call; } - CallImport* makeCallImport(Name target, const std::vector<Expression*>& args, Type type) { - auto* call = allocator.alloc<CallImport>(); - call->type = type; // similar to makeCall, for consistency - call->target = target; - call->operands.set(args); - return call; - } template<typename T> Call* makeCall(Name target, const T& args, Type type) { auto* call = allocator.alloc<Call>(); @@ -182,14 +175,6 @@ public: call->operands.set(args); return call; } - template<typename T> - CallImport* makeCallImport(Name target, const T& args, Type type) { - auto* call = allocator.alloc<CallImport>(); - call->type = type; // similar to makeCall, for consistency - call->target = target; - call->operands.set(args); - return call; - } CallIndirect* makeCallIndirect(FunctionType* type, Expression* target, const std::vector<Expression*>& args) { auto* call = allocator.alloc<CallIndirect>(); call->fullType = type->name; diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index 820ba0049..8bd951eef 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -31,6 +31,7 @@ #include "support/safe_integer.h" #include "wasm.h" #include "wasm-traversal.h" +#include "ir/module-utils.h" #ifdef WASM_INTERPRETER_DEBUG #include "wasm-printing.h" @@ -511,7 +512,6 @@ public: Flow visitLoop(Loop* curr) { WASM_UNREACHABLE(); } Flow visitCall(Call* curr) { WASM_UNREACHABLE(); } - Flow visitCallImport(CallImport* curr) { WASM_UNREACHABLE(); } Flow visitCallIndirect(CallIndirect* curr) { WASM_UNREACHABLE(); } Flow visitGetLocal(GetLocal *curr) { WASM_UNREACHABLE(); } Flow visitSetLocal(SetLocal *curr) { WASM_UNREACHABLE(); } @@ -545,7 +545,7 @@ public: struct ExternalInterface { virtual void init(Module& wasm, SubType& instance) {} virtual void importGlobals(GlobalManager& globals, Module& wasm) = 0; - virtual Literal callImport(Import* import, LiteralList& arguments) = 0; + virtual Literal callImport(Function* import, LiteralList& arguments) = 0; virtual Literal callTable(Index index, LiteralList& arguments, Type result, SubType& instance) = 0; virtual void growMemory(Address oldSize, Address newSize) = 0; virtual void trap(const char* why) = 0; @@ -636,9 +636,9 @@ public: // prepare memory memorySize = wasm.memory.initial; // generate internal (non-imported) globals - for (auto& global : wasm.globals) { + ModuleUtils::iterDefinedGlobals(wasm, [&](Global* global) { globals[global->name] = ConstantExpressionRunner<GlobalManager>(globals).visit(global->init).value; - } + }); // initialize the rest of the external interface externalInterface->init(wasm, *self()); // run start, if present @@ -757,19 +757,18 @@ public: LiteralList arguments; Flow flow = generateArguments(curr->operands, arguments); if (flow.breaking()) return flow; - Flow ret = instance.callFunctionInternal(curr->target, arguments); + auto* func = instance.wasm.getFunction(curr->target); + Flow ret; + if (func->imported()) { + ret = instance.externalInterface->callImport(func, arguments); + } else { + ret = instance.callFunctionInternal(curr->target, arguments); + } #ifdef WASM_INTERPRETER_DEBUG std::cout << "(returned to " << scope.function->name << ")\n"; #endif return ret; } - Flow visitCallImport(CallImport *curr) { - NOTE_ENTER("CallImport"); - LiteralList arguments; - Flow flow = generateArguments(curr->operands, arguments); - if (flow.breaking()) return flow; - return instance.externalInterface->callImport(instance.wasm.getImport(curr->target), arguments); - } Flow visitCallIndirect(CallIndirect *curr) { NOTE_ENTER("CallIndirect"); LiteralList arguments; diff --git a/src/wasm-js.cpp b/src/wasm-js.cpp index 5bcd7e133..a777295d7 100644 --- a/src/wasm-js.cpp +++ b/src/wasm-js.cpp @@ -28,6 +28,7 @@ #include "wasm-s-parser.h" #include "wasm-binary.h" #include "wasm-printing.h" +#include "ir/module-utils.h" using namespace cashew; using namespace wasm; @@ -158,15 +159,15 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { } } - // verify imports are provided - for (auto& import : module->imports) { + auto verifyImportIsProvided = [&](Importable* import) { EM_ASM_({ var mod = Pointer_stringify($0); var base = Pointer_stringify($1); - var name = Pointer_stringify($2); - assert(Module['lookupImport'](mod, base) !== undefined, 'checking import ' + name + ' = ' + mod + '.' + base); - }, import->module.str, import->base.str, import->name.str); - } + assert(Module['lookupImport'](mod, base) !== undefined, 'checking import ' + mod + '.' + base); + }, import->module.str, import->base.str); + }; + ModuleUtils::iterImportedFunctions(*module, verifyImportIsProvided); + ModuleUtils::iterImportedGlobals(*module, verifyImportIsProvided); if (wasmJSDebug) std::cerr << "creating instance...\n"; @@ -176,24 +177,15 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { void init(Module& wasm, ModuleInstance& instance) override { module = &wasm; // look for imported memory - { - bool found = false; - for (auto& import : wasm.imports) { - if (import->module == ENV && import->base == MEMORY) { - assert(import->kind == ExternalKind::Memory); - // memory is imported - EM_ASM({ - Module['asmExports']['memory'] = Module['lookupImport']('env', 'memory'); - }); - found = true; - } - } - if (!found) { - // no memory import; create a new buffer here, just like native wasm support would. - EM_ASM_({ - Module['asmExports']['memory'] = Module['outside']['newBuffer'] = new ArrayBuffer($0); - }, wasm.memory.initial * Memory::kPageSize); - } + if (wasm.memory.imported()) { + EM_ASM({ + Module['asmExports']['memory'] = Module['lookupImport']('env', 'memory'); + }); + } else { + // no memory import; create a new buffer here, just like native wasm support would. + EM_ASM_({ + Module['asmExports']['memory'] = Module['outside']['newBuffer'] = new ArrayBuffer($0); + }, wasm.memory.initial * Memory::kPageSize); } for (auto segment : wasm.memory.segments) { EM_ASM_({ @@ -203,24 +195,15 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { }, ConstantExpressionRunner<TrivialGlobalManager>(instance.globals).visit(segment.offset).value.geti32(), &segment.data[0], segment.data.size()); } // look for imported table - { - bool found = false; - for (auto& import : wasm.imports) { - if (import->module == ENV && import->base == TABLE) { - assert(import->kind == ExternalKind::Table); - // table is imported - EM_ASM({ - Module['outside']['wasmTable'] = Module['lookupImport']('env', 'table'); - }); - found = true; - } - } - if (!found) { - // no table import; create a new one here, just like native wasm support would. - EM_ASM_({ - Module['outside']['wasmTable'] = new Array($0); - }, wasm.table.initial); - } + if (wasm.table.imported()) { + EM_ASM({ + Module['outside']['wasmTable'] = Module['lookupImport']('env', 'table'); + }); + } else { + // no table import; create a new one here, just like native wasm support would. + EM_ASM_({ + Module['outside']['wasmTable'] = new Array($0); + }, wasm.table.initial); } EM_ASM({ Module['asmExports']['table'] = Module['outside']['wasmTable']; @@ -232,16 +215,15 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { assert(offset + segment.data.size() <= wasm.table.initial); for (size_t i = 0; i != segment.data.size(); ++i) { Name name = segment.data[i]; - auto* func = wasm.getFunctionOrNull(name); - if (func) { + auto* func = wasm.getFunction(name); + if (!func->imported()) { EM_ASM_({ Module['outside']['wasmTable'][$0] = $1; }, offset + i, func); } else { - auto* import = wasm.getImport(name); EM_ASM_({ Module['outside']['wasmTable'][$0] = Module['lookupImport'](Pointer_stringify($1), Pointer_stringify($2)); - }, offset + i, import->module.str, import->base.str); + }, offset + i, func->module.str, func->base.str); } } } @@ -275,23 +257,21 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { } void importGlobals(std::map<Name, Literal>& globals, Module& wasm) override { - for (auto& import : wasm.imports) { - if (import->kind == ExternalKind::Global) { - double ret = EM_ASM_DOUBLE({ - var mod = Pointer_stringify($0); - var base = Pointer_stringify($1); - var lookup = Module['lookupImport'](mod, base); - return lookup; - }, import->module.str, import->base.str); - - if (wasmJSDebug) std::cout << "calling importGlobal for " << import->name << " returning " << ret << '\n'; - - globals[import->name] = getResultFromJS(ret, import->globalType); - } - } + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + double ret = EM_ASM_DOUBLE({ + var mod = Pointer_stringify($0); + var base = Pointer_stringify($1); + var lookup = Module['lookupImport'](mod, base); + return lookup; + }, import->module.str, import->base.str); + + if (wasmJSDebug) std::cout << "calling importGlobal for " << import->name << " returning " << ret << '\n'; + + globals[import->name] = getResultFromJS(ret, import->type); + }); } - Literal callImport(Import *import, LiteralList& arguments) override { + Literal callImport(Function *import, LiteralList& arguments) override { if (wasmJSDebug) std::cout << "calling import " << import->name.str << '\n'; prepareTempArgments(arguments); double ret = EM_ASM_DOUBLE({ @@ -303,9 +283,9 @@ extern "C" void EMSCRIPTEN_KEEPALIVE instantiate() { return lookup.apply(null, tempArguments); }, import->module.str, import->base.str); - if (wasmJSDebug) std::cout << "calling import returning " << ret << " and function type is " << module->getFunctionType(import->functionType)->result << '\n'; + if (wasmJSDebug) std::cout << "calling import returning " << ret << " and function type is " << module->getFunctionType(import->type)->result << '\n'; - return getResultFromJS(ret, module->getFunctionType(import->functionType)->result); + return getResultFromJS(ret, module->getFunctionType(import->type)->result); } Literal callTable(Index index, LiteralList& arguments, Type result, ModuleInstance& instance) override { diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h index d492d0c3c..c27f63f1b 100644 --- a/src/wasm-s-parser.h +++ b/src/wasm-s-parser.h @@ -192,7 +192,6 @@ private: Expression* makeMaybeBlock(Element& s, size_t i, Type type); Expression* makeLoop(Element& s); Expression* makeCall(Element& s); - Expression* makeCallImport(Element& s); Expression* makeCallIndirect(Element& s); template<class T> void parseCallOperands(Element& s, Index i, Index j, T* call) { diff --git a/src/wasm-stack.h b/src/wasm-stack.h index 91ea3f3be..dc3416789 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -124,7 +124,6 @@ public: void visitBreak(Break* curr); void visitSwitch(Switch* curr); void visitCall(Call* curr); - void visitCallImport(CallImport* curr); void visitCallIndirect(CallIndirect* curr); void visitGetLocal(GetLocal* curr); void visitSetLocal(SetLocal* curr); @@ -545,16 +544,6 @@ void StackWriter<Mode, Parent>::visitCall(Call* curr) { } template<StackWriterMode Mode, typename Parent> -void StackWriter<Mode, Parent>::visitCallImport(CallImport* curr) { - if (debug) std::cerr << "zz node: CallImport" << std::endl; - for (auto* operand : curr->operands) { - visitChild(operand); - } - if (justAddToStack(curr)) return; - o << int8_t(BinaryConsts::CallFunction) << U32LEB(parent.getFunctionIndex(curr->target)); -} - -template<StackWriterMode Mode, typename Parent> void StackWriter<Mode, Parent>::visitCallIndirect(CallIndirect* curr) { if (debug) std::cerr << "zz node: CallIndirect" << std::endl; for (auto* operand : curr->operands) { diff --git a/src/wasm-traversal.h b/src/wasm-traversal.h index 0c775e872..5bb176756 100644 --- a/src/wasm-traversal.h +++ b/src/wasm-traversal.h @@ -43,7 +43,6 @@ struct Visitor { ReturnType visitBreak(Break* curr) { return ReturnType(); } ReturnType visitSwitch(Switch* curr) { return ReturnType(); } ReturnType visitCall(Call* curr) { return ReturnType(); } - ReturnType visitCallImport(CallImport* curr) { return ReturnType(); } ReturnType visitCallIndirect(CallIndirect* curr) { return ReturnType(); } ReturnType visitGetLocal(GetLocal* curr) { return ReturnType(); } ReturnType visitSetLocal(SetLocal* curr) { return ReturnType(); } @@ -66,7 +65,6 @@ struct Visitor { ReturnType visitUnreachable(Unreachable* curr) { return ReturnType(); } // Module-level visitors ReturnType visitFunctionType(FunctionType* curr) { return ReturnType(); } - ReturnType visitImport(Import* curr) { return ReturnType(); } ReturnType visitExport(Export* curr) { return ReturnType(); } ReturnType visitGlobal(Global* curr) { return ReturnType(); } ReturnType visitFunction(Function* curr) { return ReturnType(); } @@ -88,7 +86,6 @@ struct Visitor { case Expression::Id::BreakId: DELEGATE(Break); case Expression::Id::SwitchId: DELEGATE(Switch); case Expression::Id::CallId: DELEGATE(Call); - case Expression::Id::CallImportId: DELEGATE(CallImport); case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); case Expression::Id::GetLocalId: DELEGATE(GetLocal); case Expression::Id::SetLocalId: DELEGATE(SetLocal); @@ -134,7 +131,6 @@ struct OverriddenVisitor { UNIMPLEMENTED(Break); UNIMPLEMENTED(Switch); UNIMPLEMENTED(Call); - UNIMPLEMENTED(CallImport); UNIMPLEMENTED(CallIndirect); UNIMPLEMENTED(GetLocal); UNIMPLEMENTED(SetLocal); @@ -156,7 +152,6 @@ struct OverriddenVisitor { UNIMPLEMENTED(Nop); UNIMPLEMENTED(Unreachable); UNIMPLEMENTED(FunctionType); - UNIMPLEMENTED(Import); UNIMPLEMENTED(Export); UNIMPLEMENTED(Global); UNIMPLEMENTED(Function); @@ -180,7 +175,6 @@ struct OverriddenVisitor { case Expression::Id::BreakId: DELEGATE(Break); case Expression::Id::SwitchId: DELEGATE(Switch); case Expression::Id::CallId: DELEGATE(Call); - case Expression::Id::CallImportId: DELEGATE(CallImport); case Expression::Id::CallIndirectId: DELEGATE(CallIndirect); case Expression::Id::GetLocalId: DELEGATE(GetLocal); case Expression::Id::SetLocalId: DELEGATE(SetLocal); @@ -224,7 +218,6 @@ struct UnifiedExpressionVisitor : public Visitor<SubType, ReturnType> { ReturnType visitBreak(Break* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitSwitch(Switch* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitCall(Call* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } - ReturnType visitCallImport(CallImport* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitCallIndirect(CallIndirect* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitGetLocal(GetLocal* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } ReturnType visitSetLocal(SetLocal* curr) { return static_cast<SubType*>(this)->visitExpression(curr); } @@ -340,17 +333,22 @@ struct Walker : public VisitorType { for (auto& curr : module->functionTypes) { self->visitFunctionType(curr.get()); } - for (auto& curr : module->imports) { - self->visitImport(curr.get()); - } for (auto& curr : module->exports) { self->visitExport(curr.get()); } for (auto& curr : module->globals) { - self->walkGlobal(curr.get()); + if (curr->imported()) { + self->visitGlobal(curr.get()); + } else { + self->walkGlobal(curr.get()); + } } for (auto& curr : module->functions) { - self->walkFunction(curr.get()); + if (curr->imported()) { + self->visitFunction(curr.get()); + } else { + self->walkFunction(curr.get()); + } } self->walkTable(&module->table); self->walkMemory(&module->memory); @@ -405,7 +403,6 @@ struct Walker : public VisitorType { static void doVisitBreak(SubType* self, Expression** currp) { self->visitBreak((*currp)->cast<Break>()); } static void doVisitSwitch(SubType* self, Expression** currp) { self->visitSwitch((*currp)->cast<Switch>()); } static void doVisitCall(SubType* self, Expression** currp) { self->visitCall((*currp)->cast<Call>()); } - static void doVisitCallImport(SubType* self, Expression** currp) { self->visitCallImport((*currp)->cast<CallImport>()); } static void doVisitCallIndirect(SubType* self, Expression** currp) { self->visitCallIndirect((*currp)->cast<CallIndirect>()); } static void doVisitGetLocal(SubType* self, Expression** currp) { self->visitGetLocal((*currp)->cast<GetLocal>()); } static void doVisitSetLocal(SubType* self, Expression** currp) { self->visitSetLocal((*currp)->cast<SetLocal>()); } @@ -493,14 +490,6 @@ struct PostWalker : public Walker<SubType, VisitorType> { } break; } - case Expression::Id::CallImportId: { - self->pushTask(SubType::doVisitCallImport, currp); - auto& list = curr->cast<CallImport>()->operands; - for (int i = int(list.size()) - 1; i >= 0; i--) { - self->pushTask(SubType::scan, &list[i]); - } - break; - } case Expression::Id::CallIndirectId: { self->pushTask(SubType::doVisitCallIndirect, currp); auto& list = curr->cast<CallIndirect>()->operands; diff --git a/src/wasm.h b/src/wasm.h index 09c81e500..57591d811 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -160,7 +160,6 @@ public: BreakId, SwitchId, CallId, - CallImportId, CallIndirectId, GetLocalId, SetLocalId, @@ -324,16 +323,6 @@ public: void finalize(); }; -class CallImport : public SpecificExpression<Expression::CallImportId> { -public: - CallImport(MixedArena& allocator) : operands(allocator) {} - - ExpressionList operands; - Name target; - - void finalize(); -}; - class FunctionType { public: Name name; @@ -582,6 +571,15 @@ public: // Globals +struct Importable { + // If these are set, then this is an import, as module.base + Name module, base; + + bool imported() { + return module.is(); + } +}; + // Forward declarations of Stack IR, as functions can contain it, see // the stackIR property. // Stack IR is a secondary IR to the main IR defined in this file (Binaryen @@ -589,16 +587,16 @@ public: class StackInst; typedef std::vector<StackInst*> StackIR; -class Function { +class Function : public Importable { public: Name name; - Type result; + Type result = none; std::vector<Type> params; // function locals are std::vector<Type> vars; // params plus vars Name type; // if null, it is implicit in params and result // The body of the function - Expression* body; + Expression* body = nullptr; // If present, this stack IR was generated from the main Binaryen IR body, // and possibly optimized. If it is present when writing to wasm binary, @@ -627,8 +625,6 @@ public: std::set<DebugLocation> prologLocation; std::set<DebugLocation> epilogLocation; - Function() : result(none), prologLocation(), epilogLocation() {} - size_t getNumParams(); size_t getNumVars(); size_t getNumLocals(); @@ -647,21 +643,13 @@ public: bool hasLocalName(Index index) const; }; +// The kind of an import or export. enum class ExternalKind { Function = 0, Table = 1, Memory = 2, - Global = 3 -}; - -class Import { -public: - Import() : globalType(none) {} - - Name name, module, base; // name = module.base - ExternalKind kind; - Name functionType; // for Function imports - Type globalType; // for Global imports + Global = 3, + Invalid = -1 }; class Export { @@ -671,7 +659,7 @@ public: ExternalKind kind; }; -class Table { +class Table : public Importable { public: static const Address::address_t kPageSize = 1; static const Index kMaxSize = Index(-1); @@ -689,18 +677,17 @@ public: // Currently the wasm object always 'has' one Table. It 'exists' if it has been defined or imported. // The table can exist but be empty and have no defined initial or max size. bool exists; - bool imported; Name name; Address initial, max; std::vector<Segment> segments; - Table() : exists(false), imported(false), initial(0), max(kMaxSize) { + Table() : exists(false), initial(0), max(kMaxSize) { name = Name::fromInt(0); } bool hasMax() { return max != kMaxSize; } }; -class Memory { +class Memory : public Importable { public: static const Address::address_t kPageSize = 64 * 1024; static const Address::address_t kMaxSize = ~Address::address_t(0) / kPageSize; @@ -726,21 +713,20 @@ public: // See comment in Table. bool exists; - bool imported; bool shared; - Memory() : initial(0), max(kMaxSize), exists(false), imported(false), shared(false) { + Memory() : initial(0), max(kMaxSize), exists(false), shared(false) { name = Name::fromInt(0); } bool hasMax() { return max != kMaxSize; } }; -class Global { +class Global : public Importable { public: Name name; Type type; Expression* init; - bool mutable_; + bool mutable_ = false; }; // "Opaque" data, not part of the core wasm spec, that is held in binaries. @@ -755,7 +741,6 @@ class Module { public: // wasm contents (generally you shouldn't access these from outside, except maybe for iterating; use add*() and the get() functions) std::vector<std::unique_ptr<FunctionType>> functionTypes; - std::vector<std::unique_ptr<Import>> imports; std::vector<std::unique_ptr<Export>> exports; std::vector<std::unique_ptr<Function>> functions; std::vector<std::unique_ptr<Global>> globals; @@ -772,7 +757,6 @@ public: private: // TODO: add a build option where Names are just indices, and then these methods are not needed std::map<Name, FunctionType*> functionTypesMap; - std::map<Name, Import*> importsMap; std::map<Name, Export*> exportsMap; // exports map is by the *exported* name, which is unique std::map<Name, Function*> functionsMap; std::map<Name, Global*> globalsMap; @@ -781,30 +765,26 @@ public: Module() {}; FunctionType* getFunctionType(Name name); - Import* getImport(Name name); Export* getExport(Name name); Function* getFunction(Name name); Global* getGlobal(Name name); FunctionType* getFunctionTypeOrNull(Name name); - Import* getImportOrNull(Name name); Export* getExportOrNull(Name name); Function* getFunctionOrNull(Name name); Global* getGlobalOrNull(Name name); void addFunctionType(FunctionType* curr); - void addImport(Import* curr); void addExport(Export* curr); void addFunction(Function* curr); void addGlobal(Global* curr); void addStart(const Name& s); - void removeImport(Name name); + void removeFunctionType(Name name); void removeExport(Name name); void removeFunction(Name name); - void removeFunctionType(Name name); - // TODO: remove* for other elements + void removeGlobal(Name name); void updateMaps(); }; diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index b643eaec8..1839c9cbd 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -35,6 +35,8 @@ void WasmBinaryWriter::prepare() { ModuleUtils::BinaryIndexes indexes(*wasm); mappedFunctions = std::move(indexes.functionIndexes); mappedGlobals = std::move(indexes.globalIndexes); + + importInfo = wasm::make_unique<ImportInfo>(*wasm); } void WasmBinaryWriter::write() { @@ -136,7 +138,7 @@ void WasmBinaryWriter::writeStart() { } void WasmBinaryWriter::writeMemory() { - if (!wasm->memory.exists || wasm->memory.imported) return; + if (!wasm->memory.exists || wasm->memory.imported()) return; if (debug) std::cerr << "== writeMemory" << std::endl; auto start = startSection(BinaryConsts::Section::Memory); o << U32LEB(1); // Define 1 memory @@ -176,46 +178,54 @@ int32_t WasmBinaryWriter::getFunctionTypeIndex(Name type) { } void WasmBinaryWriter::writeImports() { - if (wasm->imports.size() == 0) return; + auto num = importInfo->getNumImports(); + if (num == 0) return; if (debug) std::cerr << "== writeImports" << std::endl; auto start = startSection(BinaryConsts::Section::Import); - o << U32LEB(wasm->imports.size()); - for (auto& import : wasm->imports) { - if (debug) std::cerr << "write one" << std::endl; + o << U32LEB(num); + auto writeImportHeader = [&](Importable* import) { writeInlineString(import->module.str); writeInlineString(import->base.str); - o << U32LEB(int32_t(import->kind)); - switch (import->kind) { - case ExternalKind::Function: o << U32LEB(getFunctionTypeIndex(import->functionType)); break; - case ExternalKind::Table: { - o << S32LEB(BinaryConsts::EncodedType::AnyFunc); - writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.max != Table::kMaxSize, /*shared=*/false); - break; - } - case ExternalKind::Memory: { - writeResizableLimits(wasm->memory.initial, wasm->memory.max, - wasm->memory.max != Memory::kMaxSize, wasm->memory.shared); - break; - } - case ExternalKind::Global: - o << binaryType(import->globalType); - o << U32LEB(0); // Mutable global's can't be imported for now. - break; - default: WASM_UNREACHABLE(); - } + }; + ModuleUtils::iterImportedFunctions(*wasm, [&](Function* func) { + if (debug) std::cerr << "write one function" << std::endl; + writeImportHeader(func); + o << U32LEB(int32_t(ExternalKind::Function)); + o << U32LEB(getFunctionTypeIndex(func->type)); + }); + ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) { + if (debug) std::cerr << "write one global" << std::endl; + writeImportHeader(global); + o << U32LEB(int32_t(ExternalKind::Global)); + o << binaryType(global->type); + o << U32LEB(0); // Mutable globals can't be imported for now. + }); + if (wasm->memory.imported()) { + if (debug) std::cerr << "write one memory" << std::endl; + writeImportHeader(&wasm->memory); + o << U32LEB(int32_t(ExternalKind::Memory)); + writeResizableLimits(wasm->memory.initial, wasm->memory.max, + wasm->memory.max != Memory::kMaxSize, wasm->memory.shared); + } + if (wasm->table.imported()) { + if (debug) std::cerr << "write one table" << std::endl; + writeImportHeader(&wasm->table); + o << U32LEB(int32_t(ExternalKind::Table)); + o << S32LEB(BinaryConsts::EncodedType::AnyFunc); + writeResizableLimits(wasm->table.initial, wasm->table.max, wasm->table.max != Table::kMaxSize, /*shared=*/false); } finishSection(start); } void WasmBinaryWriter::writeFunctionSignatures() { - if (wasm->functions.size() == 0) return; + if (importInfo->getNumDefinedFunctions() == 0) return; if (debug) std::cerr << "== writeFunctionSignatures" << std::endl; auto start = startSection(BinaryConsts::Section::Function); - o << U32LEB(wasm->functions.size()); - for (auto& curr : wasm->functions) { + o << U32LEB(importInfo->getNumDefinedFunctions()); + ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { if (debug) std::cerr << "write one" << std::endl; - o << U32LEB(getFunctionTypeIndex(curr->type)); - } + o << U32LEB(getFunctionTypeIndex(func->type)); + }); finishSection(start); } @@ -224,17 +234,15 @@ void WasmBinaryWriter::writeExpression(Expression* curr) { } void WasmBinaryWriter::writeFunctions() { - if (wasm->functions.size() == 0) return; + if (importInfo->getNumDefinedFunctions() == 0) return; if (debug) std::cerr << "== writeFunctions" << std::endl; auto start = startSection(BinaryConsts::Section::Code); - size_t total = wasm->functions.size(); - o << U32LEB(total); - for (size_t i = 0; i < total; i++) { + o << U32LEB(importInfo->getNumDefinedFunctions()); + ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { size_t sourceMapLocationsSizeAtFunctionStart = sourceMapLocations.size(); if (debug) std::cerr << "write one at" << o.size() << std::endl; size_t sizePos = writeU32LEBPlaceholder(); size_t start = o.size(); - Function* func = wasm->functions[i].get(); if (debug) std::cerr << "writing" << func->name << std::endl; // Emit Stack IR if present, and if we can if (func->stackIR && !sourceMap) { @@ -261,22 +269,23 @@ void WasmBinaryWriter::writeFunctions() { } } tableOfContents.functionBodies.emplace_back(func->name, sizePos + sizeFieldSize, size); - } + }); finishSection(start); } void WasmBinaryWriter::writeGlobals() { - if (wasm->globals.size() == 0) return; + if (importInfo->getNumDefinedGlobals() == 0) return; if (debug) std::cerr << "== writeglobals" << std::endl; auto start = startSection(BinaryConsts::Section::Global); - o << U32LEB(wasm->globals.size()); - for (auto& curr : wasm->globals) { + auto num = importInfo->getNumDefinedGlobals(); + o << U32LEB(num); + ModuleUtils::iterDefinedGlobals(*wasm, [&](Global* global) { if (debug) std::cerr << "write one" << std::endl; - o << binaryType(curr->type); - o << U32LEB(curr->mutable_); - writeExpression(curr->init); + o << binaryType(global->type); + o << U32LEB(global->mutable_); + writeExpression(global->init); o << int8_t(BinaryConsts::End); - } + }); finishSection(start); } @@ -404,7 +413,7 @@ uint32_t WasmBinaryWriter::getGlobalIndex(Name name) { } void WasmBinaryWriter::writeFunctionTableDeclaration() { - if (!wasm->table.exists || wasm->table.imported) return; + if (!wasm->table.exists || wasm->table.imported()) return; if (debug) std::cerr << "== writeFunctionTableDeclaration" << std::endl; auto start = startSection(BinaryConsts::Section::Table); o << U32LEB(1); // Declare 1 table. @@ -436,14 +445,6 @@ void WasmBinaryWriter::writeNames() { if (wasm->functions.size() > 0) { hasContents = true; getFunctionIndex(wasm->functions[0]->name); // generate mappedFunctions - } else { - for (auto& import : wasm->imports) { - if (import->kind == ExternalKind::Function) { - hasContents = true; - getFunctionIndex(import->name); // generate mappedFunctions - break; - } - } } if (!hasContents) return; if (debug) std::cerr << "== writeNames" << std::endl; @@ -452,18 +453,13 @@ void WasmBinaryWriter::writeNames() { auto substart = startSubsection(BinaryConsts::UserSections::Subsection::NameFunction); o << U32LEB(mappedFunctions.size()); Index emitted = 0; - for (auto& import : wasm->imports) { - if (import->kind == ExternalKind::Function) { - o << U32LEB(emitted); - writeEscapedName(import->name.str); - emitted++; - } - } - for (auto& curr : wasm->functions) { + auto add = [&](Function* curr) { o << U32LEB(emitted); writeEscapedName(curr->name.str); emitted++; - } + }; + ModuleUtils::iterImportedFunctions(*wasm, add); + ModuleUtils::iterDefinedFunctions(*wasm, add); assert(emitted == mappedFunctions.size()); finishSubsection(substart); /* TODO: locals */ @@ -480,14 +476,11 @@ void WasmBinaryWriter::writeSourceMapUrl() { void WasmBinaryWriter::writeSymbolMap() { std::ofstream file(symbolMap); - for (auto& import : wasm->imports) { - if (import->kind == ExternalKind::Function) { - file << getFunctionIndex(import->name) << ":" << import->name.str << std::endl; - } - } - for (auto& func : wasm->functions) { + auto write = [&](Function* func) { file << getFunctionIndex(func->name) << ":" << func->name.str << std::endl; - } + }; + ModuleUtils::iterImportedFunctions(*wasm, write); + ModuleUtils::iterDefinedFunctions(*wasm, write); file.close(); } @@ -940,17 +933,7 @@ void WasmBinaryBuilder::readSignatures() { } Name WasmBinaryBuilder::getFunctionIndexName(Index i) { - if (i < functionImports.size()) { - auto* import = functionImports[i]; - assert(import->kind == ExternalKind::Function); - return import->name; - } else { - i -= functionImports.size(); - if (i >= wasm.functions.size()) { - throwError("bad function index"); - } - return wasm.functions[i]->name; - } + return wasm.functions[i]->name; } void WasmBinaryBuilder::getResizableLimits(Address& initial, Address& max, bool &shared, Address defaultIfNoMax) { @@ -968,61 +951,69 @@ void WasmBinaryBuilder::readImports() { if (debug) std::cerr << "== readImports" << std::endl; size_t num = getU32LEB(); if (debug) std::cerr << "num: " << num << std::endl; + Builder builder(wasm); for (size_t i = 0; i < num; i++) { if (debug) std::cerr << "read one" << std::endl; - auto curr = new Import; - curr->module = getInlineString(); - curr->base = getInlineString(); - curr->kind = (ExternalKind)getU32LEB(); + auto module = getInlineString(); + auto base = getInlineString(); + auto kind = (ExternalKind)getU32LEB(); // We set a unique prefix for the name based on the kind. This ensures no collisions // between them, which can't occur here (due to the index i) but could occur later // due to the names section. - switch (curr->kind) { + switch (kind) { case ExternalKind::Function: { - curr->name = Name(std::string("fimport$") + std::to_string(i)); + auto name = Name(std::string("fimport$") + std::to_string(i)); auto index = getU32LEB(); if (index >= wasm.functionTypes.size()) { throwError("invalid function index " + std::to_string(index) + " / " + std::to_string(wasm.functionTypes.size())); } - curr->functionType = wasm.functionTypes[index]->name; - assert(curr->functionType.is()); + auto* functionType = wasm.functionTypes[index].get(); + auto params = functionType->params; + auto result = functionType->result; + auto* curr = builder.makeFunction(name, std::move(params), result, {}); + curr->module = module; + curr->base = base; + curr->type = functionType->name; + wasm.addFunction(curr); functionImports.push_back(curr); - continue; // don't add the import yet, we add them later after we know their names break; } case ExternalKind::Table: { - curr->name = Name(std::string("timport$") + std::to_string(i)); + wasm.table.module = module; + wasm.table.base = base; + wasm.table.name = Name(std::string("timport$") + std::to_string(i)); auto elementType = getS32LEB(); WASM_UNUSED(elementType); if (elementType != BinaryConsts::EncodedType::AnyFunc) throwError("Imported table type is not AnyFunc"); wasm.table.exists = true; - wasm.table.imported = true; bool is_shared; getResizableLimits(wasm.table.initial, wasm.table.max, is_shared, Table::kMaxSize); if (is_shared) throwError("Tables may not be shared"); break; } case ExternalKind::Memory: { - curr->name = Name(std::string("mimport$") + std::to_string(i)); + wasm.memory.module = module; + wasm.memory.base = base; + wasm.memory.name = Name(std::to_string(i)); wasm.memory.exists = true; - wasm.memory.imported = true; getResizableLimits(wasm.memory.initial, wasm.memory.max, wasm.memory.shared, Memory::kMaxSize); break; } case ExternalKind::Global: { - curr->name = Name(std::string("gimport$") + std::to_string(i)); - curr->globalType = getConcreteType(); - auto globalMutable = getU32LEB(); - // TODO: actually use the globalMutable flag. Currently mutable global - // imports is a future feature, to be implemented with thread support. - (void)globalMutable; + auto name = Name(std::string("gimport$") + std::to_string(i)); + auto type = getConcreteType(); + auto mutable_ = getU32LEB(); + assert(!mutable_); // for now, until mutable globals + auto* curr = builder.makeGlobal(name, type, nullptr, mutable_ ? Builder::Mutable : Builder::Immutable); + curr->module = module; + curr->base = base; + wasm.addGlobal(curr); break; } default: { throwError("bad import kind"); } } - wasm.addImport(curr); } } @@ -1332,7 +1323,7 @@ void WasmBinaryBuilder::readGlobals() { if (mutable_ & ~1) throwError("Global mutability must be 0 or 1"); auto* init = readExpression(); wasm.addGlobal(Builder::makeGlobal( - "global$" + std::to_string(wasm.globals.size()), + "global$" + std::to_string(i), type, init, mutable_ ? Builder::Mutable : Builder::Immutable @@ -1460,15 +1451,12 @@ Expression* WasmBinaryBuilder::popNonVoidExpression() { Name WasmBinaryBuilder::getGlobalName(Index index) { if (!mappedGlobals.size()) { // Create name => index mapping. - for (auto& import : wasm.imports) { - if (import->kind != ExternalKind::Global) continue; - auto index = mappedGlobals.size(); - mappedGlobals[index] = import->name; - } - for (size_t i = 0; i < wasm.globals.size(); i++) { + auto add = [&](Global* curr) { auto index = mappedGlobals.size(); - mappedGlobals[index] = wasm.globals[i]->name; - } + mappedGlobals[index] = curr->name; + }; + ModuleUtils::iterImportedGlobals(wasm, add); + ModuleUtils::iterDefinedGlobals(wasm, add); } if (index == Index(-1)) return Name("null"); // just a force-rebuild if (mappedGlobals.count(index) == 0) { @@ -1482,17 +1470,6 @@ void WasmBinaryBuilder::processFunctions() { wasm.addFunction(func); } - for (auto* import : functionImports) { - wasm.addImport(import); - } - - // we should have seen all the functions - // we assume this later down in fact, when we read wasm.functions[index], - // as index was validated vs functionTypes.size() - if (wasm.functions.size() != functionTypes.size()) { - throwError("did not see the right number of functions"); - } - // now that we have names for each function, apply things if (startIndex != static_cast<Index>(-1)) { @@ -1518,15 +1495,7 @@ void WasmBinaryBuilder::processFunctions() { size_t index = iter.first; auto& calls = iter.second; for (auto* call : calls) { - call->target = wasm.functions[index]->name; - } - } - - for (auto& iter : functionImportCalls) { - size_t index = iter.first; - auto& calls = iter.second; - for (auto* call : calls) { - call->target = functionImports[index]->name; + call->target = getFunctionIndexName(index); } } @@ -1537,6 +1506,10 @@ void WasmBinaryBuilder::processFunctions() { wasm.table.segments[i].data.push_back(getFunctionIndexName(j)); } } + + // Everything now has its proper name. + + wasm.updateMaps(); } void WasmBinaryBuilder::readDataSegments() { @@ -1689,7 +1662,7 @@ BinaryConsts::ASTNodes WasmBinaryBuilder::readExpression(Expression*& curr) { case BinaryConsts::Br: case BinaryConsts::BrIf: visitBreak((curr = allocator.alloc<Break>())->cast<Break>(), code); break; // code distinguishes br from br_if case BinaryConsts::TableSwitch: visitSwitch((curr = allocator.alloc<Switch>())->cast<Switch>()); break; - case BinaryConsts::CallFunction: curr = visitCall(); break; // we don't know if it's a call or call_import yet + case BinaryConsts::CallFunction: visitCall((curr = allocator.alloc<Call>())->cast<Call>()); break; case BinaryConsts::CallIndirect: visitCallIndirect((curr = allocator.alloc<CallIndirect>())->cast<CallIndirect>()); break; case BinaryConsts::GetLocal: visitGetLocal((curr = allocator.alloc<GetLocal>())->cast<GetLocal>()); break; case BinaryConsts::TeeLocal: @@ -1928,35 +1901,26 @@ void WasmBinaryBuilder::visitSwitch(Switch* curr) { curr->finalize(); } -Expression* WasmBinaryBuilder::visitCall() { +void WasmBinaryBuilder::visitCall(Call* curr) { if (debug) std::cerr << "zz node: Call" << std::endl; auto index = getU32LEB(); FunctionType* type; - Expression* ret; if (index < functionImports.size()) { - // this is a call of an imported function - auto* call = allocator.alloc<CallImport>(); auto* import = functionImports[index]; - type = wasm.getFunctionType(import->functionType); - functionImportCalls[index].push_back(call); - call->target = import->name; // name section may modify it - fillCall(call, type); - call->finalize(); - ret = call; + type = wasm.getFunctionType(import->type); } else { - // this is a call of a defined function - auto* call = allocator.alloc<Call>(); auto adjustedIndex = index - functionImports.size(); - if (adjustedIndex >= functionTypes.size()) { - throwError("bad call index"); - } type = functionTypes[adjustedIndex]; - fillCall(call, type); - functionCalls[adjustedIndex].push_back(call); // we don't know function names yet - call->finalize(); - ret = call; } - return ret; + assert(type); + auto num = type->params.size(); + curr->operands.resize(num); + for (size_t i = 0; i < num; i++) { + curr->operands[num - i - 1] = popNonVoidExpression(); + } + curr->type = type->result; + functionCalls[index].push_back(curr); // we don't know function names yet + curr->finalize(); } void WasmBinaryBuilder::visitCallIndirect(CallIndirect* curr) { @@ -2007,17 +1971,7 @@ void WasmBinaryBuilder::visitGetGlobal(GetGlobal* curr) { if (debug) std::cerr << "zz node: GetGlobal " << pos << std::endl; auto index = getU32LEB(); curr->name = getGlobalName(index); - auto* global = wasm.getGlobalOrNull(curr->name); - if (global) { - curr->type = global->type; - return; - } - auto* import = wasm.getImportOrNull(curr->name); - if (import && import->kind == ExternalKind::Global) { - curr->type = import->globalType; - return; - } - throwError("bad get_global"); + curr->type = wasm.getGlobal(curr->name)->type; } void WasmBinaryBuilder::visitSetGlobal(SetGlobal* curr) { diff --git a/src/wasm/wasm-emscripten.cpp b/src/wasm/wasm-emscripten.cpp index 65d385870..635b5dfc5 100644 --- a/src/wasm/wasm-emscripten.cpp +++ b/src/wasm/wasm-emscripten.cpp @@ -24,6 +24,8 @@ #include "wasm-builder.h" #include "wasm-traversal.h" #include "wasm.h" +#include "ir/function-type-utils.h" +#include "ir/module-utils.h" namespace wasm { @@ -181,12 +183,7 @@ void EmscriptenGlueGenerator::generateDynCallThunks() { if (indirectFunc == dummyFunction) { continue; } - std::string sig; - if (auto import = wasm.getImportOrNull(indirectFunc)) { - sig = getSig(wasm.getFunctionType(import->functionType)); - } else { - sig = getSig(wasm.getFunction(indirectFunc)); - } + std::string sig = getSig(wasm.getFunction(indirectFunc)); auto* funcType = ensureFunctionType(sig, &wasm); if (hasI64ResultOrParam(funcType)) continue; // Can't export i64s on the web. if (!sigs.insert(sig).second) continue; // Sig is already in the set @@ -269,12 +266,12 @@ void EmscriptenGlueGenerator::generateJSCallThunks( // function would have signature 'vii'.) std::string importSig = std::string(1, sig[0]) + 'i' + sig.substr(1); FunctionType *importType = ensureFunctionType(importSig, &wasm); - auto import = new Import; + auto import = new Function; import->name = import->base = "jsCall_" + sig; import->module = ENV; - import->functionType = importType->name; - import->kind = ExternalKind::Function; - wasm.addImport(import); + import->type = importType->name; + FunctionTypeUtils::fillFunction(import, importType); + wasm.addFunction(import); FunctionType *funcType = ensureFunctionType(sig, &wasm); // Create jsCall_sig_index thunks (e.g. jsCall_vi_0, jsCall_vi_1, ...) @@ -297,7 +294,7 @@ void EmscriptenGlueGenerator::generateJSCallThunks( args.push_back(builder.makeGetLocal(i, funcType->params[i])); } Expression* call = - builder.makeCallImport(import->name, args, funcType->result); + builder.makeCall(import->name, args, funcType->result); f->body = call; wasm.addFunction(f); tableSegmentData.push_back(f->name); @@ -378,7 +375,7 @@ struct AsmConstWalker : public PostWalker<AsmConstWalker> { : wasm(_wasm), segmentOffsets(getSegmentOffsets(wasm)) { } - void visitCallImport(CallImport* curr); + void visitCall(Call* curr); private: Literal idLiteralForCode(std::string code); @@ -387,9 +384,9 @@ private: void addImport(Name importName, std::string baseSig); }; -void AsmConstWalker::visitCallImport(CallImport* curr) { - Import* import = wasm.getImport(curr->target); - if (import->base.hasSubstring(EMSCRIPTEN_ASM_CONST)) { +void AsmConstWalker::visitCall(Call* curr) { + auto* import = wasm.getFunction(curr->target); + if (import->imported() && import->base.hasSubstring(EMSCRIPTEN_ASM_CONST)) { auto arg = curr->operands[0]->cast<Const>(); auto code = codeForConstAddr(wasm, segmentOffsets, arg); arg->value = idLiteralForCode(code); @@ -434,20 +431,19 @@ Name AsmConstWalker::nameForImportWithSig(std::string sig) { } void AsmConstWalker::addImport(Name importName, std::string baseSig) { - auto import = new Import; + auto import = new Function; import->name = import->base = importName; import->module = ENV; - import->functionType = ensureFunctionType(baseSig, &wasm)->name; - import->kind = ExternalKind::Function; - wasm.addImport(import); + import->type = ensureFunctionType(baseSig, &wasm)->name; + wasm.addFunction(import); } AsmConstWalker fixEmAsmConstsAndReturnWalker(Module& wasm) { // Collect imports to remove // This would find our generated functions if we ran it later std::vector<Name> toRemove; - for (auto& import : wasm.imports) { - if (import->base.hasSubstring(EMSCRIPTEN_ASM_CONST)) { + for (auto& import : wasm.functions) { + if (import->imported() && import->base.hasSubstring(EMSCRIPTEN_ASM_CONST)) { toRemove.push_back(import->name); } } @@ -458,7 +454,7 @@ AsmConstWalker fixEmAsmConstsAndReturnWalker(Module& wasm) { // Remove the base functions that we didn't generate for (auto importName : toRemove) { - wasm.removeImport(importName); + wasm.removeFunction(importName); } return walker; } @@ -590,19 +586,21 @@ struct FixInvokeFunctionNamesWalker : public PostWalker<FixInvokeFunctionNamesWa return fixEmExceptionInvoke(name, sig); } - void visitImport(Import* curr) { - if (curr->kind != ExternalKind::Function) + void visitFunction(Function* curr) { + if (!curr->imported()) { return; + } - FunctionType* func = wasm.getFunctionType(curr->functionType); + FunctionType* func = wasm.getFunctionType(curr->type); Name newname = fixEmEHSjLjNames(curr->base, getSig(func)); - if (newname == curr->base) + if (newname == curr->base) { return; + } assert(importRenames.count(curr->name) == 0); importRenames[curr->name] = newname; // Either rename or remove the existing import - if (wasm.getImportOrNull(newname) || !newImports.insert(newname).second) { + if (wasm.getFunctionOrNull(newname) || !newImports.insert(newname).second) { toRemove.push_back(curr->name); } else { curr->base = newname; @@ -621,16 +619,18 @@ struct FixInvokeFunctionNamesWalker : public PostWalker<FixInvokeFunctionNamesWa } } - void visitCallImport(CallImport* curr) { - auto it = importRenames.find(curr->target); - if (it != importRenames.end()) { - curr->target = it->second; + void visitCall(Call* curr) { + if (wasm.getFunction(curr->target)->imported()) { + auto it = importRenames.find(curr->target); + if (it != importRenames.end()) { + curr->target = it->second; + } } } void visitModule(Module* curr) { for (auto importName : toRemove) { - wasm.removeImport(importName); + wasm.removeFunction(importName); } wasm.updateMaps(); } @@ -731,25 +731,23 @@ std::string EmscriptenGlueGenerator::generateEmscriptenMetadata( // see. meta << ", \"declares\": ["; commaFirst = true; - for (const auto& import : wasm.imports) { - if (import->kind == ExternalKind::Function && - (emJsWalker.codeByName.count(import->base.str) == 0) && + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { + if (emJsWalker.codeByName.count(import->base.str) == 0 && !import->base.startsWith(EMSCRIPTEN_ASM_CONST.str) && !import->base.startsWith("invoke_") && !import->base.startsWith("jsCall_")) { - if (declares.insert(import->base.str).second) + if (declares.insert(import->base.str).second) { meta << maybeComma() << '"' << import->base.str << '"'; + } } - } + }); meta << "]"; meta << ", \"externs\": ["; commaFirst = true; - for (const auto& import : wasm.imports) { - if (import->kind == ExternalKind::Global) { - meta << maybeComma() << "\"_" << import->base.str << '"'; - } - } + ModuleUtils::iterImportedGlobals(wasm, [&](Global* import) { + meta << maybeComma() << "\"_" << import->base.str << '"'; + }); meta << "]"; meta << ", \"implementedFunctions\": ["; @@ -770,12 +768,13 @@ std::string EmscriptenGlueGenerator::generateEmscriptenMetadata( meta << ", \"invokeFuncs\": ["; commaFirst = true; - for (const auto& import : wasm.imports) { + ModuleUtils::iterImportedFunctions(wasm, [&](Function* import) { if (import->base.startsWith("invoke_")) { - if (invokeFuncs.insert(import->base.str).second) + if (invokeFuncs.insert(import->base.str).second) { meta << maybeComma() << '"' << import->base.str << '"'; + } } - } + }); meta << "]"; meta << " }\n"; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index e4c171b5c..7085666bb 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -23,6 +23,7 @@ #include "asm_v_wasm.h" #include "asmjs/shared-constants.h" #include "ir/branch-utils.h" +#include "ir/function-type-utils.h" #include "shared-constants.h" #include "wasm-binary.h" #include "wasm-builder.h" @@ -581,14 +582,15 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { if (importModule.is()) { // this is an import, actually if (!preParseImport) throw ParseException("!preParseImport in func"); - std::unique_ptr<Import> im = make_unique<Import>(); + auto im = make_unique<Function>(); im->name = name; im->module = importModule; im->base = importBase; - im->kind = ExternalKind::Function; - im->functionType = wasm.getFunctionType(type)->name; - if (wasm.getImportOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); - wasm.addImport(im.release()); + im->type = type; + FunctionTypeUtils::fillFunction(im.get(), wasm.getFunctionType(type)); + functionTypes[name] = im->result; + if (wasm.getFunctionOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); + wasm.addFunction(im.release()); if (currFunction) throw ParseException("import module inside function dec"); currLocalTypes.clear(); nameMapper.clear(); @@ -821,7 +823,6 @@ Expression* SExpressionWasmBuilder::makeExpression(Element& s) { case 'c': { if (str[1] == 'a') { if (id == CALL) return makeCall(s); - if (id == CALL_IMPORT) return makeCallImport(s); if (id == CALL_INDIRECT) return makeCallIndirect(s); } else if (str[1] == 'u') return makeHost(s, HostOp::CurrentMemory); abort_on(str); @@ -1040,16 +1041,11 @@ Expression* SExpressionWasmBuilder::makeGetGlobal(Element& s) { auto ret = allocator.alloc<GetGlobal>(); ret->name = getGlobalName(*s[1]); auto* global = wasm.getGlobalOrNull(ret->name); - if (global) { - ret->type = global->type; - return ret; + if (!global) { + throw ParseException("bad get_global name", s.line, s.col); } - auto* import = wasm.getImportOrNull(ret->name); - if (import && import->kind == ExternalKind::Global) { - ret->type = import->globalType; - return ret; - } - throw ParseException("bad get_global name", s.line, s.col); + ret->type = global->type; + return ret; } Expression* SExpressionWasmBuilder::makeSetGlobal(Element& s) { @@ -1377,15 +1373,6 @@ Expression* SExpressionWasmBuilder::makeLoop(Element& s) { Expression* SExpressionWasmBuilder::makeCall(Element& s) { auto target = getFunctionName(*s[1]); - auto* import = wasm.getImportOrNull(target); - if (import && import->kind == ExternalKind::Function) { - auto ret = allocator.alloc<CallImport>(); - ret->target = target; - Import* import = wasm.getImport(ret->target); - ret->type = wasm.getFunctionType(import->functionType)->result; - parseCallOperands(s, 2, s.size(), ret); - return ret; - } auto ret = allocator.alloc<Call>(); ret->target = target; ret->type = functionTypes[ret->target]; @@ -1394,16 +1381,6 @@ Expression* SExpressionWasmBuilder::makeCall(Element& s) { return ret; } -Expression* SExpressionWasmBuilder::makeCallImport(Element& s) { - auto ret = allocator.alloc<CallImport>(); - ret->target = s[1]->str(); - Import* import = wasm.getImport(ret->target); - ret->type = wasm.getFunctionType(import->functionType)->result; - parseCallOperands(s, 2, s.size(), ret); - ret->finalize(); - return ret; -} - Expression* SExpressionWasmBuilder::makeCallIndirect(Element& s) { if (!wasm.table.exists) throw ParseException("no table"); auto ret = allocator.alloc<CallIndirect>(); @@ -1543,7 +1520,6 @@ Index SExpressionWasmBuilder::parseMemoryLimits(Element& s, Index i) { void SExpressionWasmBuilder::parseMemory(Element& s, bool preParseImport) { if (wasm.memory.exists) throw ParseException("too many memories"); wasm.memory.exists = true; - wasm.memory.imported = preParseImport; wasm.memory.shared = false; Index i = 1; if (s[i]->dollared()) { @@ -1561,15 +1537,8 @@ void SExpressionWasmBuilder::parseMemory(Element& s, bool preParseImport) { wasm.addExport(ex.release()); i++; } else if (inner[0]->str() == IMPORT) { - importModule = inner[1]->str(); - importBase = inner[2]->str(); - auto im = make_unique<Import>(); - im->kind = ExternalKind::Memory; - im->module = importModule; - im->base = importBase; - im->name = importModule; - if (wasm.getImportOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); - wasm.addImport(im.release()); + wasm.memory.module = inner[1]->str(); + wasm.memory.base = inner[2]->str(); i++; } else if (inner[0]->str() == "shared") { wasm.memory.shared = true; @@ -1675,70 +1644,69 @@ void SExpressionWasmBuilder::parseExport(Element& s) { } void SExpressionWasmBuilder::parseImport(Element& s) { - std::unique_ptr<Import> im = make_unique<Import>(); size_t i = 1; bool newStyle = s.size() == 4 && s[3]->isList(); // (import "env" "STACKTOP" (global $stackTop i32)) + auto kind = ExternalKind::Invalid; if (newStyle) { if ((*s[3])[0]->str() == FUNC) { - im->kind = ExternalKind::Function; + kind = ExternalKind::Function; } else if ((*s[3])[0]->str() == MEMORY) { - im->kind = ExternalKind::Memory; + kind = ExternalKind::Memory; if (wasm.memory.exists) throw ParseException("more than one memory"); wasm.memory.exists = true; - wasm.memory.imported = true; } else if ((*s[3])[0]->str() == TABLE) { - im->kind = ExternalKind::Table; + kind = ExternalKind::Table; if (wasm.table.exists) throw ParseException("more than one table"); wasm.table.exists = true; - wasm.table.imported = true; } else if ((*s[3])[0]->str() == GLOBAL) { - im->kind = ExternalKind::Global; + kind = ExternalKind::Global; } else { newStyle = false; // either (param..) or (result..) } } Index newStyleInner = 1; + Name name; if (s.size() > 3 && s[3]->isStr()) { - im->name = s[i++]->str(); + name = s[i++]->str(); } else if (newStyle && newStyleInner < s[3]->size() && (*s[3])[newStyleInner]->dollared()) { - im->name = (*s[3])[newStyleInner++]->str(); - } - if (!im->name.is()) { - if (im->kind == ExternalKind::Function) { - im->name = Name("import$function$" + std::to_string(functionCounter++)); - functionNames.push_back(im->name); - } else if (im->kind == ExternalKind::Global) { - im->name = Name("import$global" + std::to_string(globalCounter++)); - globalNames.push_back(im->name); - } else if (im->kind == ExternalKind::Memory) { - im->name = Name("import$memory$" + std::to_string(0)); - } else if (im->kind == ExternalKind::Table) { - im->name = Name("import$table$" + std::to_string(0)); + name = (*s[3])[newStyleInner++]->str(); + } + if (!name.is()) { + if (kind == ExternalKind::Function) { + name = Name("import$function$" + std::to_string(functionCounter++)); + functionNames.push_back(name); + } else if (kind == ExternalKind::Global) { + name = Name("import$global" + std::to_string(globalCounter++)); + globalNames.push_back(name); + } else if (kind == ExternalKind::Memory) { + name = Name("import$memory$" + std::to_string(0)); + } else if (kind == ExternalKind::Table) { + name = Name("import$table$" + std::to_string(0)); } else { throw ParseException("invalid import"); } } if (!s[i]->quoted()) { if (s[i]->str() == MEMORY) { - im->kind = ExternalKind::Memory; + kind = ExternalKind::Memory; } else if (s[i]->str() == TABLE) { - im->kind = ExternalKind::Table; + kind = ExternalKind::Table; } else if (s[i]->str() == GLOBAL) { - im->kind = ExternalKind::Global; + kind = ExternalKind::Global; } else { throw ParseException("invalid ext import"); } i++; } else if (!newStyle) { - im->kind = ExternalKind::Function; + kind = ExternalKind::Function; } - im->module = s[i++]->str(); + auto module = s[i++]->str(); if (!s[i]->isStr()) throw ParseException("no name for import"); - im->base = s[i++]->str(); + auto base = s[i++]->str(); // parse internals Element& inner = newStyle ? *s[3] : s; Index j = newStyle ? newStyleInner : i; - if (im->kind == ExternalKind::Function) { + if (kind == ExternalKind::Function) { std::unique_ptr<FunctionType> type = make_unique<FunctionType>(); if (inner.size() > j) { Element& params = *inner[j]; @@ -1762,17 +1730,34 @@ void SExpressionWasmBuilder::parseImport(Element& s) { type->result = stringToType(result[1]->str()); } } - im->functionType = ensureFunctionType(getSig(type.get()), &wasm)->name; - } else if (im->kind == ExternalKind::Global) { + auto func = make_unique<Function>(); + func->name = name; + func->module = module; + func->base = base; + auto* functionType = ensureFunctionType(getSig(type.get()), &wasm); + func->type = functionType->name; + FunctionTypeUtils::fillFunction(func.get(), functionType); + functionTypes[name] = func->result; + wasm.addFunction(func.release()); + } else if (kind == ExternalKind::Global) { + Type type; if (inner[j]->isStr()) { - im->globalType = stringToType(inner[j]->str()); + type = stringToType(inner[j]->str()); } else { auto& inner2 = *inner[j]; if (inner2[0]->str() != MUT) throw ParseException("expected mut"); - im->globalType = stringToType(inner2[1]->str()); + type = stringToType(inner2[1]->str()); throw ParseException("cannot import a mutable global", s.line, s.col); } - } else if (im->kind == ExternalKind::Table) { + auto global = make_unique<Global>(); + global->name = name; + global->module = module; + global->base = base; + global->type = type; + wasm.addGlobal(global.release()); + } else if (kind == ExternalKind::Table) { + wasm.table.module = module; + wasm.table.base = base; if (j < inner.size() - 1) { wasm.table.initial = getCheckedAddress(inner[j++], "excessive table init size"); } @@ -1782,7 +1767,9 @@ void SExpressionWasmBuilder::parseImport(Element& s) { wasm.table.max = Table::kMaxSize; } // ends with the table element type - } else if (im->kind == ExternalKind::Memory) { + } else if (kind == ExternalKind::Memory) { + wasm.memory.module = module; + wasm.memory.base = base; if (inner[j]->isList()) { auto& limits = *inner[j]; if (!(limits[0]->isStr() && limits[0]->str() == "shared")) throw ParseException("bad memory limit declaration"); @@ -1792,8 +1779,6 @@ void SExpressionWasmBuilder::parseImport(Element& s) { parseMemoryLimits(inner, j); } } - if (wasm.getImportOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); - wasm.addImport(im.release()); } void SExpressionWasmBuilder::parseGlobal(Element& s, bool preParseImport) { @@ -1841,14 +1826,13 @@ void SExpressionWasmBuilder::parseGlobal(Element& s, bool preParseImport) { // this is an import, actually if (!preParseImport) throw ParseException("!preParseImport in global"); if (mutable_) throw ParseException("cannot import a mutable global", s.line, s.col); - std::unique_ptr<Import> im = make_unique<Import>(); + auto im = make_unique<Global>(); im->name = global->name; im->module = importModule; im->base = importBase; - im->kind = ExternalKind::Global; - im->globalType = type; - if (wasm.getImportOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); - wasm.addImport(im.release()); + im->type = type; + if (wasm.getGlobalOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); + wasm.addGlobal(im.release()); return; } if (preParseImport) throw ParseException("preParseImport in global"); @@ -1868,7 +1852,6 @@ void SExpressionWasmBuilder::parseGlobal(Element& s, bool preParseImport) { void SExpressionWasmBuilder::parseTable(Element& s, bool preParseImport) { if (wasm.table.exists) throw ParseException("more than one table"); wasm.table.exists = true; - wasm.table.imported = preParseImport; Index i = 1; if (i == s.size()) return; // empty table in old notation if (s[i]->dollared()) { @@ -1887,16 +1870,9 @@ void SExpressionWasmBuilder::parseTable(Element& s, bool preParseImport) { wasm.addExport(ex.release()); i++; } else if (inner[0]->str() == IMPORT) { - importModule = inner[1]->str(); - importBase = inner[2]->str(); if (!preParseImport) throw ParseException("!preParseImport in table"); - auto im = make_unique<Import>(); - im->kind = ExternalKind::Table; - im->module = importModule; - im->base = importBase; - im->name = importModule; - if (wasm.getImportOrNull(im->name)) throw ParseException("duplicate import", s.line, s.col); - wasm.addImport(im.release()); + wasm.table.module = inner[1]->str(); + wasm.table.base = inner[2]->str(); i++; } else { throw ParseException("invalid table"); diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index d57607ce6..36069f8db 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -24,6 +24,7 @@ #include "wasm-validator.h" #include "ir/utils.h" #include "ir/branch-utils.h" +#include "ir/module-utils.h" #include "support/colors.h" @@ -233,7 +234,6 @@ public: void visitBreak(Break* curr); void visitSwitch(Switch* curr); void visitCall(Call* curr); - void visitCallImport(CallImport* curr); void visitCallIndirect(CallIndirect* curr); void visitGetLocal(GetLocal* curr); void visitSetLocal(SetLocal* curr); @@ -444,12 +444,7 @@ void FunctionValidator::visitSwitch(Switch* curr) { void FunctionValidator::visitCall(Call* curr) { if (!info.validateGlobally) return; auto* target = getModule()->getFunctionOrNull(curr->target); - if (!shouldBeTrue(!!target, curr, "call target must exist")) { - if (getModule()->getImportOrNull(curr->target) && !info.quiet) { - getStream() << "(perhaps it should be a CallImport instead of Call?)\n"; - } - return; - } + if (!shouldBeTrue(!!target, curr, "call target must exist")) return; if (!shouldBeTrue(curr->operands.size() == target->params.size(), curr, "call param number must match")) return; for (size_t i = 0; i < curr->operands.size(); i++) { if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, target->params[i], curr, "call param types must match") && !info.quiet) { @@ -458,20 +453,6 @@ void FunctionValidator::visitCall(Call* curr) { } } -void FunctionValidator::visitCallImport(CallImport* curr) { - if (!info.validateGlobally) return; - auto* import = getModule()->getImportOrNull(curr->target); - if (!shouldBeTrue(!!import, curr, "call_import target must exist")) return; - if (!shouldBeTrue(!!import->functionType.is(), curr, "called import must be function")) return; - auto* type = getModule()->getFunctionType(import->functionType); - if (!shouldBeTrue(curr->operands.size() == type->params.size(), curr, "call param number must match")) return; - for (size_t i = 0; i < curr->operands.size(); i++) { - if (!shouldBeEqualOrFirstIsUnreachable(curr->operands[i]->type, type->params[i], curr, "call param types must match") && !info.quiet) { - getStream() << "(on argument " << i << ")\n"; - } - } -} - void FunctionValidator::visitCallIndirect(CallIndirect* curr) { if (!info.validateGlobally) return; auto* type = getModule()->getFunctionTypeOrNull(curr->fullType); @@ -503,13 +484,13 @@ void FunctionValidator::visitSetLocal(SetLocal* curr) { void FunctionValidator::visitGetGlobal(GetGlobal* curr) { if (!info.validateGlobally) return; - shouldBeTrue(getModule()->getGlobalOrNull(curr->name) || getModule()->getImportOrNull(curr->name), curr, "get_global name must be valid"); + shouldBeTrue(getModule()->getGlobalOrNull(curr->name), curr, "get_global name must be valid"); } void FunctionValidator::visitSetGlobal(SetGlobal* curr) { if (!info.validateGlobally) return; auto* global = getModule()->getGlobalOrNull(curr->name); - if (shouldBeTrue(global != NULL, curr, "set_global name must be valid (and not an import; imports can't be modified)")) { + if (shouldBeTrue(global, curr, "set_global name must be valid (and not an import; imports can't be modified)")) { shouldBeTrue(global->mutable_, curr, "set_global global must be mutable"); shouldBeEqualOrFirstIsUnreachable(curr->value->type, global->type, curr, "set_global value must have right type"); } @@ -926,23 +907,15 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { // Main validator class static void validateImports(Module& module, ValidationInfo& info) { - for (auto& curr : module.imports) { - if (curr->kind == ExternalKind::Function) { - if (info.validateWeb) { - auto* functionType = module.getFunctionType(curr->functionType); - info.shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type"); - for (Type param : functionType->params) { - info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); - } + ModuleUtils::iterImportedFunctions(module, [&](Function* curr) { + if (info.validateWeb) { + auto* functionType = module.getFunctionType(curr->type); + info.shouldBeUnequal(functionType->result, i64, curr->name, "Imported function must not have i64 return type"); + for (Type param : functionType->params) { + info.shouldBeUnequal(param, i64, curr->name, "Imported function must not have i64 parameters"); } } - if (curr->kind == ExternalKind::Table) { - info.shouldBeTrue(module.table.imported, curr->name, "Table import record exists but table is not marked as imported"); - } - if (curr->kind == ExternalKind::Memory) { - info.shouldBeTrue(module.memory.imported, curr->name, "Memory import record exists but memory is not marked as imported"); - } - } + }); } static void validateExports(Module& module, ValidationInfo& info) { @@ -961,13 +934,9 @@ static void validateExports(Module& module, ValidationInfo& info) { for (auto& exp : module.exports) { Name name = exp->value; if (exp->kind == ExternalKind::Function) { - Import* imp; - info.shouldBeTrue(module.getFunctionOrNull(name) || - ((imp = module.getImportOrNull(name)) && imp->kind == ExternalKind::Function), name, "module function exports must be found"); + info.shouldBeTrue(module.getFunctionOrNull(name), name, "module function exports must be found"); } else if (exp->kind == ExternalKind::Global) { - Import* imp; - info.shouldBeTrue(module.getGlobalOrNull(name) || - ((imp = module.getImportOrNull(name)) && imp->kind == ExternalKind::Global), name, "module global exports must be found"); + info.shouldBeTrue(module.getGlobalOrNull(name), name, "module global exports must be found"); } else if (exp->kind == ExternalKind::Table) { info.shouldBeTrue(name == Name("0") || name == module.table.name, name, "module table exports must be found"); } else if (exp->kind == ExternalKind::Memory) { @@ -982,13 +951,13 @@ static void validateExports(Module& module, ValidationInfo& info) { } static void validateGlobals(Module& module, ValidationInfo& info) { - for (auto& curr : module.globals) { + ModuleUtils::iterDefinedGlobals(module, [&](Global* curr) { info.shouldBeTrue(curr->init != nullptr, curr->name, "global init must be non-null"); info.shouldBeTrue(curr->init->is<Const>() || curr->init->is<GetGlobal>(), curr->name, "global init must be valid"); if (!info.shouldBeEqual(curr->type, curr->init->type, curr->init, "global init must have correct type") && !info.quiet) { info.getStream(nullptr) << "(on global " << curr->name << ")\n"; } - } + }); } static void validateMemory(Module& module, ValidationInfo& info) { @@ -1016,7 +985,7 @@ static void validateTable(Module& module, ValidationInfo& info) { info.shouldBeEqual(segment.offset->type, i32, segment.offset, "segment offset should be i32"); info.shouldBeTrue(checkOffset(segment.offset, segment.data.size(), module.table.initial * Table::kPageSize), segment.offset, "segment offset should be reasonable"); for (auto name : segment.data) { - info.shouldBeTrue(module.getFunctionOrNull(name) || module.getImportOrNull(name), name, "segment name should be valid"); + info.shouldBeTrue(module.getFunctionOrNull(name), name, "segment name should be valid"); } } } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index ac3623cee..ad4d6343a 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -53,7 +53,6 @@ Name GROW_WASM_MEMORY("__growWasmMemory"), LOCAL("local"), TYPE("type"), CALL("call"), - CALL_IMPORT("call_import"), CALL_INDIRECT("call_indirect"), BLOCK("block"), BR_IF("br_if"), @@ -83,7 +82,6 @@ const char* getExpressionName(Expression* curr) { case Expression::Id::BreakId: return "break"; case Expression::Id::SwitchId: return "switch"; case Expression::Id::CallId: return "call"; - case Expression::Id::CallImportId: return "call_import"; case Expression::Id::CallIndirectId: return "call_indirect"; case Expression::Id::GetLocalId: return "get_local"; case Expression::Id::SetLocalId: return "set_local"; @@ -324,10 +322,6 @@ void Call::finalize() { handleUnreachableOperands(this); } -void CallImport::finalize() { - handleUnreachableOperands(this); -} - void CallIndirect::finalize() { handleUnreachableOperands(this); if (target->type == unreachable) { @@ -650,14 +644,6 @@ FunctionType* Module::getFunctionType(Name name) { return iter->second; } -Import* Module::getImport(Name name) { - auto iter = importsMap.find(name); - if (iter == importsMap.end()) { - Fatal() << "Module::getImport: " << name << " does not exist"; - } - return iter->second; -} - Export* Module::getExport(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -690,14 +676,6 @@ FunctionType* Module::getFunctionTypeOrNull(Name name) { return iter->second; } -Import* Module::getImportOrNull(Name name) { - auto iter = importsMap.find(name); - if (iter == importsMap.end()) { - return nullptr; - } - return iter->second; -} - Export* Module::getExportOrNull(Name name) { auto iter = exportsMap.find(name); if (iter == exportsMap.end()) { @@ -733,17 +711,6 @@ void Module::addFunctionType(FunctionType* curr) { functionTypesMap[curr->name] = curr; } -void Module::addImport(Import* curr) { - if (!curr->name.is()) { - Fatal() << "Module::addImport: empty name"; - } - if (getImportOrNull(curr->name)) { - Fatal() << "Module::addImport: " << curr->name << " already exists"; - } - imports.push_back(std::unique_ptr<Import>(curr)); - importsMap[curr->name] = curr; -} - void Module::addExport(Export* curr) { if (!curr->name.is()) { Fatal() << "Module::addExport: empty name"; @@ -781,14 +748,14 @@ void Module::addStart(const Name& s) { start = s; } -void Module::removeImport(Name name) { - for (size_t i = 0; i < imports.size(); i++) { - if (imports[i]->name == name) { - imports.erase(imports.begin() + i); +void Module::removeFunctionType(Name name) { + for (size_t i = 0; i < functionTypes.size(); i++) { + if (functionTypes[i]->name == name) { + functionTypes.erase(functionTypes.begin() + i); break; } } - importsMap.erase(name); + functionTypesMap.erase(name); } void Module::removeExport(Name name) { @@ -811,14 +778,14 @@ void Module::removeFunction(Name name) { functionsMap.erase(name); } -void Module::removeFunctionType(Name name) { - for (size_t i = 0; i < functionTypes.size(); i++) { - if (functionTypes[i]->name == name) { - functionTypes.erase(functionTypes.begin() + i); +void Module::removeGlobal(Name name) { + for (size_t i = 0; i < globals.size(); i++) { + if (globals[i]->name == name) { + globals.erase(globals.begin() + i); break; } } - functionTypesMap.erase(name); + globalsMap.erase(name); } // TODO: remove* for other elements @@ -832,10 +799,6 @@ void Module::updateMaps() { for (auto& curr : functionTypes) { functionTypesMap[curr->name] = curr.get(); } - importsMap.clear(); - for (auto& curr : imports) { - importsMap[curr->name] = curr.get(); - } exportsMap.clear(); for (auto& curr : exports) { exportsMap[curr->name] = curr.get(); diff --git a/src/wasm2js.h b/src/wasm2js.h index a0519effd..080315490 100644 --- a/src/wasm2js.h +++ b/src/wasm2js.h @@ -34,6 +34,8 @@ #include "emscripten-optimizer/optimizer.h" #include "mixed_arena.h" #include "asm_v_wasm.h" +#include "ir/import-utils.h" +#include "ir/module-utils.h" #include "ir/names.h" #include "ir/utils.h" #include "passes/passes.h" @@ -272,7 +274,7 @@ private: void addEsmImports(Ref ast, Module* wasm); void addEsmExportsAndInstantiate(Ref ast, Module* wasm, Name funcName); void addBasics(Ref ast); - void addImport(Ref ast, Import* import); + void addFunctionImport(Ref ast, Function* import); void addTables(Ref ast, Module* wasm); void addExports(Ref ast, Module* wasm); void addGlobal(Ref ast, Global* global); @@ -341,9 +343,12 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { asmFunc[3]->push_back(ValueBuilder::makeStatement(ValueBuilder::makeString(USE_ASM))); // create heaps, etc addBasics(asmFunc[3]); - for (auto& import : wasm->imports) { - addImport(asmFunc[3], import.get()); - } + ModuleUtils::iterImportedFunctions(*wasm, [&](Function* import) { + addFunctionImport(asmFunc[3], import); + }); + ModuleUtils::iterImportedGlobals(*wasm, [&](Global* import) { + addGlobal(asmFunc[3], import); + }); // figure out the table size tableSize = std::accumulate(wasm->table.segments.begin(), wasm->table.segments.end(), @@ -368,16 +373,16 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { fromName(WASM_FETCH_HIGH_BITS, NameScope::Top); // globals bool generateFetchHighBits = false; - for (auto& global : wasm->globals) { - addGlobal(asmFunc[3], global.get()); + ModuleUtils::iterDefinedGlobals(*wasm, [&](Global* global) { + addGlobal(asmFunc[3], global); if (flags.allowAsserts && global->name == INT64_TO_32_HIGH_BITS) { generateFetchHighBits = true; } - } + }); // functions - for (auto& func : wasm->functions) { - asmFunc[3]->push_back(processFunction(wasm, func.get())); - } + ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { + asmFunc[3]->push_back(processFunction(wasm, func)); + }); if (generateFetchHighBits) { Builder builder(allocator); std::vector<Type> params; @@ -402,19 +407,15 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) { return ret; } -void Wasm2JSBuilder::addEsmImports(Ref ast, Module *wasm) { +void Wasm2JSBuilder::addEsmImports(Ref ast, Module* wasm) { std::unordered_map<Name, Name> nameMap; - for (auto& import : wasm->imports) { - // Only function imports are supported for now, but eventually imported - // memories can probably be supported at least. - switch (import->kind) { - case ExternalKind::Function: break; - default: - Fatal() << "non-function imports aren't supported yet\n"; - abort(); - } - + ImportInfo imports(*wasm); + if (imports.getNumImportedGlobals() > 0) { + Fatal() << "non-function imports aren't supported yet\n"; + abort(); + } + ModuleUtils::iterImportedFunctions(*wasm, [&](Function* import) { // Right now codegen requires a flat namespace going into the module, // meaning we don't importing the same name from multiple namespaces yet. if (nameMap.count(import->base) && nameMap[import->base] != import->module) { @@ -434,7 +435,7 @@ void Wasm2JSBuilder::addEsmImports(Ref ast, Module *wasm) { std::string os = out.str(); IString name(os.c_str(), false); flattenAppend(ast, ValueBuilder::makeName(name)); - } + }); } static std::string base64Encode(std::vector<char> &data) { @@ -553,13 +554,10 @@ void Wasm2JSBuilder::addEsmExportsAndInstantiate(Ref ast, Module *wasm, Name fun << "}, {"; construct << "abort:function() { throw new Error('abort'); }"; - for (auto& import : wasm->imports) { - switch (import->kind) { - case ExternalKind::Function: break; - default: continue; - } + + ModuleUtils::iterImportedFunctions(*wasm, [&](Function* import) { construct << "," << import->base.str; - } + }); construct << "},mem" << funcName.str << ")"; std::string sconstruct = construct.str(); IString name(sconstruct.c_str(), false); @@ -678,7 +676,7 @@ void Wasm2JSBuilder::addBasics(Ref ast) { ); } -void Wasm2JSBuilder::addImport(Ref ast, Import* import) { +void Wasm2JSBuilder::addFunctionImport(Ref ast, Function* import) { Ref theVar = ValueBuilder::makeVar(); ast->push_back(theVar); Ref module = ValueBuilder::makeName(ENV); // TODO: handle nested module imports @@ -940,14 +938,6 @@ void Wasm2JSBuilder::scanFunctionBody(Expression* curr) { } } } - void visitCallImport(CallImport* curr) { - for (auto item : curr->operands) { - if (parent->isStatement(item)) { - parent->setStatement(curr); - break; - } - } - } void visitCallIndirect(CallIndirect* curr) { // TODO: this is a pessimization that probably wants to get tweaked in // the future. If none of the arguments have any side effects then we @@ -1255,10 +1245,6 @@ Ref Wasm2JSBuilder::processFunctionBody(Module* m, Function* func, IString resul return visitGenericCall(curr, curr->target, curr->operands); } - Ref visitCallImport(CallImport* curr) { - return visitGenericCall(curr, curr->target, curr->operands); - } - Ref visitCallIndirect(CallIndirect* curr) { // TODO: the codegen here is a pessimization of what the ideal codegen // looks like. Eventually if necessary this should be tightened up in the @@ -2197,7 +2183,7 @@ Ref Wasm2JSBuilder::makeAssertReturnNanFunc(SExpressionWasmBuilder& sexpBuilder, Name testFuncName, Name asmModule) { Expression* actual = sexpBuilder.parseExpression(e[1]); - Expression* body = wasmBuilder.makeCallImport("isNaN", {actual}, i32); + Expression* body = wasmBuilder.makeCall("isNaN", {actual}, i32); std::unique_ptr<Function> testFunc( wasmBuilder.makeFunction( testFuncName, |