summaryrefslogtreecommitdiff
path: root/src/wasm
diff options
context:
space:
mode:
authorThomas Lively <7121787+tlively@users.noreply.github.com>2021-07-01 01:56:23 +0000
committerGitHub <noreply@github.com>2021-06-30 18:56:23 -0700
commitca27f40a2f1070a16ee7c0efc18ff35d342d8027 (patch)
treeab0f2b1b731737bc409db21f677b97be16f67c0f /src/wasm
parent10ef52d62468aec5762742930630e882dc5e5c0b (diff)
downloadbinaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.tar.gz
binaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.tar.bz2
binaryen-ca27f40a2f1070a16ee7c0efc18ff35d342d8027.zip
Preserve Function HeapTypes (#3952)
When using nominal types, func.ref of two functions with identical signatures but different HeapTypes will yield different types. To preserve these semantics, Functions need to track their HeapTypes, not just their Signatures. This PR replaces the Signature field in Function with a HeapType field and adds new utility methods to make it almost as simple to update and query the function HeapType as it was to update and query the Function Signature.
Diffstat (limited to 'src/wasm')
-rw-r--r--src/wasm/wasm-binary.cpp16
-rw-r--r--src/wasm/wasm-s-parser.cpp20
-rw-r--r--src/wasm/wasm-validator.cpp28
-rw-r--r--src/wasm/wasm.cpp12
4 files changed, 36 insertions, 40 deletions
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 779df5c65..31fe1bc6e 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -269,7 +269,7 @@ void WasmBinaryWriter::writeImports() {
BYN_TRACE("write one function\n");
writeImportHeader(func);
o << U32LEB(int32_t(ExternalKind::Function));
- o << U32LEB(getTypeIndex(func->sig));
+ o << U32LEB(getTypeIndex(func->type));
});
ModuleUtils::iterImportedGlobals(*wasm, [&](Global* global) {
BYN_TRACE("write one global\n");
@@ -318,7 +318,7 @@ void WasmBinaryWriter::writeFunctionSignatures() {
o << U32LEB(importInfo->getNumDefinedFunctions());
ModuleUtils::iterDefinedFunctions(*wasm, [&](Function* func) {
BYN_TRACE("write one\n");
- o << U32LEB(getTypeIndex(func->sig));
+ o << U32LEB(getTypeIndex(func->type));
});
finishSection(start);
}
@@ -2008,8 +2008,7 @@ void WasmBinaryBuilder::readImports() {
Name name(std::string("fimport$") + std::to_string(functionCounter++));
auto index = getU32LEB();
functionTypes.push_back(getTypeByIndex(index));
- auto curr =
- builder.makeFunction(name, getSignatureByTypeIndex(index), {});
+ auto curr = builder.makeFunction(name, getTypeByIndex(index), {});
curr->module = module;
curr->base = base;
functionImports.push_back(curr.get());
@@ -2158,7 +2157,7 @@ void WasmBinaryBuilder::readFunctions() {
auto* func = new Function;
func->name = Name::fromInt(i);
- func->sig = getSignatureByFunctionIndex(functionImports.size() + i);
+ func->type = getTypeByFunctionIndex(functionImports.size() + i);
currFunction = func;
if (DWARF) {
@@ -2197,7 +2196,7 @@ void WasmBinaryBuilder::readFunctions() {
auto currFunctionIndex = functionImports.size() + functions.size();
bool isStart = startIndex == currFunctionIndex;
if (!skipFunctionBodies || isStart) {
- func->body = getBlockOrSingleton(func->sig.results);
+ func->body = getBlockOrSingleton(func->getResults());
} else {
// When skipping the function body we need to put something valid in
// their place so we validate. An unreachable is always acceptable
@@ -6004,8 +6003,9 @@ void WasmBinaryBuilder::visitSelect(Select* curr, uint8_t code) {
void WasmBinaryBuilder::visitReturn(Return* curr) {
BYN_TRACE("zz node: Return\n");
requireFunctionContext("return");
- if (currFunction->sig.results.isConcrete()) {
- curr->value = popTypedExpression(currFunction->sig.results);
+ Type type = currFunction->getResults();
+ if (type.isConcrete()) {
+ curr->value = popTypedExpression(type);
}
curr->finalize();
}
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index bda94dec5..25be13f58 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -914,9 +914,7 @@ void SExpressionWasmBuilder::preParseFunctionType(Element& s) {
}
functionNames.push_back(name);
functionCounter++;
- HeapType type;
- parseTypeUse(s, i, type);
- functionTypes[name] = type;
+ parseTypeUse(s, i, functionTypes[name]);
}
size_t SExpressionWasmBuilder::parseFunctionNames(Element& s,
@@ -1007,12 +1005,12 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) {
im->setName(name, hasExplicitName);
im->module = importModule;
im->base = importBase;
- im->sig = type.getSignature();
+ im->type = type;
functionTypes[name] = type;
if (wasm.getFunctionOrNull(im->name)) {
throw ParseException("duplicate import", s.line, s.col);
}
- wasm.addFunction(im.release());
+ wasm.addFunction(std::move(im));
if (currFunction) {
throw ParseException("import module inside function dec", s.line, s.col);
}
@@ -1034,8 +1032,8 @@ void SExpressionWasmBuilder::parseFunction(Element& s, bool preParseImport) {
}
// make a new function
- currFunction = std::unique_ptr<Function>(Builder(wasm).makeFunction(
- name, std::move(params), type.getSignature().results, std::move(vars)));
+ currFunction = std::unique_ptr<Function>(
+ Builder(wasm).makeFunction(name, std::move(params), type, std::move(vars)));
currFunction->profile = profile;
// parse body
@@ -3045,13 +3043,11 @@ void SExpressionWasmBuilder::parseImport(Element& s) {
if (kind == ExternalKind::Function) {
auto func = make_unique<Function>();
- HeapType funcType;
- j = parseTypeUse(inner, j, funcType);
- func->sig = funcType.getSignature();
+ j = parseTypeUse(inner, j, func->type);
func->setName(name, hasExplicitName);
func->module = module;
func->base = base;
- functionTypes[name] = funcType;
+ functionTypes[name] = func->type;
wasm.addFunction(func.release());
} else if (kind == ExternalKind::Global) {
Type type;
@@ -3403,7 +3399,7 @@ ElementSegment* SExpressionWasmBuilder::parseElemFinish(
for (; i < s.size(); i++) {
auto func = getFunctionName(*s[i]);
segment->data.push_back(
- Builder(wasm).makeRefFunc(func, functionTypes[func].getSignature()));
+ Builder(wasm).makeRefFunc(func, functionTypes[func]));
}
}
return wasm.addElementSegment(std::move(segment));
diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp
index bed04b92b..f492c7110 100644
--- a/src/wasm/wasm-validator.cpp
+++ b/src/wasm/wasm-validator.cpp
@@ -461,7 +461,7 @@ private:
curr,
"return_call* should have unreachable type");
shouldBeEqual(
- getFunction()->sig.results,
+ getFunction()->getResults(),
sig.results,
curr,
"return_call* callee return type must match caller return type");
@@ -785,7 +785,7 @@ void FunctionValidator::visitCall(Call* curr) {
if (!shouldBeTrue(!!target, curr, "call target must exist")) {
return;
}
- validateCallParamsAndResult(curr, target->sig);
+ validateCallParamsAndResult(curr, target->getSig());
}
void FunctionValidator::visitCallIndirect(CallIndirect* curr) {
@@ -2481,17 +2481,17 @@ void FunctionValidator::visitArrayCopy(ArrayCopy* curr) {
}
void FunctionValidator::visitFunction(Function* curr) {
- if (curr->sig.results.isTuple()) {
+ if (curr->getResults().isTuple()) {
shouldBeTrue(getModule()->features.hasMultivalue(),
curr->body,
"Multivalue function results (multivalue is not enabled)");
}
FeatureSet features;
- for (const auto& param : curr->sig.params) {
+ for (const auto& param : curr->getParams()) {
features |= param.getFeatures();
shouldBeTrue(param.isConcrete(), curr, "params must be concretely typed");
}
- for (const auto& result : curr->sig.results) {
+ for (const auto& result : curr->getResults()) {
features |= result.getFeatures();
shouldBeTrue(result.isConcrete(), curr, "results must be concretely typed");
}
@@ -2512,12 +2512,12 @@ void FunctionValidator::visitFunction(Function* curr) {
// if function has no result, it is ignored
// if body is unreachable, it might be e.g. a return
shouldBeSubType(curr->body->type,
- curr->sig.results,
+ curr->getResults(),
curr->body,
"function body type must match, if function returns");
for (Type returnType : returnTypes) {
shouldBeSubType(returnType,
- curr->sig.results,
+ curr->getResults(),
curr->body,
"function result must match, if function has returns");
}
@@ -2657,20 +2657,20 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) {
static void validateImports(Module& module, ValidationInfo& info) {
ModuleUtils::iterImportedFunctions(module, [&](Function* curr) {
- if (curr->sig.results.isTuple()) {
+ if (curr->getResults().isTuple()) {
info.shouldBeTrue(module.features.hasMultivalue(),
curr->name,
"Imported multivalue function "
"(multivalue is not enabled)");
}
if (info.validateWeb) {
- for (const auto& param : curr->sig.params) {
+ for (const auto& param : curr->getParams()) {
info.shouldBeUnequal(param,
Type(Type::i64),
curr->name,
"Imported function must not have i64 parameters");
}
- for (const auto& result : curr->sig.results) {
+ for (const auto& result : curr->getResults()) {
info.shouldBeUnequal(result,
Type(Type::i64),
curr->name,
@@ -2693,14 +2693,14 @@ static void validateExports(Module& module, ValidationInfo& info) {
if (curr->kind == ExternalKind::Function) {
if (info.validateWeb) {
Function* f = module.getFunction(curr->value);
- for (const auto& param : f->sig.params) {
+ for (const auto& param : f->getParams()) {
info.shouldBeUnequal(
param,
Type(Type::i64),
f->name,
"Exported function must not have i64 parameters");
}
- for (const auto& result : f->sig.results) {
+ for (const auto& result : f->getResults()) {
info.shouldBeUnequal(result,
Type(Type::i64),
f->name,
@@ -3007,10 +3007,10 @@ static void validateModule(Module& module, ValidationInfo& info) {
auto func = module.getFunctionOrNull(module.start);
if (info.shouldBeTrue(
func != nullptr, module.start, "start must be found")) {
- info.shouldBeTrue(func->sig.params == Type::none,
+ info.shouldBeTrue(func->getParams() == Type::none,
module.start,
"start must have 0 params");
- info.shouldBeTrue(func->sig.results == Type::none,
+ info.shouldBeTrue(func->getResults() == Type::none,
module.start,
"start must not return a value");
}
diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp
index 54c82159f..5b1163ae0 100644
--- a/src/wasm/wasm.cpp
+++ b/src/wasm/wasm.cpp
@@ -1087,14 +1087,14 @@ void RefAs::finalize() {
}
}
-size_t Function::getNumParams() { return sig.params.size(); }
+size_t Function::getNumParams() { return getParams().size(); }
size_t Function::getNumVars() { return vars.size(); }
-size_t Function::getNumLocals() { return sig.params.size() + vars.size(); }
+size_t Function::getNumLocals() { return getParams().size() + vars.size(); }
bool Function::isParam(Index index) {
- size_t size = sig.params.size();
+ size_t size = getParams().size();
assert(index < size + vars.size());
return index < size;
}
@@ -1141,12 +1141,12 @@ Index Function::getLocalIndex(Name name) {
return iter->second;
}
-Index Function::getVarIndexBase() { return sig.params.size(); }
+Index Function::getVarIndexBase() { return getParams().size(); }
Type Function::getLocalType(Index index) {
- auto numParams = sig.params.size();
+ auto numParams = getParams().size();
if (index < numParams) {
- return sig.params[index];
+ return getParams()[index];
} else if (isVar(index)) {
return vars[index - numParams];
} else {