diff options
author | Thomas Lively <7121787+tlively@users.noreply.github.com> | 2021-07-01 01:56:23 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-06-30 18:56:23 -0700 |
commit | ca27f40a2f1070a16ee7c0efc18ff35d342d8027 (patch) | |
tree | ab0f2b1b731737bc409db21f677b97be16f67c0f /src/wasm | |
parent | 10ef52d62468aec5762742930630e882dc5e5c0b (diff) | |
download | binaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.tar.gz binaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.tar.bz2 binaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.zip |
Preserve Function HeapTypes (#3952)
When using nominal types, func.ref of two functions with identical signatures
but different HeapTypes will yield different types. To preserve these semantics,
Functions need to track their HeapTypes, not just their Signatures.
This PR replaces the Signature field in Function with a HeapType field and adds
new utility methods to make it almost as simple to update and query the function
HeapType as it was to update and query the Function Signature.
Diffstat (limited to 'src/wasm')
-rw-r--r-- | src/wasm/wasm-binary.cpp | 16 | ||||
-rw-r--r-- | src/wasm/wasm-s-parser.cpp | 20 | ||||
-rw-r--r-- | src/wasm/wasm-validator.cpp | 28 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 12 |
4 files changed, 36 insertions, 40 deletions
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 779df5c65..31fe1bc6e 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -269,7 +269,7 @@ void WasmBinaryWriter::writeImports() { BYN_TRACE("write one function\n"); writeImportHeader(func); o << U32LEB(int32_t(ExternalKind::Function)); - o << U32LEB(getTypeIndex(func->sig)); + o << U32LEB(getTypeIndex(func->type)); }); ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) { BYN_TRACE("write one global\n"); @@ -318,7 +318,7 @@ void WasmBinaryWriter::writeFunctionSignatures() { o << U32LEB(importInfo->getNumDefinedFunctions()); ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) { BYN_TRACE("write one\n"); - o << U32LEB(getTypeIndex(func->sig)); + o << U32LEB(getTypeIndex(func->type)); }); finishSection(start); } @@ -2008,8 +2008,7 @@ void WasmBinaryBuilder::readImports() { Name name(std::string("fimport$") + std::to_string(functionCounter++)); auto index = getU32LEB(); functionTypes.push_back(getTypeByIndex(index)); - auto curr = - builder.makeFunction(name, getSignatureByTypeIndex(index), {}); + auto curr = builder.makeFunction(name, getTypeByIndex(index), {}); curr->module = module; curr->base = base; functionImports.push_back(curr.get()); @@ -2158,7 +2157,7 @@ void WasmBinaryBuilder::readFunctions() { auto* func = new Function; func->name = Name::fromInt(i); - func->sig = getSignatureByFunctionIndex(functionImports.size() + i); + func->type = getTypeByFunctionIndex(functionImports.size() + i); currFunction = func; if (DWARF) { @@ -2197,7 +2196,7 @@ void WasmBinaryBuilder::readFunctions() { auto currFunctionIndex = functionImports.size() + functions.size(); bool isStart = startIndex == currFunctionIndex; if (!skipFunctionBodies || isStart) { - func->body = getBlockOrSingleton(func->sig.results); + func->body = getBlockOrSingleton(func->getResults()); } else { // When skipping the function body we need to put something valid in // their place so we validate. An unreachable is always acceptable @@ -6004,8 +6003,9 @@ void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) { void WasmBinaryBuilder::visitReturn(Return* curr) { BYN_TRACE("zz node: Return\n"); requireFunctionContext("return"); - if (currFunction->sig.results.isConcrete()) { - curr->value = popTypedExpression(currFunction->sig.results); + Type type = currFunction->getResults(); + if (type.isConcrete()) { + curr->value = popTypedExpression(type); } curr->finalize(); } diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index bda94dec5..25be13f58 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -914,9 +914,7 @@ void SExpressionWasmBuilder::preParseFunctionType(Element& s) { } functionNames.push_back(name); functionCounter++; - HeapType type; - parseTypeUse(s, i, type); - functionTypes[name] = type; + parseTypeUse(s, i, functionTypes[name]); } size_t SExpressionWasmBuilder::parseFunctionNames(Element& s, @@ -1007,12 +1005,12 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { im->setName(name, hasExplicitName); im->module = importModule; im->base = importBase; - im->sig = type.getSignature(); + im->type = type; functionTypes[name] = type; if (wasm.getFunctionOrNull(im->name)) { throw ParseException("duplicate import", s.line, s.col); } - wasm.addFunction(im.release()); + wasm.addFunction(std::move(im)); if (currFunction) { throw ParseException("import module inside function dec", s.line, s.col); } @@ -1034,8 +1032,8 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) { } // make a new function - currFunction = std::unique_ptr<Function>(Builder(wasm).makeFunction( - name, std::move(params), type.getSignature().results, std::move(vars))); + currFunction = std::unique_ptr<Function>( + Builder(wasm).makeFunction(name, std::move(params), type, std::move(vars))); currFunction->profile = profile; // parse body @@ -3045,13 +3043,11 @@ void SExpressionWasmBuilder::parseImport(Element& s) { if (kind == ExternalKind::Function) { auto func = make_unique<Function>(); - HeapType funcType; - j = parseTypeUse(inner, j, funcType); - func->sig = funcType.getSignature(); + j = parseTypeUse(inner, j, func->type); func->setName(name, hasExplicitName); func->module = module; func->base = base; - functionTypes[name] = funcType; + functionTypes[name] = func->type; wasm.addFunction(func.release()); } else if (kind == ExternalKind::Global) { Type type; @@ -3403,7 +3399,7 @@ ElementSegment* SExpressionWasmBuilder::parseElemFinish( for (; i < s.size(); i++) { auto func = getFunctionName(*s[i]); segment->data.push_back( - Builder(wasm).makeRefFunc(func, functionTypes[func].getSignature())); + Builder(wasm).makeRefFunc(func, functionTypes[func])); } } return wasm.addElementSegment(std::move(segment)); diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index bed04b92b..f492c7110 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -461,7 +461,7 @@ private: curr, "return_call* should have unreachable type"); shouldBeEqual( - getFunction()->sig.results, + getFunction()->getResults(), sig.results, curr, "return_call* callee return type must match caller return type"); @@ -785,7 +785,7 @@ void FunctionValidator::visitCall(Call* curr) { if (!shouldBeTrue(!!target, curr, "call target must exist")) { return; } - validateCallParamsAndResult(curr, target->sig); + validateCallParamsAndResult(curr, target->getSig()); } void FunctionValidator::visitCallIndirect(CallIndirect* curr) { @@ -2481,17 +2481,17 @@ void FunctionValidator::visitArrayCopy(ArrayCopy* curr) { } void FunctionValidator::visitFunction(Function* curr) { - if (curr->sig.results.isTuple()) { + if (curr->getResults().isTuple()) { shouldBeTrue(getModule()->features.hasMultivalue(), curr->body, "Multivalue function results (multivalue is not enabled)"); } FeatureSet features; - for (const auto& param : curr->sig.params) { + for (const auto& param : curr->getParams()) { features |= param.getFeatures(); shouldBeTrue(param.isConcrete(), curr, "params must be concretely typed"); } - for (const auto& result : curr->sig.results) { + for (const auto& result : curr->getResults()) { features |= result.getFeatures(); shouldBeTrue(result.isConcrete(), curr, "results must be concretely typed"); } @@ -2512,12 +2512,12 @@ void FunctionValidator::visitFunction(Function* curr) { // if function has no result, it is ignored // if body is unreachable, it might be e.g. a return shouldBeSubType(curr->body->type, - curr->sig.results, + curr->getResults(), curr->body, "function body type must match, if function returns"); for (Type returnType : returnTypes) { shouldBeSubType(returnType, - curr->sig.results, + curr->getResults(), curr->body, "function result must match, if function has returns"); } @@ -2657,20 +2657,20 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { static void validateImports(Module& module, ValidationInfo& info) { ModuleUtils::iterImportedFunctions(module, [&](Function* curr) { - if (curr->sig.results.isTuple()) { + if (curr->getResults().isTuple()) { info.shouldBeTrue(module.features.hasMultivalue(), curr->name, "Imported multivalue function " "(multivalue is not enabled)"); } if (info.validateWeb) { - for (const auto& param : curr->sig.params) { + for (const auto& param : curr->getParams()) { info.shouldBeUnequal(param, Type(Type::i64), curr->name, "Imported function must not have i64 parameters"); } - for (const auto& result : curr->sig.results) { + for (const auto& result : curr->getResults()) { info.shouldBeUnequal(result, Type(Type::i64), curr->name, @@ -2693,14 +2693,14 @@ static void validateExports(Module& module, ValidationInfo& info) { if (curr->kind == ExternalKind::Function) { if (info.validateWeb) { Function* f = module.getFunction(curr->value); - for (const auto& param : f->sig.params) { + for (const auto& param : f->getParams()) { info.shouldBeUnequal( param, Type(Type::i64), f->name, "Exported function must not have i64 parameters"); } - for (const auto& result : f->sig.results) { + for (const auto& result : f->getResults()) { info.shouldBeUnequal(result, Type(Type::i64), f->name, @@ -3007,10 +3007,10 @@ static void validateModule(Module& module, ValidationInfo& info) { auto func = module.getFunctionOrNull(module.start); if (info.shouldBeTrue( func != nullptr, module.start, "start must be found")) { - info.shouldBeTrue(func->sig.params == Type::none, + info.shouldBeTrue(func->getParams() == Type::none, module.start, "start must have 0 params"); - info.shouldBeTrue(func->sig.results == Type::none, + info.shouldBeTrue(func->getResults() == Type::none, module.start, "start must not return a value"); } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 54c82159f..5b1163ae0 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -1087,14 +1087,14 @@ void RefAs::finalize() { } } -size_t Function::getNumParams() { return sig.params.size(); } +size_t Function::getNumParams() { return getParams().size(); } size_t Function::getNumVars() { return vars.size(); } -size_t Function::getNumLocals() { return sig.params.size() + vars.size(); } +size_t Function::getNumLocals() { return getParams().size() + vars.size(); } bool Function::isParam(Index index) { - size_t size = sig.params.size(); + size_t size = getParams().size(); assert(index < size + vars.size()); return index < size; } @@ -1141,12 +1141,12 @@ Index Function::getLocalIndex(Name name) { return iter->second; } -Index Function::getVarIndexBase() { return sig.params.size(); } +Index Function::getVarIndexBase() { return getParams().size(); } Type Function::getLocalType(Index index) { - auto numParams = sig.params.size(); + auto numParams = getParams().size(); if (index < numParams) { - return sig.params[index]; + return getParams()[index]; } else if (isVar(index)) { return vars[index - numParams]; } else { |