summaryrefslogtreecommitdiff
path: root/src/passes
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes')
-rw-r--r--src/passes/DeadArgumentElimination.cpp6
-rw-r--r--src/passes/Directize.cpp44
-rw-r--r--src/passes/Inlining.cpp5
-rw-r--r--src/passes/MergeBlocks.cpp6
-rw-r--r--src/passes/Print.cpp42
5 files changed, 89 insertions, 14 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp
index 89d03f461..34637cf5a 100644
--- a/src/passes/DeadArgumentElimination.cpp
+++ b/src/passes/DeadArgumentElimination.cpp
@@ -143,6 +143,12 @@ struct DAEScanner
}
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ info->hasTailCalls = true;
+ }
+ }
+
void visitDrop(Drop* curr) {
if (auto* call = curr->value->dynCast<Call>()) {
info->droppedCalls[call] = getCurrentPointer();
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 0c1132b04..f966d1a5a 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -41,6 +41,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
FunctionDirectizer(TableUtils::FlatTable* flatTable) : flatTable(flatTable) {}
void visitCallIndirect(CallIndirect* curr) {
+ if (!flatTable) {
+ return;
+ }
if (auto* c = curr->target->dynCast<Const>()) {
Index index = c->value.geti32();
// If the index is invalid, or the type is wrong, we can
@@ -68,6 +71,15 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
}
}
+ void visitCallRef(CallRef* curr) {
+ if (auto* ref = curr->target->dynCast<RefFunc>()) {
+ // We know the target!
+ replaceCurrent(
+ Builder(*getModule())
+ .makeCall(ref->func, curr->operands, curr->type, curr->isReturn));
+ }
+ }
+
void doWalkFunction(Function* func) {
WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func);
if (changedTypes) {
@@ -76,7 +88,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
}
private:
+ // If null, then we cannot optimize call_indirects.
TableUtils::FlatTable* flatTable;
+
bool changedTypes = false;
void replaceWithUnreachable(CallIndirect* call) {
@@ -92,23 +106,31 @@ private:
struct Directize : public Pass {
void run(PassRunner* runner, Module* module) override {
+ bool canOptimizeCallIndirect = true;
+ TableUtils::FlatTable flatTable(module->table);
if (!module->table.exists) {
- return;
- }
- if (module->table.imported()) {
- return;
- }
- for (auto& ex : module->exports) {
- if (ex->kind == ExternalKind::Table) {
- return;
+ canOptimizeCallIndirect = false;
+ } else if (module->table.imported()) {
+ canOptimizeCallIndirect = false;
+ } else {
+ for (auto& ex : module->exports) {
+ if (ex->kind == ExternalKind::Table) {
+ canOptimizeCallIndirect = false;
+ }
+ }
+ if (!flatTable.valid) {
+ canOptimizeCallIndirect = false;
}
}
- TableUtils::FlatTable flatTable(module->table);
- if (!flatTable.valid) {
+ // Without typed function references, all we can do is optimize table
+ // accesses, so if we can't do that, stop.
+ if (!canOptimizeCallIndirect &&
+ !module->features.hasTypedFunctionReferences()) {
return;
}
// The table exists and is constant, so this is possible.
- FunctionDirectizer(&flatTable).run(runner, module);
+ FunctionDirectizer(canOptimizeCallIndirect ? &flatTable : nullptr)
+ .run(runner, module);
}
};
diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp
index bcab7318f..a44f02426 100644
--- a/src/passes/Inlining.cpp
+++ b/src/passes/Inlining.cpp
@@ -211,6 +211,11 @@ struct Updater : public PostWalker<Updater> {
handleReturnCall(curr, curr->sig.results);
}
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ handleReturnCall(curr, curr->target->type);
+ }
+ }
void visitLocalGet(LocalGet* curr) {
curr->index = localMapping[curr->index];
}
diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp
index 4ecec6669..33dbec77c 100644
--- a/src/passes/MergeBlocks.cpp
+++ b/src/passes/MergeBlocks.cpp
@@ -564,7 +564,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
void visitCall(Call* curr) { handleCall(curr); }
- void visitCallIndirect(CallIndirect* curr) {
+ template<typename T> void handleNonDirectCall(T* curr) {
FeatureSet features = getModule()->features;
Block* outer = nullptr;
for (Index i = 0; i < curr->operands.size(); i++) {
@@ -581,6 +581,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> {
optimize(curr, curr->target, outer);
}
+ void visitCallIndirect(CallIndirect* curr) { handleNonDirectCall(curr); }
+
+ void visitCallRef(CallRef* curr) { handleNonDirectCall(curr); }
+
void visitThrow(Throw* curr) {
Block* outer = nullptr;
for (Index i = 0; i < curr->operands.size(); i++) {
diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp
index e512d398f..864a46362 100644
--- a/src/passes/Print.cpp
+++ b/src/passes/Print.cpp
@@ -87,14 +87,35 @@ struct SigName {
};
std::ostream& operator<<(std::ostream& os, SigName sigName) {
- auto printType = [&](Type type) {
+ std::function<void(Type)> printType = [&](Type type) {
if (type == Type::none) {
os << "none";
} else {
auto sep = "";
for (const auto& t : type) {
- os << sep << t;
+ os << sep;
sep = "_";
+ if (t.isRef()) {
+ auto heapType = t.getHeapType();
+ if (heapType.isSignature()) {
+ auto sig = heapType.getSignature();
+ os << "ref";
+ if (t.isNullable()) {
+ os << "_null";
+ }
+ os << "<";
+ for (auto s : sig.params) {
+ printType(s);
+ }
+ os << "_->_";
+ for (auto s : sig.results) {
+ printType(s);
+ }
+ os << ">";
+ continue;
+ }
+ }
+ os << t;
}
}
};
@@ -1561,6 +1582,13 @@ struct PrintExpressionContents
void visitI31Get(I31Get* curr) {
printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u");
}
+ void visitCallRef(CallRef* curr) {
+ if (curr->isReturn) {
+ printMedium(o, "return_call_ref");
+ } else {
+ printMedium(o, "call_ref");
+ }
+ }
void visitRefTest(RefTest* curr) {
printMedium(o, "ref.test");
WASM_UNREACHABLE("TODO (gc): ref.test");
@@ -2216,6 +2244,16 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> {
printFullLine(curr->i31);
decIndent();
}
+ void visitCallRef(CallRef* curr) {
+ o << '(';
+ PrintExpressionContents(currFunction, o).visit(curr);
+ incIndent();
+ for (auto operand : curr->operands) {
+ printFullLine(operand);
+ }
+ printFullLine(curr->target);
+ decIndent();
+ }
void visitRefTest(RefTest* curr) {
o << '(';
PrintExpressionContents(currFunction, o).visit(curr);