summaryrefslogtreecommitdiff
path: root/src/passes/Directize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/passes/Directize.cpp')
-rw-r--r--src/passes/Directize.cpp129
1 files changed, 100 insertions, 29 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 0f04b5f4c..5a9b1b984 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -23,6 +23,7 @@
#include <unordered_map>
#include "ir/table-utils.h"
+#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
@@ -50,30 +51,65 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
auto& flatTable = it->second;
- if (auto* c = curr->target->dynCast<Const>()) {
- Index index = c->value.geti32();
- // If the index is invalid, or the type is wrong, we can
- // emit an unreachable here, since in Binaryen it is ok to
- // reorder/replace traps when optimizing (but never to
- // remove them, at least not by default).
- if (index >= flatTable.names.size()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto name = flatTable.names[index];
- if (!name.is()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto* func = getModule()->getFunction(name);
- if (curr->sig != func->getSig()) {
- replaceWithUnreachable(curr);
- return;
+ // If the target is constant, we can emit a direct call.
+ if (curr->target->is<Const>()) {
+ std::vector<Expression*> operands(curr->operands.begin(),
+ curr->operands.end());
+ replaceCurrent(makeDirectCall(operands, curr->target, flatTable, curr));
+ return;
+ }
+
+ // If the target is a select of two different constants, we can emit two
+ // direct calls.
+ // TODO: handle 3+
+ // TODO: handle the case where just one arm is a constant?
+ if (auto* select = curr->target->dynCast<Select>()) {
+ if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) {
+ Builder builder(*getModule());
+ auto* func = getFunction();
+ std::vector<Expression*> blockContents;
+
+ // We must use the operands twice, and also must move the condition to
+ // execute first; use locals for them all. While doing so, if we see
+ // any are unreachable, stop trying to optimize and leave this for DCE.
+ std::vector<Index> operandLocals;
+ for (auto* operand : curr->operands) {
+ if (operand->type == Type::unreachable ||
+ !TypeUpdating::canHandleAsLocal(operand->type)) {
+ return;
+ }
+ auto currLocal = builder.addVar(func, operand->type);
+ operandLocals.push_back(currLocal);
+ blockContents.push_back(builder.makeLocalSet(currLocal, operand));
+ }
+
+ if (select->condition->type == Type::unreachable) {
+ return;
+ }
+
+ // Build the calls.
+ auto numOperands = curr->operands.size();
+ auto getOperands = [&]() {
+ std::vector<Expression*> newOperands(numOperands);
+ for (Index i = 0; i < numOperands; i++) {
+ newOperands[i] =
+ builder.makeLocalGet(operandLocals[i], curr->operands[i]->type);
+ }
+ return newOperands;
+ };
+ auto* ifTrueCall =
+ makeDirectCall(getOperands(), select->ifTrue, flatTable, curr);
+ auto* ifFalseCall =
+ makeDirectCall(getOperands(), select->ifFalse, flatTable, curr);
+
+ // Create the if to pick the calls, and emit the final block.
+ auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
+ blockContents.push_back(iff);
+ replaceCurrent(builder.makeBlock(blockContents));
+
+ // By adding locals we must make type adjustments at the end.
+ changedTypes = true;
}
- // Everything looks good!
- replaceCurrent(
- Builder(*getModule())
- .makeCall(name, curr->operands, curr->type, curr->isReturn));
}
}
@@ -81,6 +117,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func);
if (changedTypes) {
ReFinalize().walkFunctionInModule(func, getModule());
+ TypeUpdating::handleNonDefaultableLocals(func, *getModule());
}
}
@@ -89,14 +126,48 @@ private:
bool changedTypes = false;
- void replaceWithUnreachable(CallIndirect* call) {
- Builder builder(*getModule());
- for (auto*& operand : call->operands) {
- operand = builder.makeDrop(operand);
+ // Create a direct call for a given list of operands, an expression which is
+ // known to contain a constant indicating the table offset, and the relevant
+ // table. If we can see that the call will trap, instead return an
+ // unreachable.
+ Expression* makeDirectCall(const std::vector<Expression*>& operands,
+ Expression* c,
+ const TableUtils::FlatTable& flatTable,
+ CallIndirect* original) {
+ Index index = c->cast<Const>()->value.geti32();
+
+ // If the index is invalid, or the type is wrong, we can
+ // emit an unreachable here, since in Binaryen it is ok to
+ // reorder/replace traps when optimizing (but never to
+ // remove them, at least not by default).
+ if (index >= flatTable.names.size()) {
+ return replaceWithUnreachable(operands);
+ }
+ auto name = flatTable.names[index];
+ if (!name.is()) {
+ return replaceWithUnreachable(operands);
}
- replaceCurrent(builder.makeSequence(builder.makeBlock(call->operands),
- builder.makeUnreachable()));
+ auto* func = getModule()->getFunction(name);
+ if (original->sig != func->getSig()) {
+ return replaceWithUnreachable(operands);
+ }
+
+ // Everything looks good!
+ return Builder(*getModule())
+ .makeCall(name, operands, original->type, original->isReturn);
+ }
+
+ Expression* replaceWithUnreachable(const std::vector<Expression*>& operands) {
+ // Emitting an unreachable means we must update parent types.
changedTypes = true;
+
+ Builder builder(*getModule());
+ std::vector<Expression*> newOperands;
+ for (auto* operand : operands) {
+ newOperands.push_back(builder.makeDrop(operand));
+ }
+ return builder.makeSequence(builder.makeBlock(newOperands),
+ builder.makeUnreachable());
}
};