summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/passes/Directize.cpp115
-rw-r--r--src/passes/OptimizeInstructions.cpp18
-rw-r--r--src/passes/call-utils.h153
-rw-r--r--test/lit/passes/directize_all-features.wast59
-rw-r--r--test/lit/passes/optimize-instructions-call_ref.wast111
5 files changed, 355 insertions, 101 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp
index 6853f1faf..e338272a9 100644
--- a/src/passes/Directize.cpp
+++ b/src/passes/Directize.cpp
@@ -22,6 +22,7 @@
#include <unordered_map>
+#include "call-utils.h"
#include "ir/table-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
@@ -59,62 +60,19 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> {
return;
}
- // If the target is a select of two different constants, we can emit two
- // direct calls.
- // TODO: handle 3+
- // TODO: handle the case where just one arm is a constant?
- if (auto* select = curr->target->dynCast<Select>()) {
- if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) {
- Builder builder(*getModule());
- auto* func = getFunction();
- std::vector<Expression*> blockContents;
-
- if (select->condition->type == Type::unreachable) {
- // Leave this for DCE.
- return;
- }
-
- // We must use the operands twice, and also must move the condition to
- // execute first; use locals for them all. While doing so, if we see
- // any are unreachable, stop trying to optimize and leave this for DCE.
- std::vector<Index> operandLocals;
- for (auto* operand : curr->operands) {
- if (operand->type == Type::unreachable ||
- !TypeUpdating::canHandleAsLocal(operand->type)) {
- return;
- }
- }
-
- // None of the types are a problem, so we can proceed to add new vars as
- // needed and perform this optimization.
- for (auto* operand : curr->operands) {
- auto currLocal = builder.addVar(func, operand->type);
- operandLocals.push_back(currLocal);
- blockContents.push_back(builder.makeLocalSet(currLocal, operand));
- // By adding locals we must make type adjustments at the end.
- changedTypes = true;
- }
-
- // Build the calls.
- auto numOperands = curr->operands.size();
- auto getOperands = [&]() {
- std::vector<Expression*> newOperands(numOperands);
- for (Index i = 0; i < numOperands; i++) {
- newOperands[i] =
- builder.makeLocalGet(operandLocals[i], curr->operands[i]->type);
- }
- return newOperands;
- };
- auto* ifTrueCall =
- makeDirectCall(getOperands(), select->ifTrue, flatTable, curr);
- auto* ifFalseCall =
- makeDirectCall(getOperands(), select->ifFalse, flatTable, curr);
-
- // Create the if to pick the calls, and emit the final block.
- auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
- blockContents.push_back(iff);
- replaceCurrent(builder.makeBlock(blockContents));
- }
+ // Emit direct calls for things like a select over constants.
+ if (auto* calls = CallUtils::convertToDirectCalls(
+ curr,
+ [&](Expression* target) {
+ return getTargetInfo(target, flatTable, curr);
+ },
+ *getFunction(),
+ *getModule())) {
+ replaceCurrent(calls);
+ // Note that types may have changed, as the utility here can add locals
+ // which require fixups if they are non-nullable, for example.
+ changedTypes = true;
+ return;
}
}
@@ -131,6 +89,36 @@ private:
bool changedTypes = false;
+ // Given an expression that we will use as the target of an indirect call,
+ // analyze it and return one of the results of CallUtils::IndirectCallInfo,
+ // that is, whether we know a direct call target, or we know it will trap, or
+ // if we know nothing.
+ CallUtils::IndirectCallInfo
+ getTargetInfo(Expression* target,
+ const TableUtils::FlatTable& flatTable,
+ CallIndirect* original) {
+ auto* c = target->dynCast<Const>();
+ if (!c) {
+ return CallUtils::Unknown{};
+ }
+
+ Index index = c->value.geti32();
+
+ // If the index is invalid, or the type is wrong, then this will trap.
+ if (index >= flatTable.names.size()) {
+ return CallUtils::Trap{};
+ }
+ auto name = flatTable.names[index];
+ if (!name.is()) {
+ return CallUtils::Trap{};
+ }
+ auto* func = getModule()->getFunction(name);
+ if (original->heapType != func->type) {
+ return CallUtils::Trap{};
+ }
+ return CallUtils::Known{name};
+ }
+
// Create a direct call for a given list of operands, an expression which is
// known to contain a constant indicating the table offset, and the relevant
// table. If we can see that the call will trap, instead return an
@@ -139,23 +127,16 @@ private:
Expression* c,
const TableUtils::FlatTable& flatTable,
CallIndirect* original) {
- Index index = c->cast<Const>()->value.geti32();
-
// If the index is invalid, or the type is wrong, we can
// emit an unreachable here, since in Binaryen it is ok to
// reorder/replace traps when optimizing (but never to
// remove them, at least not by default).
- if (index >= flatTable.names.size()) {
- return replaceWithUnreachable(operands);
- }
- auto name = flatTable.names[index];
- if (!name.is()) {
- return replaceWithUnreachable(operands);
- }
- auto* func = getModule()->getFunction(name);
- if (original->heapType != func->type) {
+ auto info = getTargetInfo(c, flatTable, original);
+ if (std::get_if<CallUtils::Trap>(&info)) {
return replaceWithUnreachable(operands);
}
+ assert(std::get_if<CallUtils::Known>(&info));
+ auto name = std::get_if<CallUtils::Known>(&info)->target;
// Everything looks good!
return Builder(*getModule())
diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp
index 4b499213f..c4a5b2ed4 100644
--- a/src/passes/OptimizeInstructions.cpp
+++ b/src/passes/OptimizeInstructions.cpp
@@ -42,6 +42,8 @@
#include <support/threads.h>
#include <wasm.h>
+#include "call-utils.h"
+
// TODO: Use the new sign-extension opcodes where appropriate. This needs to be
// conditionalized on the availability of atomics.
@@ -1334,6 +1336,22 @@ struct OptimizeInstructions
curr->operands.back() = builder.makeBlock({set, drop, get});
replaceCurrent(builder.makeCall(
ref->func, curr->operands, curr->type, curr->isReturn));
+ return;
+ }
+
+ // If the target is a select of two different constants, we can emit an if
+ // over two direct calls.
+ if (auto* calls = CallUtils::convertToDirectCalls(
+ curr,
+ [](Expression* target) -> CallUtils::IndirectCallInfo {
+ if (auto* refFunc = target->dynCast<RefFunc>()) {
+ return CallUtils::Known{refFunc->func};
+ }
+ return CallUtils::Unknown{};
+ },
+ *getFunction(),
+ *getModule())) {
+ replaceCurrent(calls);
}
}
diff --git a/src/passes/call-utils.h b/src/passes/call-utils.h
new file mode 100644
index 000000000..175946c0e
--- /dev/null
+++ b/src/passes/call-utils.h
@@ -0,0 +1,153 @@
+/*
+ * Copyright 2022 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef wasm_ir_function_h
+#define wasm_ir_function_h
+
+#include <variant>
+
+#include "ir/type-updating.h"
+#include "wasm.h"
+
+namespace wasm::CallUtils {
+
+// Define a variant to describe the information we know about an indirect call,
+// which is one of three things:
+// * Unknown: Nothing is known this call.
+// * Trap: This call target is invalid and will trap at runtime.
+// * Known: This call goes to a known static call target, which is provided.
+struct Unknown : public std::monostate {};
+struct Trap : public std::monostate {};
+struct Known {
+ Name target;
+};
+using IndirectCallInfo = std::variant<Unknown, Trap, Known>;
+
+// Converts indirect calls that target selects between values into ifs over
+// direct calls. For example, consider this input:
+//
+// (call_ref
+// (select
+// (ref.func A)
+// (ref.func B)
+// (..condition..)
+// )
+// )
+//
+// We'll check if the input falls into such a pattern, and if so, return the new
+// form:
+//
+// (if
+// (..condition..)
+// (call $A)
+// (call $B)
+// )
+//
+// If we fail to find the expected pattern, or we decide it is not worth
+// optimizing it for some reason, we return nullptr.
+//
+// If this returns the new form, it will modify the function as necessary,
+// adding new locals etc., which later passes should optimize.
+//
+// |getCallInfo| is given one of the arms of the select and should return an
+// IndirectCallInfo that says what we know about it. We may know nothing, or
+// that it will trap, or that it will go to a known static target.
+template<typename T>
+inline Expression*
+convertToDirectCalls(T* curr,
+ std::function<IndirectCallInfo(Expression*)> getCallInfo,
+ Function& func,
+ Module& wasm) {
+ auto* select = curr->target->template dynCast<Select>();
+ if (!select) {
+ return nullptr;
+ }
+
+ if (select->condition->type == Type::unreachable) {
+ // Leave this for DCE.
+ return nullptr;
+ }
+
+ // Check if we can find useful info for both arms: either known call targets,
+ // or traps.
+ // TODO: support more than 2 targets (with nested selects)
+ auto ifTrueCallInfo = getCallInfo(select->ifTrue);
+ auto ifFalseCallInfo = getCallInfo(select->ifFalse);
+ if (std::get_if<Unknown>(&ifTrueCallInfo) ||
+ std::get_if<Unknown>(&ifFalseCallInfo)) {
+ // We know nothing about at least one arm, so give up.
+ // TODO: Perhaps emitting a direct call for one arm is enough even if the
+ // other remains indirect?
+ return nullptr;
+ }
+
+ auto& operands = curr->operands;
+
+ // We must use the operands twice, and also must move the condition to
+ // execute first, so we'll use locals for them all. First, see if any are
+ // unreachable, and if so stop trying to optimize and leave this for DCE.
+ for (auto* operand : operands) {
+ if (operand->type == Type::unreachable ||
+ !TypeUpdating::canHandleAsLocal(operand->type)) {
+ return nullptr;
+ }
+ }
+
+ Builder builder(wasm);
+ std::vector<Expression*> blockContents;
+
+ // None of the types are a problem, so we can proceed to add new vars as
+ // needed and perform this optimization.
+ std::vector<Index> operandLocals;
+ for (auto* operand : operands) {
+ auto currLocal = builder.addVar(&func, operand->type);
+ operandLocals.push_back(currLocal);
+ blockContents.push_back(builder.makeLocalSet(currLocal, operand));
+ }
+
+ // Build the calls.
+ auto numOperands = operands.size();
+ auto getOperands = [&]() {
+ std::vector<Expression*> newOperands(numOperands);
+ for (Index i = 0; i < numOperands; i++) {
+ newOperands[i] =
+ builder.makeLocalGet(operandLocals[i], operands[i]->type);
+ }
+ return newOperands;
+ };
+
+ auto makeCall = [&](IndirectCallInfo info) -> Expression* {
+ if (std::get_if<Trap>(&info)) {
+ return builder.makeUnreachable();
+ } else {
+ return builder.makeCall(std::get<Known>(info).target,
+ getOperands(),
+ curr->type,
+ curr->isReturn);
+ }
+ };
+ auto* ifTrueCall = makeCall(ifTrueCallInfo);
+ auto* ifFalseCall = makeCall(ifFalseCallInfo);
+
+ // Create the if to pick the calls, and emit the final block.
+ auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall);
+ blockContents.push_back(iff);
+ return builder.makeBlock(blockContents);
+}
+
+} // namespace wasm::CallUtils
+
+#endif // wasm_ir_function_h
diff --git a/test/lit/passes/directize_all-features.wast b/test/lit/passes/directize_all-features.wast
index 4c1400c54..89c6cfcb1 100644
--- a/test/lit/passes/directize_all-features.wast
+++ b/test/lit/passes/directize_all-features.wast
@@ -519,6 +519,11 @@
;; CHECK: (type $ii (func (param i32 i32)))
(type $ii (func (param i32 i32)))
+
+ (type $none (func))
+
+ ;; CHECK: (type $i32_=>_none (func (param i32)))
+
;; CHECK: (table $0 5 5 funcref)
(table $0 5 5 funcref)
(elem (i32.const 1) $foo1 $foo2)
@@ -627,17 +632,7 @@
;; CHECK-NEXT: )
;; CHECK-NEXT: (if
;; CHECK-NEXT: (local.get $z)
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $3)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $4)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (unreachable)
- ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: (call $foo2
;; CHECK-NEXT: (local.get $3)
;; CHECK-NEXT: (local.get $4)
@@ -668,28 +663,8 @@
;; CHECK-NEXT: )
;; CHECK-NEXT: (if
;; CHECK-NEXT: (local.get $z)
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $3)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $4)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (unreachable)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (block
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $3)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (drop
- ;; CHECK-NEXT: (local.get $4)
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: )
- ;; CHECK-NEXT: (unreachable)
- ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $select-both-out-of-range (param $x i32) (param $y i32) (param $z i32)
@@ -751,6 +726,24 @@
)
)
)
+ ;; CHECK: (func $select-bad-type (param $z i32)
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $select-bad-type (param $z i32)
+ ;; The type here is $none, which does not match the functions at indexes 1 and
+ ;; 2, so we know they will trap and can emit unreachables.
+ (call_indirect (type $none)
+ (select
+ (i32.const 1)
+ (i32.const 2)
+ (local.get $z)
+ )
+ )
+ )
)
(module
diff --git a/test/lit/passes/optimize-instructions-call_ref.wast b/test/lit/passes/optimize-instructions-call_ref.wast
index 553f8d07b..b92e976bb 100644
--- a/test/lit/passes/optimize-instructions-call_ref.wast
+++ b/test/lit/passes/optimize-instructions-call_ref.wast
@@ -19,13 +19,15 @@
;; CHECK: (type $data_=>_none (func (param dataref)))
(type $data_=>_none (func (param (ref data))))
+ ;; CHECK: (type $i32_i32_i32_ref|$i32_i32_=>_none|_=>_none (func (param i32 i32 i32 (ref $i32_i32_=>_none))))
+
;; CHECK: (table $table-1 10 (ref null $i32_i32_=>_none))
(table $table-1 10 (ref null $i32_i32_=>_none))
;; CHECK: (elem $elem-1 (table $table-1) (i32.const 0) (ref null $i32_i32_=>_none) (ref.func $foo))
(elem $elem-1 (table $table-1) (i32.const 0) (ref null $i32_i32_=>_none)
(ref.func $foo))
- ;; CHECK: (elem declare func $fallthrough-no-params $fallthrough-non-nullable $return-nothing)
+ ;; CHECK: (elem declare func $bar $fallthrough-no-params $fallthrough-non-nullable $return-nothing)
;; CHECK: (func $foo (param $0 i32) (param $1 i32)
;; CHECK-NEXT: (unreachable)
@@ -33,6 +35,14 @@
(func $foo (param i32) (param i32)
(unreachable)
)
+
+ ;; CHECK: (func $bar (param $0 i32) (param $1 i32)
+ ;; CHECK-NEXT: (unreachable)
+ ;; CHECK-NEXT: )
+ (func $bar (param i32) (param i32)
+ (unreachable)
+ )
+
;; CHECK: (func $call_ref-to-direct (param $x i32) (param $y i32)
;; CHECK-NEXT: (call $foo
;; CHECK-NEXT: (local.get $x)
@@ -229,4 +239,103 @@
)
)
)
+
+ ;; CHECK: (func $call_ref-to-select (param $x i32) (param $y i32) (param $z i32) (param $f (ref $i32_i32_=>_none))
+ ;; CHECK-NEXT: (local $4 i32)
+ ;; CHECK-NEXT: (local $5 i32)
+ ;; CHECK-NEXT: (block
+ ;; CHECK-NEXT: (local.set $4
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $5
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: (call $foo
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: (local.get $5)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (call $bar
+ ;; CHECK-NEXT: (local.get $4)
+ ;; CHECK-NEXT: (local.get $5)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (call_ref
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: (select (result (ref $i32_i32_=>_none))
+ ;; CHECK-NEXT: (local.get $f)
+ ;; CHECK-NEXT: (ref.func $bar)
+ ;; CHECK-NEXT: (local.get $z)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $call_ref-to-select (param $x i32) (param $y i32) (param $z i32) (param $f (ref $i32_i32_=>_none))
+ ;; This call_ref should become an if over two direct calls.
+ (call_ref
+ (local.get $x)
+ (local.get $y)
+ (select
+ (ref.func $foo)
+ (ref.func $bar)
+ (local.get $z)
+ )
+ )
+
+ ;; But here one arm is not constant, so we do not optimize.
+ (call_ref
+ (local.get $x)
+ (local.get $y)
+ (select
+ (local.get $f)
+ (ref.func $bar)
+ (local.get $z)
+ )
+ )
+ )
+
+ ;; CHECK: (func $return_call_ref-to-select (param $x i32) (param $y i32)
+ ;; CHECK-NEXT: (local $2 i32)
+ ;; CHECK-NEXT: (local $3 i32)
+ ;; CHECK-NEXT: (local.set $2
+ ;; CHECK-NEXT: (local.get $x)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (local.set $3
+ ;; CHECK-NEXT: (local.get $y)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (if
+ ;; CHECK-NEXT: (call $get-i32)
+ ;; CHECK-NEXT: (return_call $foo
+ ;; CHECK-NEXT: (local.get $2)
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: (return_call $bar
+ ;; CHECK-NEXT: (local.get $2)
+ ;; CHECK-NEXT: (local.get $3)
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ ;; CHECK-NEXT: )
+ (func $return_call_ref-to-select (param $x i32) (param $y i32)
+ ;; As above, but with a return call. We optimize this too, and turn a
+ ;; return_call_ref over a select into an if over return_calls.
+ (return_call_ref
+ (local.get $x)
+ (local.get $y)
+ (select
+ (ref.func $foo)
+ (ref.func $bar)
+ (call $get-i32)
+ )
+ )
+ )
+
+ ;; CHECK: (func $get-i32 (result i32)
+ ;; CHECK-NEXT: (i32.const 42)
+ ;; CHECK-NEXT: )
+ (func $get-i32 (result i32)
+ ;; Helper for the above.
+ (i32.const 42)
+ )
)