diff options
author | Alon Zakai <azakai@google.com> | 2021-10-04 15:45:43 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-04 15:45:43 -0700 |
commit | 8895a2417d37e3444c98f0023e98ab9151d04290 (patch) | |
tree | 4b233b178e80f39a1177c26e093613fa932fbc25 /src | |
parent | 972062c373e599a5f75d70fa0569e9d0b57854bf (diff) | |
download | binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.gz binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.bz2 binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.zip |
Optimize call_indirect of a select of two constants (#4208)
(call_indirect
..args..
(select
(i32.const x)
(i32.const y)
(condition)
)
)
=>
(if
(condition)
(call $func-for-x
..args..
)
(call $func-for-y
..args..
)
)
To do this we must reorder the condition with the args, and also use
the args more than once, so place them all in locals.
This works towards the goal of polymorphic devirtualization, that is,
turning an indirect call of more than one possible target into more
than one direct call.
Diffstat (limited to 'src')
-rw-r--r-- | src/passes/Directize.cpp | 129 |
1 files changed, 100 insertions, 29 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 0f04b5f4c..5a9b1b984 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -23,6 +23,7 @@ #include <unordered_map> #include "ir/table-utils.h" +#include "ir/type-updating.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" @@ -50,30 +51,65 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { auto& flatTable = it->second; - if (auto* c = curr->target->dynCast<Const>()) { - Index index = c->value.geti32(); - // If the index is invalid, or the type is wrong, we can - // emit an unreachable here, since in Binaryen it is ok to - // reorder/replace traps when optimizing (but never to - // remove them, at least not by default). - if (index >= flatTable.names.size()) { - replaceWithUnreachable(curr); - return; - } - auto name = flatTable.names[index]; - if (!name.is()) { - replaceWithUnreachable(curr); - return; - } - auto* func = getModule()->getFunction(name); - if (curr->sig != func->getSig()) { - replaceWithUnreachable(curr); - return; + // If the target is constant, we can emit a direct call. + if (curr->target->is<Const>()) { + std::vector<Expression*> operands(curr->operands.begin(), + curr->operands.end()); + replaceCurrent(makeDirectCall(operands, curr->target, flatTable, curr)); + return; + } + + // If the target is a select of two different constants, we can emit two + // direct calls. + // TODO: handle 3+ + // TODO: handle the case where just one arm is a constant? + if (auto* select = curr->target->dynCast<Select>()) { + if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) { + Builder builder(*getModule()); + auto* func = getFunction(); + std::vector<Expression*> blockContents; + + // We must use the operands twice, and also must move the condition to + // execute first; use locals for them all. While doing so, if we see + // any are unreachable, stop trying to optimize and leave this for DCE. + std::vector<Index> operandLocals; + for (auto* operand : curr->operands) { + if (operand->type == Type::unreachable || + !TypeUpdating::canHandleAsLocal(operand->type)) { + return; + } + auto currLocal = builder.addVar(func, operand->type); + operandLocals.push_back(currLocal); + blockContents.push_back(builder.makeLocalSet(currLocal, operand)); + } + + if (select->condition->type == Type::unreachable) { + return; + } + + // Build the calls. + auto numOperands = curr->operands.size(); + auto getOperands = [&]() { + std::vector<Expression*> newOperands(numOperands); + for (Index i = 0; i < numOperands; i++) { + newOperands[i] = + builder.makeLocalGet(operandLocals[i], curr->operands[i]->type); + } + return newOperands; + }; + auto* ifTrueCall = + makeDirectCall(getOperands(), select->ifTrue, flatTable, curr); + auto* ifFalseCall = + makeDirectCall(getOperands(), select->ifFalse, flatTable, curr); + + // Create the if to pick the calls, and emit the final block. + auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); + blockContents.push_back(iff); + replaceCurrent(builder.makeBlock(blockContents)); + + // By adding locals we must make type adjustments at the end. + changedTypes = true; } - // Everything looks good! - replaceCurrent( - Builder(*getModule()) - .makeCall(name, curr->operands, curr->type, curr->isReturn)); } } @@ -81,6 +117,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func); if (changedTypes) { ReFinalize().walkFunctionInModule(func, getModule()); + TypeUpdating::handleNonDefaultableLocals(func, *getModule()); } } @@ -89,14 +126,48 @@ private: bool changedTypes = false; - void replaceWithUnreachable(CallIndirect* call) { - Builder builder(*getModule()); - for (auto*& operand : call->operands) { - operand = builder.makeDrop(operand); + // Create a direct call for a given list of operands, an expression which is + // known to contain a constant indicating the table offset, and the relevant + // table. If we can see that the call will trap, instead return an + // unreachable. + Expression* makeDirectCall(const std::vector<Expression*>& operands, + Expression* c, + const TableUtils::FlatTable& flatTable, + CallIndirect* original) { + Index index = c->cast<Const>()->value.geti32(); + + // If the index is invalid, or the type is wrong, we can + // emit an unreachable here, since in Binaryen it is ok to + // reorder/replace traps when optimizing (but never to + // remove them, at least not by default). + if (index >= flatTable.names.size()) { + return replaceWithUnreachable(operands); + } + auto name = flatTable.names[index]; + if (!name.is()) { + return replaceWithUnreachable(operands); } - replaceCurrent(builder.makeSequence(builder.makeBlock(call->operands), - builder.makeUnreachable())); + auto* func = getModule()->getFunction(name); + if (original->sig != func->getSig()) { + return replaceWithUnreachable(operands); + } + + // Everything looks good! + return Builder(*getModule()) + .makeCall(name, operands, original->type, original->isReturn); + } + + Expression* replaceWithUnreachable(const std::vector<Expression*>& operands) { + // Emitting an unreachable means we must update parent types. changedTypes = true; + + Builder builder(*getModule()); + std::vector<Expression*> newOperands; + for (auto* operand : operands) { + newOperands.push_back(builder.makeDrop(operand)); + } + return builder.makeSequence(builder.makeBlock(newOperands), + builder.makeUnreachable()); } }; |