summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorThomas Lively <tlively@google.com>2023-10-18 20:37:48 +0200
committerGitHub <noreply@github.com>2023-10-18 18:37:48 +0000
commitf50e933f639c24f3a5814980fb20e6f7e1435184 (patch)
treec118bfbfbb54a5c0e1879af650131503ad59bd97 /src
parent89c02aad305474aea1d413e110aadd68278a13d6 (diff)
downloadbinaryen-f50e933f639c24f3a5814980fb20e6f7e1435184.tar.gz
binaryen-f50e933f639c24f3a5814980fb20e6f7e1435184.tar.bz2
binaryen-f50e933f639c24f3a5814980fb20e6f7e1435184.zip
Reuse existing function types for blocks (#6022)
Type annotations on multivalue blocks (and loops, ifs, and trys) are type indices that refer to function types in the type section. For these type annotations, the identities of the function types does not matter. As long as the referenced type has the correct parameters and results, it will be valid to use. Previously, when collecting module types, we always used the "default" function type for multivalue control flow, i.e. we used a final function type with no supertypes in a singleton rec group. However, in cases where the program already contains another function type with the expected signature, using the default type is unnecessary and bloats the type section. Update the type collecting code to reuse existing function types for multivalue control flow where possible rather than unconditionally adding the default function type. Similarly, update the binary writer to use the first heap type with the required signature when emitting annotations on multivalue control flow structures. To make this all testable, update the printer to print the type annotations as well, rather than just the result types. Since the parser was not able to parse those newly emitted type annotations, update the parser as well.
Diffstat (limited to 'src')
-rw-r--r--src/ir/module-utils.cpp139
-rw-r--r--src/passes/Print.cpp35
-rw-r--r--src/wasm-binary.h2
-rw-r--r--src/wasm-s-parser.h2
-rw-r--r--src/wasm/wasm-binary.cpp16
-rw-r--r--src/wasm/wasm-s-parser.cpp44
-rw-r--r--src/wasm/wasm-stack.cpp2
7 files changed, 176 insertions, 64 deletions
diff --git a/src/ir/module-utils.cpp b/src/ir/module-utils.cpp
index 0da47e811..cd865a79c 100644
--- a/src/ir/module-utils.cpp
+++ b/src/ir/module-utils.cpp
@@ -230,10 +230,18 @@ void renameFunction(Module& wasm, Name oldName, Name newName) {
namespace {
// Helper for collecting HeapTypes and their frequencies.
-struct Counts : public InsertOrderedMap<HeapType, size_t> {
+struct Counts {
+ InsertOrderedMap<HeapType, size_t> counts;
+
+ // Multivalue control flow structures need a function type, but the identity
+ // of the function type (i.e. what recursion group it is in or whether it is
+ // final) doesn't matter. Save them for the end to see if we can re-use an
+ // existing function type with the necessary signature.
+ InsertOrderedMap<Signature, size_t> controlFlowSignatures;
+
void note(HeapType type) {
if (!type.isBasic()) {
- (*this)[type]++;
+ counts[type]++;
}
}
void note(Type type) {
@@ -244,7 +252,7 @@ struct Counts : public InsertOrderedMap<HeapType, size_t> {
// Ensure a type is included without increasing its count.
void include(HeapType type) {
if (!type.isBasic()) {
- (*this)[type];
+ counts[type];
}
}
void include(Type type) {
@@ -252,6 +260,18 @@ struct Counts : public InsertOrderedMap<HeapType, size_t> {
include(ht);
}
}
+ void noteControlFlow(Signature sig) {
+ // TODO: support control flow input parameters.
+ assert(sig.params.size() == 0);
+ if (sig.results.isTuple()) {
+ // We have to use a function type.
+ controlFlowSignatures[sig]++;
+ } else if (sig.results != Type::none) {
+ // The result type can be emitted directly instead of using a function
+ // type.
+ note(sig.results[0]);
+ }
+ }
};
struct CodeScanner
@@ -319,12 +339,7 @@ struct CodeScanner
} else if (auto* set = curr->dynCast<ArraySet>()) {
counts.note(set->ref->type);
} else if (Properties::isControlFlowStructure(curr)) {
- if (curr->type.isTuple()) {
- // TODO: Allow control flow to have input types as well
- counts.note(Signature(Type::none, curr->type));
- } else {
- counts.note(curr->type);
- }
+ counts.noteControlFlow(Signature(Type::none, curr->type));
}
}
};
@@ -332,7 +347,8 @@ struct CodeScanner
// Count the number of times each heap type that would appear in the binary is
// referenced. If `prune`, exclude types that are never referenced, even though
// a binary would be invalid without them.
-Counts getHeapTypeCounts(Module& wasm, bool prune = false) {
+InsertOrderedMap<HeapType, size_t> getHeapTypeCounts(Module& wasm,
+ bool prune = false) {
// Collect module-level info.
Counts counts;
CodeScanner(wasm, counts).walkModuleCode(&wasm);
@@ -363,18 +379,21 @@ Counts getHeapTypeCounts(Module& wasm, bool prune = false) {
// Combine the function info with the module info.
for (auto& [_, functionCounts] : analysis.map) {
- for (auto& [sig, count] : functionCounts) {
- counts[sig] += count;
+ for (auto& [type, count] : functionCounts.counts) {
+ counts.counts[type] += count;
+ }
+ for (auto& [sig, count] : functionCounts.controlFlowSignatures) {
+ counts.controlFlowSignatures[sig] += count;
}
}
if (prune) {
// Remove types that are not actually used.
- auto it = counts.begin();
- while (it != counts.end()) {
+ auto it = counts.counts.begin();
+ while (it != counts.counts.end()) {
if (it->second == 0) {
auto deleted = it++;
- counts.erase(deleted);
+ counts.counts.erase(deleted);
} else {
++it;
}
@@ -388,50 +407,75 @@ Counts getHeapTypeCounts(Module& wasm, bool prune = false) {
// appear in the type section once, so we just need to visit it once. Also
// track which recursion groups we've already processed to avoid quadratic
// behavior when there is a single large group.
- InsertOrderedSet<HeapType> newTypes;
- for (auto& [type, _] : counts) {
- newTypes.insert(type);
+ UniqueNonrepeatingDeferredQueue<HeapType> newTypes;
+ std::unordered_map<Signature, HeapType> seenSigs;
+ auto noteNewType = [&](HeapType type) {
+ newTypes.push(type);
+ if (type.isSignature()) {
+ seenSigs.insert({type.getSignature(), type});
+ }
+ };
+ for (auto& [type, _] : counts.counts) {
+ noteNewType(type);
}
+ auto controlFlowIt = counts.controlFlowSignatures.begin();
std::unordered_set<RecGroup> includedGroups;
while (!newTypes.empty()) {
- auto iter = newTypes.begin();
- auto ht = *iter;
- newTypes.erase(iter);
- for (HeapType child : ht.getHeapTypeChildren()) {
- if (!child.isBasic()) {
- if (!counts.count(child)) {
- newTypes.insert(child);
+ while (!newTypes.empty()) {
+ auto ht = newTypes.pop();
+ for (HeapType child : ht.getHeapTypeChildren()) {
+ if (!child.isBasic()) {
+ if (!counts.counts.count(child)) {
+ noteNewType(child);
+ }
+ counts.note(child);
}
- counts.note(child);
}
- }
- if (auto super = ht.getDeclaredSuperType()) {
- if (!counts.count(*super)) {
- newTypes.insert(*super);
- // We should unconditionally count supertypes, but while the type system
- // is in flux, skip counting them to keep the type orderings in nominal
- // test outputs more similar to the orderings in the equirecursive
- // outputs. FIXME
- counts.include(*super);
+ if (auto super = ht.getDeclaredSuperType()) {
+ if (!counts.counts.count(*super)) {
+ noteNewType(*super);
+ // We should unconditionally count supertypes, but while the type
+ // system is in flux, skip counting them to keep the type orderings in
+ // nominal test outputs more similar to the orderings in the
+ // equirecursive outputs. FIXME
+ counts.include(*super);
+ }
}
- }
- // Make sure we've noted the complete recursion group of each type as well.
- if (!prune) {
- auto recGroup = ht.getRecGroup();
- if (includedGroups.insert(recGroup).second) {
- for (auto type : recGroup) {
- if (!counts.count(type)) {
- newTypes.insert(type);
- counts.include(type);
+ // Make sure we've noted the complete recursion group of each type as
+ // well.
+ if (!prune) {
+ auto recGroup = ht.getRecGroup();
+ if (includedGroups.insert(recGroup).second) {
+ for (auto type : recGroup) {
+ if (!counts.counts.count(type)) {
+ noteNewType(type);
+ counts.include(type);
+ }
}
}
}
}
+
+ // We've found all the types there are to find without considering more
+ // control flow types. Consider one more control flow type and repeat.
+ for (; controlFlowIt != counts.controlFlowSignatures.end();
+ ++controlFlowIt) {
+ auto& [sig, count] = *controlFlowIt;
+ if (auto it = seenSigs.find(sig); it != seenSigs.end()) {
+ counts.counts[it->second] += count;
+ } else {
+ // We've never seen this signature before, so add a type for it.
+ HeapType type(sig);
+ noteNewType(type);
+ counts.counts[type] += count;
+ break;
+ }
+ }
}
- return counts;
+ return counts.counts;
}
void setIndices(IndexedHeapTypes& indexedTypes) {
@@ -561,12 +605,11 @@ std::vector<HeapType> getPrivateHeapTypes(Module& wasm) {
}
IndexedHeapTypes getOptimizedIndexedHeapTypes(Module& wasm) {
- Counts counts = getHeapTypeCounts(wasm);
+ auto counts = getHeapTypeCounts(wasm);
// Types have to be arranged into topologically ordered recursion groups.
// Under isorecrsive typing, the topological sort has to take all referenced
- // rec groups into account but under nominal typing it only has to take
- // supertypes into account. First, sort the groups by average use count among
+ // rec groups into account. First, sort the groups by average use count among
// their members so that the later topological sort will place frequently used
// types first.
struct GroupInfo {
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index edbb136a3..51d261126 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -169,6 +169,7 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
int controlFlowDepth = 0;
std::vector<HeapType> heapTypes;
+ std::unordered_map<Signature, HeapType> signatureTypes;
// Track the print indent so that we can see when it changes. That affects how
// we print debug annotations. In particular, we don't want to print repeated
@@ -242,6 +243,22 @@ struct PrintSExpression : public UnifiedExpressionVisitor<PrintSExpression> {
return printPrefixedTypes("param", type);
}
+ std::ostream& printBlockType(Signature sig) {
+ assert(sig.params == Type::none);
+ if (sig.results == Type::none) {
+ return o;
+ }
+ if (sig.results.isTuple()) {
+ if (auto it = signatureTypes.find(sig); it != signatureTypes.end()) {
+ o << "(type ";
+ printHeapType(it->second);
+ o << ") ";
+ }
+ }
+ printResultType(sig.results);
+ return o;
+ }
+
void printDebugLocation(const Function::DebugLocation& location);
void printDebugLocation(Expression* curr);
@@ -370,6 +387,10 @@ struct PrintExpressionContents
return parent.printParamType(type);
}
+ std::ostream& printBlockType(Signature sig) {
+ return parent.printBlockType(sig);
+ }
+
void visitBlock(Block* curr) {
printMedium(o, "block");
if (curr->name.is()) {
@@ -378,14 +399,14 @@ struct PrintExpressionContents
}
if (curr->type.isConcrete()) {
o << ' ';
- printResultType(curr->type);
+ printBlockType(Signature(Type::none, curr->type));
}
}
void visitIf(If* curr) {
printMedium(o, "if");
if (curr->type.isConcrete()) {
o << ' ';
- printResultType(curr->type);
+ printBlockType(Signature(Type::none, curr->type));
}
}
void visitLoop(Loop* curr) {
@@ -396,7 +417,7 @@ struct PrintExpressionContents
}
if (curr->type.isConcrete()) {
o << ' ';
- printResultType(curr->type);
+ printBlockType(Signature(Type::none, curr->type));
}
}
void visitBreak(Break* curr) {
@@ -1937,7 +1958,7 @@ struct PrintExpressionContents
}
if (curr->type.isConcrete()) {
o << ' ';
- printResultType(curr->type);
+ printBlockType(Signature(Type::none, curr->type));
}
}
void visitThrow(Throw* curr) {
@@ -2369,8 +2390,14 @@ void PrintSExpression::setModule(Module* module) {
currModule = module;
if (module) {
heapTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*module).types;
+ for (auto type : heapTypes) {
+ if (type.isSignature()) {
+ signatureTypes.insert({type.getSignature(), type});
+ }
+ }
} else {
heapTypes = {};
+ signatureTypes = {};
}
// Reset the type printer for this module's types (or absence thereof).
typePrinter.~TypePrinter();
diff --git a/src/wasm-binary.h b/src/wasm-binary.h
index 8a4d6969f..ee92fe90e 100644
--- a/src/wasm-binary.h
+++ b/src/wasm-binary.h
@@ -1422,6 +1422,7 @@ public:
uint32_t getDataSegmentIndex(Name name) const;
uint32_t getElementSegmentIndex(Name name) const;
uint32_t getTypeIndex(HeapType type) const;
+ uint32_t getSignatureIndex(Signature sig) const;
uint32_t getStringIndex(Name string) const;
void writeTableDeclarations();
@@ -1476,6 +1477,7 @@ private:
BufferWithRandomAccess& o;
BinaryIndexes indexes;
ModuleUtils::IndexedHeapTypes indexedTypes;
+ std::unordered_map<Signature, uint32_t> signatureIndexes;
bool debugInfo = true;
diff --git a/src/wasm-s-parser.h b/src/wasm-s-parser.h
index a2faea771..6853f2c01 100644
--- a/src/wasm-s-parser.h
+++ b/src/wasm-s-parser.h
@@ -327,7 +327,7 @@ private:
Expression* makeStringSliceIter(Element& s);
// Helper functions
- Type parseOptionalResultType(Element& s, Index& i);
+ Type parseBlockType(Element& s, Index& i);
Index parseMemoryLimits(Element& s, Index i, std::unique_ptr<Memory>& memory);
Index parseMemoryIndex(Element& s, Index i, std::unique_ptr<Memory>& memory);
Index parseMemoryForInstruction(const std::string& instrName,
diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp
index 2b96e839c..b12cb50c7 100644
--- a/src/wasm/wasm-binary.cpp
+++ b/src/wasm/wasm-binary.cpp
@@ -35,6 +35,11 @@ void WasmBinaryWriter::prepare() {
// Collect function types and their frequencies. Collect information in each
// function in parallel, then merge.
indexedTypes = ModuleUtils::getOptimizedIndexedHeapTypes(*wasm);
+ for (Index i = 0, size = indexedTypes.types.size(); i < size; ++i) {
+ if (indexedTypes.types[i].isSignature()) {
+ signatureIndexes.insert({indexedTypes.types[i].getSignature(), i});
+ }
+ }
importInfo = std::make_unique<ImportInfo>(*wasm);
}
@@ -686,6 +691,17 @@ uint32_t WasmBinaryWriter::getTypeIndex(HeapType type) const {
return it->second;
}
+uint32_t WasmBinaryWriter::getSignatureIndex(Signature sig) const {
+ auto it = signatureIndexes.find(sig);
+#ifndef NDEBUG
+ if (it == signatureIndexes.end()) {
+ std::cout << "Missing signature: " << sig << '\n';
+ assert(0);
+ }
+#endif
+ return it->second;
+}
+
uint32_t WasmBinaryWriter::getStringIndex(Name string) const {
auto it = stringIndexes.find(string);
assert(it != stringIndexes.end());
diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp
index 215d349e1..4074e4792 100644
--- a/src/wasm/wasm-s-parser.cpp
+++ b/src/wasm/wasm-s-parser.cpp
@@ -1445,7 +1445,7 @@ Expression* SExpressionWasmBuilder::makeUnary(Element& s, UnaryOp op) {
Expression* SExpressionWasmBuilder::makeSelect(Element& s) {
auto ret = allocator.alloc<Select>();
Index i = 1;
- Type type = parseOptionalResultType(s, i);
+ Type type = parseBlockType(s, i);
ret->ifTrue = parseExpression(s[i++]);
ret->ifFalse = parseExpression(s[i++]);
ret->condition = parseExpression(s[i]);
@@ -1603,7 +1603,7 @@ Expression* SExpressionWasmBuilder::makeBlock(Element& s) {
stack.emplace_back(Info{sp, curr, hadName});
curr->name = nameMapper.pushLabelName(sName);
// block signature
- curr->type = parseOptionalResultType(s, i);
+ curr->type = parseBlockType(s, i);
if (i >= s.size()) {
break; // empty block
}
@@ -1630,7 +1630,8 @@ Expression* SExpressionWasmBuilder::makeBlock(Element& s) {
while (i < s.size() && s[i]->isStr()) {
i++;
}
- if (i < s.size() && elementStartsWith(*s[i], RESULT)) {
+ while (i < s.size() && (elementStartsWith(*s[i], RESULT) ||
+ elementStartsWith(*s[i], TYPE))) {
i++;
}
if (t < int(stack.size()) - 1) {
@@ -2370,7 +2371,7 @@ Expression* SExpressionWasmBuilder::makeIf(Element& s) {
}
auto label = nameMapper.pushLabelName(sName);
// if signature
- Type type = parseOptionalResultType(s, i);
+ Type type = parseBlockType(s, i);
ret->condition = parseExpression(s[i++]);
ret->ifTrue = parseExpression(*s[i++]);
if (i < s.size()) {
@@ -2409,7 +2410,7 @@ SExpressionWasmBuilder::makeMaybeBlock(Element& s, size_t i, Type type) {
return ret;
}
-Type SExpressionWasmBuilder::parseOptionalResultType(Element& s, Index& i) {
+Type SExpressionWasmBuilder::parseBlockType(Element& s, Index& i) {
if (s.size() == i) {
return Type::none;
}
@@ -2420,11 +2421,34 @@ Type SExpressionWasmBuilder::parseOptionalResultType(Element& s, Index& i) {
return stringToType(s[i++]->str());
}
- Element& results = *s[i];
- IString id = results[0]->str();
+ Element* results = s[i];
+ IString id = (*results)[0]->str();
+ std::optional<Signature> usedType;
+ if (id == TYPE) {
+ auto type = parseHeapType(*(*results)[1]);
+ if (!type.isSignature()) {
+ throw SParseException("unexpected non-function type", s);
+ }
+ usedType = type.getSignature();
+ if (usedType->params != Type::none) {
+ throw SParseException("block input values are not yet supported", s);
+ }
+ i++;
+ results = s[i];
+ id = (*results)[0]->str();
+ }
+
if (id == RESULT) {
i++;
- return Type(parseResults(results));
+ auto type = Type(parseResults(*results));
+ if (usedType && usedType->results != type) {
+ throw SParseException("results do not match type", s);
+ }
+ return type;
+ }
+
+ if (usedType && usedType->results != Type::none) {
+ throw SParseException("results do not match type", s);
}
return Type::none;
}
@@ -2439,7 +2463,7 @@ Expression* SExpressionWasmBuilder::makeLoop(Element& s) {
sName = "loop-in";
}
ret->name = nameMapper.pushLabelName(sName);
- ret->type = parseOptionalResultType(s, i);
+ ret->type = parseBlockType(s, i);
ret->body = makeMaybeBlock(s, i, ret->type);
nameMapper.popLabelName(ret->name);
ret->finalize(ret->type);
@@ -2690,7 +2714,7 @@ Expression* SExpressionWasmBuilder::makeTry(Element& s) {
sName = "try";
}
ret->name = nameMapper.pushLabelName(sName);
- Type type = parseOptionalResultType(s, i); // signature
+ Type type = parseBlockType(s, i); // signature
if (!elementStartsWith(*s[i], "do")) {
throw SParseException("try body should start with 'do'", s, *s[i]);
diff --git a/src/wasm/wasm-stack.cpp b/src/wasm/wasm-stack.cpp
index 1f1141821..1d2363be3 100644
--- a/src/wasm/wasm-stack.cpp
+++ b/src/wasm/wasm-stack.cpp
@@ -26,7 +26,7 @@ void BinaryInstWriter::emitResultType(Type type) {
if (type == Type::unreachable) {
parent.writeType(Type::none);
} else if (type.isTuple()) {
- o << S32LEB(parent.getTypeIndex(Signature(Type::none, type)));
+ o << S32LEB(parent.getSignatureIndex(Signature(Type::none, type)));
} else {
parent.writeType(type);
}