summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2021-10-04 15:45:43 -0700
committerGitHub <noreply@github.com>2021-10-04 15:45:43 -0700
commit8895a2417d37e3444c98f0023e98ab9151d04290 (patch)
tree4b233b178e80f39a1177c26e093613fa932fbc25 /src
parent972062c373e599a5f75d70fa0569e9d0b57854bf (diff)
downloadbinaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.gz
binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.bz2
binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.zip
Optimize call_indirect of a select of two constants (#4208)
(call_indirect ..args.. (select (i32.const x) (i32.const y) (condition) ) ) => (if (condition) (call $func-for-x ..args.. ) (call $func-for-y ..args.. ) ) To do this we must reorder the condition with the args, and also use the args more than once, so place them all in locals. This works towards the goal of polymorphic devirtualization, that is, turning an indirect call of more than one possible target into more than one direct call.
Diffstat (limited to 'src')
-rw-r--r--src/passes/Directize.cpp129
1 files changed, 100 insertions, 29 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 0f04b5f4c..5a9b1b984 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -23,6 +23,7 @@
#include <unordered_map>
#include "ir/table-utils.h"
+#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
@@ -50,30 +51,65 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
auto& flatTable = it->second;
- if (auto* c = curr->target->dynCast<Const>()) {
- Index index = c->value.geti32();
- // If the index is invalid, or the type is wrong, we can
- // emit an unreachable here, since in Binaryen it is ok to
- // reorder/replace traps when optimizing (but never to
- // remove them, at least not by default).
- if (index >= flatTable.names.size()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto name = flatTable.names[index];
- if (!name.is()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto* func = getModule()->getFunction(name);
- if (curr->sig != func->getSig()) {
- replaceWithUnreachable(curr);
- return;
+ // If the target is constant, we can emit a direct call.
+ if (curr->target->is<Const>()) {
+ std::vector<Expression*> operands(curr->operands.begin(),
+ curr->operands.end());
+ replaceCurrent(makeDirectCall(operands, curr->target, flatTable, curr));
+ return;
+ }
+
+ // If the target is a select of two different constants, we can emit two
+ // direct calls.
+ // TODO: handle 3+
+ // TODO: handle the case where just one arm is a constant?
+ if (auto* select = curr->target->dynCast<Select>()) {
+ if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) {
+ Builder builder(*getModule());
+ auto* func = getFunction();
+ std::vector<Expression*> blockContents;
+
+ // We must use the operands twice, and also must move the condition to
+ // execute first; use locals for them all. While doing so, if we see
+ // any are unreachable, stop trying to optimize and leave this for DCE.
+ std::vector<Index> operandLocals;
+ for (auto* operand : curr->operands) {
+ if (operand->type == Type::unreachable ||
+ !TypeUpdating::canHandleAsLocal(operand->type)) {
+ return;
+ }
+ auto currLocal = builder.addVar(func, operand->type);
+ operandLocals.push_back(currLocal);
+ blockContents.push_back(builder.makeLocalSet(currLocal, operand));
+ }
+
+ if (select->condition->type == Type::unreachable) {
+ return;
+ }
+
+ // Build the calls.
+ auto numOperands = curr->operands.size();
+ auto getOperands = [&]() {
+ std::vector<Expression*> newOperands(numOperands);
+ for (Index i = 0; i < numOperands; i++) {
+ newOperands[i] =
+ builder.makeLocalGet(operandLocals[i], curr->operands[i]->type);
+ }
+ return newOperands;
+ };
+ auto* ifTrueCall =
+ makeDirectCall(getOperands(), select->ifTrue, flatTable, curr);
+ auto* ifFalseCall =
+ makeDirectCall(getOperands(), select->ifFalse, flatTable, curr);
+
+ // Create the if to pick the calls, and emit the final block.
+ auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
+ blockContents.push_back(iff);
+ replaceCurrent(builder.makeBlock(blockContents));
+
+ // By adding locals we must make type adjustments at the end.
+ changedTypes = true;
}
- // Everything looks good!
- replaceCurrent(
- Builder(*getModule())
- .makeCall(name, curr->operands, curr->type, curr->isReturn));
}
}
@@ -81,6 +117,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func);
if (changedTypes) {
ReFinalize().walkFunctionInModule(func, getModule());
+ TypeUpdating::handleNonDefaultableLocals(func, *getModule());
}
}
@@ -89,14 +126,48 @@ private:
bool changedTypes = false;
- void replaceWithUnreachable(CallIndirect* call) {
- Builder builder(*getModule());
- for (auto*& operand : call->operands) {
- operand = builder.makeDrop(operand);
+ // Create a direct call for a given list of operands, an expression which is
+ // known to contain a constant indicating the table offset, and the relevant
+ // table. If we can see that the call will trap, instead return an
+ // unreachable.
+ Expression* makeDirectCall(const std::vector<Expression*>& operands,
+ Expression* c,
+ const TableUtils::FlatTable& flatTable,
+ CallIndirect* original) {
+ Index index = c->cast<Const>()->value.geti32();
+
+ // If the index is invalid, or the type is wrong, we can
+ // emit an unreachable here, since in Binaryen it is ok to
+ // reorder/replace traps when optimizing (but never to
+ // remove them, at least not by default).
+ if (index >= flatTable.names.size()) {
+ return replaceWithUnreachable(operands);
+ }
+ auto name = flatTable.names[index];
+ if (!name.is()) {
+ return replaceWithUnreachable(operands);
}
- replaceCurrent(builder.makeSequence(builder.makeBlock(call->operands),
- builder.makeUnreachable()));
+ auto* func = getModule()->getFunction(name);
+ if (original->sig != func->getSig()) {
+ return replaceWithUnreachable(operands);
+ }
+
+ // Everything looks good!
+ return Builder(*getModule())
+ .makeCall(name, operands, original->type, original->isReturn);
+ }
+
+ Expression* replaceWithUnreachable(const std::vector<Expression*>& operands) {
+ // Emitting an unreachable means we must update parent types.
changedTypes = true;
+
+ Builder builder(*getModule());
+ std::vector<Expression*> newOperands;
+ for (auto* operand : operands) {
+ newOperands.push_back(builder.makeDrop(operand));
+ }
+ return builder.makeSequence(builder.makeBlock(newOperands),
+ builder.makeUnreachable());
}
};