summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlon Zakai <azakai@google.com>2021-10-04 15:45:43 -0700
committerGitHub <noreply@github.com>2021-10-04 15:45:43 -0700
commit8895a2417d37e3444c98f0023e98ab9151d04290 (patch)
tree4b233b178e80f39a1177c26e093613fa932fbc25
parent972062c373e599a5f75d70fa0569e9d0b57854bf (diff)
downloadbinaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.gz
binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.tar.bz2
binaryen-8895a2417d37e3444c98f0023e98ab9151d04290.zip
Optimize call_indirect of a select of two constants (#4208)
(call_indirect ..args.. (select (i32.const x) (i32.const y) (condition) ) ) => (if (condition) (call $func-for-x ..args.. ) (call $func-for-y ..args.. ) ) To do this we must reorder the condition with the args, and also use the args more than once, so place them all in locals. This works towards the goal of polymorphic devirtualization, that is, turning an indirect call of more than one possible target into more than one direct call.
-rw-r--r--src/passes/Directize.cpp129
-rw-r--r--test/lit/passes/directize_all-features.wast299
2 files changed, 399 insertions, 29 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 0f04b5f4c..5a9b1b984 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -23,6 +23,7 @@
#include <unordered_map>
#include "ir/table-utils.h"
+#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "wasm-builder.h"
@@ -50,30 +51,65 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
auto& flatTable = it->second;
- if (auto* c = curr->target->dynCast<Const>()) {
- Index index = c->value.geti32();
- // If the index is invalid, or the type is wrong, we can
- // emit an unreachable here, since in Binaryen it is ok to
- // reorder/replace traps when optimizing (but never to
- // remove them, at least not by default).
- if (index >= flatTable.names.size()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto name = flatTable.names[index];
- if (!name.is()) {
- replaceWithUnreachable(curr);
- return;
- }
- auto* func = getModule()->getFunction(name);
- if (curr->sig != func->getSig()) {
- replaceWithUnreachable(curr);
- return;
+ // If the target is constant, we can emit a direct call.
+ if (curr->target->is<Const>()) {
+ std::vector<Expression*> operands(curr->operands.begin(),
+ curr->operands.end());
+ replaceCurrent(makeDirectCall(operands, curr->target, flatTable, curr));
+ return;
+ }
+
+ // If the target is a select of two different constants, we can emit two
+ // direct calls.
+ // TODO: handle 3+
+ // TODO: handle the case where just one arm is a constant?
+ if (auto* select = curr->target->dynCast<Select>()) {
+ if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) {
+ Builder builder(*getModule());
+ auto* func = getFunction();
+ std::vector<Expression*> blockContents;
+
+ // We must use the operands twice, and also must move the condition to
+ // execute first; use locals for them all. While doing so, if we see
+ // any are unreachable, stop trying to optimize and leave this for DCE.
+ std::vector<Index> operandLocals;
+ for (auto* operand : curr->operands) {
+ if (operand->type == Type::unreachable ||
+ !TypeUpdating::canHandleAsLocal(operand->type)) {
+ return;
+ }
+ auto currLocal = builder.addVar(func, operand->type);
+ operandLocals.push_back(currLocal);
+ blockContents.push_back(builder.makeLocalSet(currLocal, operand));
+ }
+
+ if (select->condition->type == Type::unreachable) {
+ return;
+ }
+
+ // Build the calls.
+ auto numOperands = curr->operands.size();
+ auto getOperands = [&]() {
+ std::vector<Expression*> newOperands(numOperands);
+ for (Index i = 0; i < numOperands; i++) {
+ newOperands[i] =
+ builder.makeLocalGet(operandLocals[i], curr->operands[i]->type);
+ }
+ return newOperands;
+ };
+ auto* ifTrueCall =
+ makeDirectCall(getOperands(), select->ifTrue, flatTable, curr);
+ auto* ifFalseCall =
+ makeDirectCall(getOperands(), select->ifFalse, flatTable, curr);
+
+ // Create the if to pick the calls, and emit the final block.
+ auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
+ blockContents.push_back(iff);
+ replaceCurrent(builder.makeBlock(blockContents));
+
+ // By adding locals we must make type adjustments at the end.
+ changedTypes = true;
}
- // Everything looks good!
- replaceCurrent(
- Builder(*getModule())
- .makeCall(name, curr->operands, curr->type, curr->isReturn));
}
}
@@ -81,6 +117,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func);
if (changedTypes) {
ReFinalize().walkFunctionInModule(func, getModule());
+ TypeUpdating::handleNonDefaultableLocals(func, *getModule());
}
}
@@ -89,14 +126,48 @@ private:
bool changedTypes = false;
- void replaceWithUnreachable(CallIndirect* call) {
- Builder builder(*getModule());
- for (auto*& operand : call->operands) {
- operand = builder.makeDrop(operand);
+ // Create a direct call for a given list of operands, an expression which is
+ // known to contain a constant indicating the table offset, and the relevant
+ // table. If we can see that the call will trap, instead return an
+ // unreachable.
+ Expression* makeDirectCall(const std::vector<Expression*>& operands,
+ Expression* c,
+ const TableUtils::FlatTable& flatTable,
+ CallIndirect* original) {
+ Index index = c->cast<Const>()->value.geti32();
+
+ // If the index is invalid, or the type is wrong, we can
+ // emit an unreachable here, since in Binaryen it is ok to
+ // reorder/replace traps when optimizing (but never to
+ // remove them, at least not by default).
+ if (index >= flatTable.names.size()) {
+ return replaceWithUnreachable(operands);
+ }
+ auto name = flatTable.names[index];
+ if (!name.is()) {
+ return replaceWithUnreachable(operands);
}
- replaceCurrent(builder.makeSequence(builder.makeBlock(call->operands),
- builder.makeUnreachable()));
+ auto* func = getModule()->getFunction(name);
+ if (original->sig != func->getSig()) {
+ return replaceWithUnreachable(operands);
+ }
+
+ // Everything looks good!
+ return Builder(*getModule())
+ .makeCall(name, operands, original->type, original->isReturn);
+ }
+
+ Expression* replaceWithUnreachable(const std::vector<Expression*>& operands) {
+ // Emitting an unreachable means we must update parent types.
changedTypes = true;
+
+ Builder builder(*getModule());
+ std::vector<Expression*> newOperands;
+ for (auto* operand : operands) {
+ newOperands.push_back(builder.makeDrop(operand));
+ }
+ return builder.makeSequence(builder.makeBlock(newOperands),
+ builder.makeUnreachable());
}
};
diff --git a/test/lit/passes/directize_all-features.wast b/test/lit/passes/directize_all-features.wast
index 54cf07556..041c4a47c 100644
--- a/test/lit/passes/directize_all-features.wast
+++ b/test/lit/passes/directize_all-features.wast
@@ -6,17 +6,21 @@
(module
;; CHECK: (type $ii (func (param i32 i32)))
(type $ii (func (param i32 i32)))
+
;; CHECK: (table $0 5 5 funcref)
(table $0 5 5 funcref)
(elem (i32.const 1) $foo)
+
;; CHECK: (elem (i32.const 1) $foo)
;; CHECK: (func $foo (param $0 i32) (param $1 i32)
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
(func $foo (param i32) (param i32)
+ ;; helper function
(unreachable)
)
+
;; CHECK: (func $bar (param $x i32) (param $y i32)
;; CHECK-NEXT: (call $foo
;; CHECK-NEXT: (local.get $x)
@@ -494,3 +498,298 @@
)
)
)
+
+(module
+ ;; CHECK: (type $i32_i32_i32_=>_none (func (param i32 i32 i32)))
+
+ ;; CHECK: (type $ii (func (param i32 i32)))
+ (type $ii (func (param i32 i32)))
+ ;; CHECK: (table $0 5 5 funcref)
+ (table $0 5 5 funcref)
+ (elem (i32.const 1) $foo1 $foo2)
+ ;; CHECK: (elem (i32.const 1) $foo1 $foo2)
+
+ ;; CHECK: (func $foo1 (param $0 i32) (param $1 i32)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $foo1 (param i32) (param i32)
+ (unreachable)
+ )
+ ;; CHECK: (func $foo2 (param $0 i32) (param $1 i32)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $foo2 (param i32) (param i32)
+ (unreachable)
+ )
+ ;; CHECK: (func $select (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (local $3 i32)
+ ;; CHECK-NEXT: (local $4 i32)
+ ;; CHECK-NEXT: (local.set $3
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $4
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (call $foo1
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (call $foo2
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select (param $x i32) (param $y i32) (param $z i32)
+ ;; Test we can optimize a call_indirect whose index is a select between two
+ ;; constants. We can emit an if and two direct calls for that.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (i32.const 1)
+ (i32.const 2)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-bad-1 (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (call_indirect $0 (type $ii)
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: (select
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (i32.const 2)
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-bad-1 (param $x i32) (param $y i32) (param $z i32)
+ ;; As above but one select arm is not constant.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (local.get $z)
+ (i32.const 2)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-bad-2 (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (call_indirect $0 (type $ii)
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: (select
+ ;; CHECK-NEXT: (i32.const 2)
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-bad-2 (param $x i32) (param $y i32) (param $z i32)
+ ;; As above but the other select arm is not constant.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (i32.const 2)
+ (local.get $z)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-out-of-range (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (local $3 i32)
+ ;; CHECK-NEXT: (local $4 i32)
+ ;; CHECK-NEXT: (local.set $3
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $4
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (call $foo2
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-out-of-range (param $x i32) (param $y i32) (param $z i32)
+ ;; Both are constants, but one is out of range for the table, and there is no
+ ;; valid function to call there; emit an unreachable.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (i32.const 99999)
+ (i32.const 2)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-both-out-of-range (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (local $3 i32)
+ ;; CHECK-NEXT: (local $4 i32)
+ ;; CHECK-NEXT: (local.set $3
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $4
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (drop
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-both-out-of-range (param $x i32) (param $y i32) (param $z i32)
+ ;; Both are constants, and both are out of range for the table.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (i32.const 99999)
+ (i32.const -1)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-unreachable-operand (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (call_indirect $0 (type $ii)
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: (select
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: (i32.const 2)
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-unreachable-operand (param $x i32) (param $y i32) (param $z i32)
+ ;; One operand is unreachable.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (unreachable)
+ (i32.const 2)
+ (local.get $z)
+ )
+ )
+ )
+ ;; CHECK: (func $select-unreachable-condition (param $x i32) (param $y i32) (param $z i32)
+ ;; CHECK-NEXT: (local $3 i32)
+ ;; CHECK-NEXT: (local $4 i32)
+ ;; CHECK-NEXT: (call_indirect $0 (type $ii)
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: (select
+ ;; CHECK-NEXT: (i32.const 1)
+ ;; CHECK-NEXT: (i32.const 2)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-unreachable-condition (param $x i32) (param $y i32) (param $z i32)
+ ;; The condition is unreachable.
+ (call_indirect (type $ii)
+ (local.get $x)
+ (local.get $y)
+ (select
+ (i32.const 1)
+ (i32.const 2)
+ (unreachable)
+ )
+ )
+ )
+)
+
+(module
+ ;; CHECK: (type $i32_=>_none (func (param i32)))
+
+ ;; CHECK: (type $F (func (param (ref func))))
+
+ ;; CHECK: (table $0 15 15 funcref)
+ (table $0 15 15 funcref)
+ (type $F (func (param (ref func))))
+ (elem (i32.const 10) $foo-ref $foo-ref)
+
+ ;; CHECK: (elem (i32.const 10) $foo-ref $foo-ref)
+
+ ;; CHECK: (elem declare func $select-non-nullable)
+
+ ;; CHECK: (func $foo-ref (param $0 (ref func))
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $foo-ref (param (ref func))
+ ;; helper function
+ (unreachable)
+ )
+
+ ;; CHECK: (func $select-non-nullable (param $x i32)
+ ;; CHECK-NEXT: (local $1 (ref null $i32_=>_none))
+ ;; CHECK-NEXT: (local.set $1
+ ;; CHECK-NEXT: (ref.func $select-non-nullable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (call $foo-ref
+ ;; CHECK-NEXT: (ref.as_non_null
+ ;; CHECK-NEXT: (local.get $1)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (call $foo-ref
+ ;; CHECK-NEXT: (ref.as_non_null
+ ;; CHECK-NEXT: (local.get $1)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-non-nullable (param $x i32)
+ ;; Test we can handle a non-nullable value when optimizing a select, during
+ ;; which we place values in locals.
+ (call_indirect (type $F)
+ (ref.func $select-non-nullable)
+ (select
+ (i32.const 10)
+ (i32.const 11)
+ (local.get $x)
+ )
+ )
+ )
+)