diff options
Diffstat (limited to 'src/passes')
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 6 | ||||
-rw-r--r-- | src/passes/Directize.cpp | 44 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 5 | ||||
-rw-r--r-- | src/passes/MergeBlocks.cpp | 6 | ||||
-rw-r--r-- | src/passes/Print.cpp | 42 |
5 files changed, 89 insertions, 14 deletions
diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 89d03f461..34637cf5a 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -143,6 +143,12 @@ struct DAEScanner } } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + info->hasTailCalls = true; + } + } + void visitDrop(Drop* curr) { if (auto* call = curr->value->dynCast<Call>()) { info->droppedCalls[call] = getCurrentPointer(); diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 0c1132b04..f966d1a5a 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -41,6 +41,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { FunctionDirectizer(TableUtils::FlatTable* flatTable) : flatTable(flatTable) {} void visitCallIndirect(CallIndirect* curr) { + if (!flatTable) { + return; + } if (auto* c = curr->target->dynCast<Const>()) { Index index = c->value.geti32(); // If the index is invalid, or the type is wrong, we can @@ -68,6 +71,15 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { } } + void visitCallRef(CallRef* curr) { + if (auto* ref = curr->target->dynCast<RefFunc>()) { + // We know the target! + replaceCurrent( + Builder(*getModule()) + .makeCall(ref->func, curr->operands, curr->type, curr->isReturn)); + } + } + void doWalkFunction(Function* func) { WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func); if (changedTypes) { @@ -76,7 +88,9 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { } private: + // If null, then we cannot optimize call_indirects. TableUtils::FlatTable* flatTable; + bool changedTypes = false; void replaceWithUnreachable(CallIndirect* call) { @@ -92,23 +106,31 @@ private: struct Directize : public Pass { void run(PassRunner* runner, Module* module) override { + bool canOptimizeCallIndirect = true; + TableUtils::FlatTable flatTable(module->table); if (!module->table.exists) { - return; - } - if (module->table.imported()) { - return; - } - for (auto& ex : module->exports) { - if (ex->kind == ExternalKind::Table) { - return; + canOptimizeCallIndirect = false; + } else if (module->table.imported()) { + canOptimizeCallIndirect = false; + } else { + for (auto& ex : module->exports) { + if (ex->kind == ExternalKind::Table) { + canOptimizeCallIndirect = false; + } + } + if (!flatTable.valid) { + canOptimizeCallIndirect = false; } } - TableUtils::FlatTable flatTable(module->table); - if (!flatTable.valid) { + // Without typed function references, all we can do is optimize table + // accesses, so if we can't do that, stop. + if (!canOptimizeCallIndirect && + !module->features.hasTypedFunctionReferences()) { return; } // The table exists and is constant, so this is possible. - FunctionDirectizer(&flatTable).run(runner, module); + FunctionDirectizer(canOptimizeCallIndirect ? &flatTable : nullptr) + .run(runner, module); } }; diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index bcab7318f..a44f02426 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -211,6 +211,11 @@ struct Updater : public PostWalker<Updater> { handleReturnCall(curr, curr->sig.results); } } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + handleReturnCall(curr, curr->target->type); + } + } void visitLocalGet(LocalGet* curr) { curr->index = localMapping[curr->index]; } diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 4ecec6669..33dbec77c 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -564,7 +564,7 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { void visitCall(Call* curr) { handleCall(curr); } - void visitCallIndirect(CallIndirect* curr) { + template<typename T> void handleNonDirectCall(T* curr) { FeatureSet features = getModule()->features; Block* outer = nullptr; for (Index i = 0; i < curr->operands.size(); i++) { @@ -581,6 +581,10 @@ struct MergeBlocks : public WalkerPass<PostWalker<MergeBlocks>> { optimize(curr, curr->target, outer); } + void visitCallIndirect(CallIndirect* curr) { handleNonDirectCall(curr); } + + void visitCallRef(CallRef* curr) { handleNonDirectCall(curr); } + void visitThrow(Throw* curr) { Block* outer = nullptr; for (Index i = 0; i < curr->operands.size(); i++) { diff --git a/src/passes/Print.cpp b/src/passes/Print.cpp index e512d398f..864a46362 100644 --- a/src/passes/Print.cpp +++ b/src/passes/Print.cpp @@ -87,14 +87,35 @@ struct SigName { }; std::ostream& operator<<(std::ostream& os, SigName sigName) { - auto printType = [&](Type type) { + std::function<void(Type)> printType = [&](Type type) { if (type == Type::none) { os << "none"; } else { auto sep = ""; for (const auto& t : type) { - os << sep << t; + os << sep; sep = "_"; + if (t.isRef()) { + auto heapType = t.getHeapType(); + if (heapType.isSignature()) { + auto sig = heapType.getSignature(); + os << "ref"; + if (t.isNullable()) { + os << "_null"; + } + os << "<"; + for (auto s : sig.params) { + printType(s); + } + os << "_->_"; + for (auto s : sig.results) { + printType(s); + } + os << ">"; + continue; + } + } + os << t; } } }; @@ -1561,6 +1582,13 @@ struct PrintExpressionContents void visitI31Get(I31Get* curr) { printMedium(o, curr->signed_ ? "i31.get_s" : "i31.get_u"); } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + printMedium(o, "return_call_ref"); + } else { + printMedium(o, "call_ref"); + } + } void visitRefTest(RefTest* curr) { printMedium(o, "ref.test"); WASM_UNREACHABLE("TODO (gc): ref.test"); @@ -2216,6 +2244,16 @@ struct PrintSExpression : public OverriddenVisitor<PrintSExpression> { printFullLine(curr->i31); decIndent(); } + void visitCallRef(CallRef* curr) { + o << '('; + PrintExpressionContents(currFunction, o).visit(curr); + incIndent(); + for (auto operand : curr->operands) { + printFullLine(operand); + } + printFullLine(curr->target); + decIndent(); + } void visitRefTest(RefTest* curr) { o << '('; PrintExpressionContents(currFunction, o).visit(curr); |