diff options
-rw-r--r-- | src/passes/Directize.cpp | 129 | ||||
-rw-r--r-- | test/lit/passes/directize_all-features.wast | 299 |
2 files changed, 399 insertions, 29 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 0f04b5f4c..5a9b1b984 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -23,6 +23,7 @@ #include <unordered_map> #include "ir/table-utils.h" +#include "ir/type-updating.h" #include "ir/utils.h" #include "pass.h" #include "wasm-builder.h" @@ -50,30 +51,65 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { auto& flatTable = it->second; - if (auto* c = curr->target->dynCast<Const>()) { - Index index = c->value.geti32(); - // If the index is invalid, or the type is wrong, we can - // emit an unreachable here, since in Binaryen it is ok to - // reorder/replace traps when optimizing (but never to - // remove them, at least not by default). - if (index >= flatTable.names.size()) { - replaceWithUnreachable(curr); - return; - } - auto name = flatTable.names[index]; - if (!name.is()) { - replaceWithUnreachable(curr); - return; - } - auto* func = getModule()->getFunction(name); - if (curr->sig != func->getSig()) { - replaceWithUnreachable(curr); - return; + // If the target is constant, we can emit a direct call. + if (curr->target->is<Const>()) { + std::vector<Expression*> operands(curr->operands.begin(), + curr->operands.end()); + replaceCurrent(makeDirectCall(operands, curr->target, flatTable, curr)); + return; + } + + // If the target is a select of two different constants, we can emit two + // direct calls. + // TODO: handle 3+ + // TODO: handle the case where just one arm is a constant? + if (auto* select = curr->target->dynCast<Select>()) { + if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) { + Builder builder(*getModule()); + auto* func = getFunction(); + std::vector<Expression*> blockContents; + + // We must use the operands twice, and also must move the condition to + // execute first; use locals for them all. While doing so, if we see + // any are unreachable, stop trying to optimize and leave this for DCE. + std::vector<Index> operandLocals; + for (auto* operand : curr->operands) { + if (operand->type == Type::unreachable || + !TypeUpdating::canHandleAsLocal(operand->type)) { + return; + } + auto currLocal = builder.addVar(func, operand->type); + operandLocals.push_back(currLocal); + blockContents.push_back(builder.makeLocalSet(currLocal, operand)); + } + + if (select->condition->type == Type::unreachable) { + return; + } + + // Build the calls. + auto numOperands = curr->operands.size(); + auto getOperands = [&]() { + std::vector<Expression*> newOperands(numOperands); + for (Index i = 0; i < numOperands; i++) { + newOperands[i] = + builder.makeLocalGet(operandLocals[i], curr->operands[i]->type); + } + return newOperands; + }; + auto* ifTrueCall = + makeDirectCall(getOperands(), select->ifTrue, flatTable, curr); + auto* ifFalseCall = + makeDirectCall(getOperands(), select->ifFalse, flatTable, curr); + + // Create the if to pick the calls, and emit the final block. + auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); + blockContents.push_back(iff); + replaceCurrent(builder.makeBlock(blockContents)); + + // By adding locals we must make type adjustments at the end. + changedTypes = true; } - // Everything looks good! - replaceCurrent( - Builder(*getModule()) - .makeCall(name, curr->operands, curr->type, curr->isReturn)); } } @@ -81,6 +117,7 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { WalkerPass<PostWalker<FunctionDirectizer>>::doWalkFunction(func); if (changedTypes) { ReFinalize().walkFunctionInModule(func, getModule()); + TypeUpdating::handleNonDefaultableLocals(func, *getModule()); } } @@ -89,14 +126,48 @@ private: bool changedTypes = false; - void replaceWithUnreachable(CallIndirect* call) { - Builder builder(*getModule()); - for (auto*& operand : call->operands) { - operand = builder.makeDrop(operand); + // Create a direct call for a given list of operands, an expression which is + // known to contain a constant indicating the table offset, and the relevant + // table. If we can see that the call will trap, instead return an + // unreachable. + Expression* makeDirectCall(const std::vector<Expression*>& operands, + Expression* c, + const TableUtils::FlatTable& flatTable, + CallIndirect* original) { + Index index = c->cast<Const>()->value.geti32(); + + // If the index is invalid, or the type is wrong, we can + // emit an unreachable here, since in Binaryen it is ok to + // reorder/replace traps when optimizing (but never to + // remove them, at least not by default). + if (index >= flatTable.names.size()) { + return replaceWithUnreachable(operands); + } + auto name = flatTable.names[index]; + if (!name.is()) { + return replaceWithUnreachable(operands); } - replaceCurrent(builder.makeSequence(builder.makeBlock(call->operands), - builder.makeUnreachable())); + auto* func = getModule()->getFunction(name); + if (original->sig != func->getSig()) { + return replaceWithUnreachable(operands); + } + + // Everything looks good! + return Builder(*getModule()) + .makeCall(name, operands, original->type, original->isReturn); + } + + Expression* replaceWithUnreachable(const std::vector<Expression*>& operands) { + // Emitting an unreachable means we must update parent types. changedTypes = true; + + Builder builder(*getModule()); + std::vector<Expression*> newOperands; + for (auto* operand : operands) { + newOperands.push_back(builder.makeDrop(operand)); + } + return builder.makeSequence(builder.makeBlock(newOperands), + builder.makeUnreachable()); } }; diff --git a/test/lit/passes/directize_all-features.wast b/test/lit/passes/directize_all-features.wast index 54cf07556..041c4a47c 100644 --- a/test/lit/passes/directize_all-features.wast +++ b/test/lit/passes/directize_all-features.wast @@ -6,17 +6,21 @@ (module ;; CHECK: (type $ii (func (param i32 i32))) (type $ii (func (param i32 i32))) + ;; CHECK: (table $0 5 5 funcref) (table $0 5 5 funcref) (elem (i32.const 1) $foo) + ;; CHECK: (elem (i32.const 1) $foo) ;; CHECK: (func $foo (param $0 i32) (param $1 i32) ;; CHECK-NEXT: (unreachable) ;; CHECK-NEXT: ) (func $foo (param i32) (param i32) + ;; helper function (unreachable) ) + ;; CHECK: (func $bar (param $x i32) (param $y i32) ;; CHECK-NEXT: (call $foo ;; CHECK-NEXT: (local.get $x) @@ -494,3 +498,298 @@ ) ) ) + +(module + ;; CHECK: (type $i32_i32_i32_=>_none (func (param i32 i32 i32))) + + ;; CHECK: (type $ii (func (param i32 i32))) + (type $ii (func (param i32 i32))) + ;; CHECK: (table $0 5 5 funcref) + (table $0 5 5 funcref) + (elem (i32.const 1) $foo1 $foo2) + ;; CHECK: (elem (i32.const 1) $foo1 $foo2) + + ;; CHECK: (func $foo1 (param $0 i32) (param $1 i32) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $foo1 (param i32) (param i32) + (unreachable) + ) + ;; CHECK: (func $foo2 (param $0 i32) (param $1 i32) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $foo2 (param i32) (param i32) + (unreachable) + ) + ;; CHECK: (func $select (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 i32) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: (call $foo1 + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $foo2 + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select (param $x i32) (param $y i32) (param $z i32) + ;; Test we can optimize a call_indirect whose index is a select between two + ;; constants. We can emit an if and two direct calls for that. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (i32.const 1) + (i32.const 2) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-bad-1 (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (call_indirect $0 (type $ii) + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-bad-1 (param $x i32) (param $y i32) (param $z i32) + ;; As above but one select arm is not constant. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (local.get $z) + (i32.const 2) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-bad-2 (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (call_indirect $0 (type $ii) + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-bad-2 (param $x i32) (param $y i32) (param $z i32) + ;; As above but the other select arm is not constant. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (i32.const 2) + (local.get $z) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-out-of-range (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 i32) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $foo2 + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-out-of-range (param $x i32) (param $y i32) (param $z i32) + ;; Both are constants, but one is out of range for the table, and there is no + ;; valid function to call there; emit an unreachable. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (i32.const 99999) + (i32.const 2) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-both-out-of-range (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 i32) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $4 + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (block + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (local.get $4) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-both-out-of-range (param $x i32) (param $y i32) (param $z i32) + ;; Both are constants, and both are out of range for the table. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (i32.const 99999) + (i32.const -1) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-unreachable-operand (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (call_indirect $0 (type $ii) + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: (local.get $z) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-unreachable-operand (param $x i32) (param $y i32) (param $z i32) + ;; One operand is unreachable. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (unreachable) + (i32.const 2) + (local.get $z) + ) + ) + ) + ;; CHECK: (func $select-unreachable-condition (param $x i32) (param $y i32) (param $z i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (local $4 i32) + ;; CHECK-NEXT: (call_indirect $0 (type $ii) + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (local.get $y) + ;; CHECK-NEXT: (select + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-unreachable-condition (param $x i32) (param $y i32) (param $z i32) + ;; The condition is unreachable. + (call_indirect (type $ii) + (local.get $x) + (local.get $y) + (select + (i32.const 1) + (i32.const 2) + (unreachable) + ) + ) + ) +) + +(module + ;; CHECK: (type $i32_=>_none (func (param i32))) + + ;; CHECK: (type $F (func (param (ref func)))) + + ;; CHECK: (table $0 15 15 funcref) + (table $0 15 15 funcref) + (type $F (func (param (ref func)))) + (elem (i32.const 10) $foo-ref $foo-ref) + + ;; CHECK: (elem (i32.const 10) $foo-ref $foo-ref) + + ;; CHECK: (elem declare func $select-non-nullable) + + ;; CHECK: (func $foo-ref (param $0 (ref func)) + ;; CHECK-NEXT: (unreachable) + ;; CHECK-NEXT: ) + (func $foo-ref (param (ref func)) + ;; helper function + (unreachable) + ) + + ;; CHECK: (func $select-non-nullable (param $x i32) + ;; CHECK-NEXT: (local $1 (ref null $i32_=>_none)) + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (ref.func $select-non-nullable) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (if + ;; CHECK-NEXT: (local.get $x) + ;; CHECK-NEXT: (call $foo-ref + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $foo-ref + ;; CHECK-NEXT: (ref.as_non_null + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $select-non-nullable (param $x i32) + ;; Test we can handle a non-nullable value when optimizing a select, during + ;; which we place values in locals. + (call_indirect (type $F) + (ref.func $select-non-nullable) + (select + (i32.const 10) + (i32.const 11) + (local.get $x) + ) + ) + ) +) |