summaryrefslogtreecommitdiff
path: root/src/passes/Directize.cpp
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2022-05-27 11:16:39 -0700
committerGitHub <noreply@github.com>2022-05-27 11:16:39 -0700
commitf21774f3865e506b6dea912af8f8be16fd29dacb (patch)
treecd08af70a3301d44c231e093acbb0799a62123bb /src/passes/Directize.cpp
parentac2190c1bd1a0be12a1d25e62695f791f270d050 (diff)
downloadbinaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.tar.gz
binaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.tar.bz2
binaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.zip
OptimizeInstructions: Turn call_ref of a select into an if over two direct calls (#4660)
This extends the existing call_indirect code to do the same for call_ref, basically. The shared code is added to a new helper utility.
Diffstat (limited to 'src/passes/Directize.cpp')
-rw-r--r--src/passes/Directize.cpp115
1 files changed, 48 insertions, 67 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 6853f1faf..e338272a9 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -22,6 +22,7 @@
#include <unordered_map>
+#include "call-utils.h"
#include "ir/table-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
@@ -59,62 +60,19 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
return;
}
- // If the target is a select of two different constants, we can emit two
- // direct calls.
- // TODO: handle 3+
- // TODO: handle the case where just one arm is a constant?
- if (auto* select = curr->target->dynCast<Select>()) {
- if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) {
- Builder builder(*getModule());
- auto* func = getFunction();
- std::vector<Expression*> blockContents;
-
- if (select->condition->type == Type::unreachable) {
- // Leave this for DCE.
- return;
- }
-
- // We must use the operands twice, and also must move the condition to
- // execute first; use locals for them all. While doing so, if we see
- // any are unreachable, stop trying to optimize and leave this for DCE.
- std::vector<Index> operandLocals;
- for (auto* operand : curr->operands) {
- if (operand->type == Type::unreachable ||
- !TypeUpdating::canHandleAsLocal(operand->type)) {
- return;
- }
- }
-
- // None of the types are a problem, so we can proceed to add new vars as
- // needed and perform this optimization.
- for (auto* operand : curr->operands) {
- auto currLocal = builder.addVar(func, operand->type);
- operandLocals.push_back(currLocal);
- blockContents.push_back(builder.makeLocalSet(currLocal, operand));
- // By adding locals we must make type adjustments at the end.
- changedTypes = true;
- }
-
- // Build the calls.
- auto numOperands = curr->operands.size();
- auto getOperands = [&]() {
- std::vector<Expression*> newOperands(numOperands);
- for (Index i = 0; i < numOperands; i++) {
- newOperands[i] =
- builder.makeLocalGet(operandLocals[i], curr->operands[i]->type);
- }
- return newOperands;
- };
- auto* ifTrueCall =
- makeDirectCall(getOperands(), select->ifTrue, flatTable, curr);
- auto* ifFalseCall =
- makeDirectCall(getOperands(), select->ifFalse, flatTable, curr);
-
- // Create the if to pick the calls, and emit the final block.
- auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
- blockContents.push_back(iff);
- replaceCurrent(builder.makeBlock(blockContents));
- }
+ // Emit direct calls for things like a select over constants.
+ if (auto* calls = CallUtils::convertToDirectCalls(
+ curr,
+ [&](Expression* target) {
+ return getTargetInfo(target, flatTable, curr);
+ },
+ *getFunction(),
+ *getModule())) {
+ replaceCurrent(calls);
+ // Note that types may have changed, as the utility here can add locals
+ // which require fixups if they are non-nullable, for example.
+ changedTypes = true;
+ return;
}
}
@@ -131,6 +89,36 @@ private:
bool changedTypes = false;
+ // Given an expression that we will use as the target of an indirect call,
+ // analyze it and return one of the results of CallUtils::IndirectCallInfo,
+ // that is, whether we know a direct call target, or we know it will trap, or
+ // if we know nothing.
+ CallUtils::IndirectCallInfo
+ getTargetInfo(Expression* target,
+ const TableUtils::FlatTable& flatTable,
+ CallIndirect* original) {
+ auto* c = target->dynCast<Const>();
+ if (!c) {
+ return CallUtils::Unknown{};
+ }
+
+ Index index = c->value.geti32();
+
+ // If the index is invalid, or the type is wrong, then this will trap.
+ if (index >= flatTable.names.size()) {
+ return CallUtils::Trap{};
+ }
+ auto name = flatTable.names[index];
+ if (!name.is()) {
+ return CallUtils::Trap{};
+ }
+ auto* func = getModule()->getFunction(name);
+ if (original->heapType != func->type) {
+ return CallUtils::Trap{};
+ }
+ return CallUtils::Known{name};
+ }
+
// Create a direct call for a given list of operands, an expression which is
// known to contain a constant indicating the table offset, and the relevant
// table. If we can see that the call will trap, instead return an
@@ -139,23 +127,16 @@ private:
Expression* c,
const TableUtils::FlatTable& flatTable,
CallIndirect* original) {
- Index index = c->cast<Const>()->value.geti32();
-
// If the index is invalid, or the type is wrong, we can
// emit an unreachable here, since in Binaryen it is ok to
// reorder/replace traps when optimizing (but never to
// remove them, at least not by default).
- if (index >= flatTable.names.size()) {
- return replaceWithUnreachable(operands);
- }
- auto name = flatTable.names[index];
- if (!name.is()) {
- return replaceWithUnreachable(operands);
- }
- auto* func = getModule()->getFunction(name);
- if (original->heapType != func->type) {
+ auto info = getTargetInfo(c, flatTable, original);
+ if (std::get_if<CallUtils::Trap>(&info)) {
return replaceWithUnreachable(operands);
}
+ assert(std::get_if<CallUtils::Known>(&info));
+ auto name = std::get_if<CallUtils::Known>(&info)->target;
// Everything looks good!
return Builder(*getModule())