diff options
author | Alon Zakai <azakai@google.com> | 2022-05-27 11:16:39 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-27 11:16:39 -0700 |
commit | f21774f3865e506b6dea912af8f8be16fd29dacb (patch) | |
tree | cd08af70a3301d44c231e093acbb0799a62123bb /src/passes | |
parent | ac2190c1bd1a0be12a1d25e62695f791f270d050 (diff) | |
download | binaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.tar.gz binaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.tar.bz2 binaryen-f21774f3865e506b6dea912af8f8be16fd29dacb.zip |
OptimizeInstructions: Turn call_ref of a select into an if over two direct calls (#4660)
This extends the existing call_indirect code to do the same for call_ref,
basically. The shared code is added to a new helper utility.
Diffstat (limited to 'src/passes')
-rw-r--r-- | src/passes/Directize.cpp | 115 | ||||
-rw-r--r-- | src/passes/OptimizeInstructions.cpp | 18 | ||||
-rw-r--r-- | src/passes/call-utils.h | 153 |
3 files changed, 219 insertions, 67 deletions
diff --git a/src/passes/Directize.cpp b/src/passes/Directize.cpp index 6853f1faf..e338272a9 100644 --- a/src/passes/Directize.cpp +++ b/src/passes/Directize.cpp @@ -22,6 +22,7 @@ #include <unordered_map> +#include "call-utils.h" #include "ir/table-utils.h" #include "ir/type-updating.h" #include "ir/utils.h" @@ -59,62 +60,19 @@ struct FunctionDirectizer : public WalkerPass<PostWalker<FunctionDirectizer>> { return; } - // If the target is a select of two different constants, we can emit two - // direct calls. - // TODO: handle 3+ - // TODO: handle the case where just one arm is a constant? - if (auto* select = curr->target->dynCast<Select>()) { - if (select->ifTrue->is<Const>() && select->ifFalse->is<Const>()) { - Builder builder(*getModule()); - auto* func = getFunction(); - std::vector<Expression*> blockContents; - - if (select->condition->type == Type::unreachable) { - // Leave this for DCE. - return; - } - - // We must use the operands twice, and also must move the condition to - // execute first; use locals for them all. While doing so, if we see - // any are unreachable, stop trying to optimize and leave this for DCE. - std::vector<Index> operandLocals; - for (auto* operand : curr->operands) { - if (operand->type == Type::unreachable || - !TypeUpdating::canHandleAsLocal(operand->type)) { - return; - } - } - - // None of the types are a problem, so we can proceed to add new vars as - // needed and perform this optimization. - for (auto* operand : curr->operands) { - auto currLocal = builder.addVar(func, operand->type); - operandLocals.push_back(currLocal); - blockContents.push_back(builder.makeLocalSet(currLocal, operand)); - // By adding locals we must make type adjustments at the end. - changedTypes = true; - } - - // Build the calls. - auto numOperands = curr->operands.size(); - auto getOperands = [&]() { - std::vector<Expression*> newOperands(numOperands); - for (Index i = 0; i < numOperands; i++) { - newOperands[i] = - builder.makeLocalGet(operandLocals[i], curr->operands[i]->type); - } - return newOperands; - }; - auto* ifTrueCall = - makeDirectCall(getOperands(), select->ifTrue, flatTable, curr); - auto* ifFalseCall = - makeDirectCall(getOperands(), select->ifFalse, flatTable, curr); - - // Create the if to pick the calls, and emit the final block. - auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); - blockContents.push_back(iff); - replaceCurrent(builder.makeBlock(blockContents)); - } + // Emit direct calls for things like a select over constants. + if (auto* calls = CallUtils::convertToDirectCalls( + curr, + [&](Expression* target) { + return getTargetInfo(target, flatTable, curr); + }, + *getFunction(), + *getModule())) { + replaceCurrent(calls); + // Note that types may have changed, as the utility here can add locals + // which require fixups if they are non-nullable, for example. + changedTypes = true; + return; } } @@ -131,6 +89,36 @@ private: bool changedTypes = false; + // Given an expression that we will use as the target of an indirect call, + // analyze it and return one of the results of CallUtils::IndirectCallInfo, + // that is, whether we know a direct call target, or we know it will trap, or + // if we know nothing. + CallUtils::IndirectCallInfo + getTargetInfo(Expression* target, + const TableUtils::FlatTable& flatTable, + CallIndirect* original) { + auto* c = target->dynCast<Const>(); + if (!c) { + return CallUtils::Unknown{}; + } + + Index index = c->value.geti32(); + + // If the index is invalid, or the type is wrong, then this will trap. + if (index >= flatTable.names.size()) { + return CallUtils::Trap{}; + } + auto name = flatTable.names[index]; + if (!name.is()) { + return CallUtils::Trap{}; + } + auto* func = getModule()->getFunction(name); + if (original->heapType != func->type) { + return CallUtils::Trap{}; + } + return CallUtils::Known{name}; + } + // Create a direct call for a given list of operands, an expression which is // known to contain a constant indicating the table offset, and the relevant // table. If we can see that the call will trap, instead return an @@ -139,23 +127,16 @@ private: Expression* c, const TableUtils::FlatTable& flatTable, CallIndirect* original) { - Index index = c->cast<Const>()->value.geti32(); - // If the index is invalid, or the type is wrong, we can // emit an unreachable here, since in Binaryen it is ok to // reorder/replace traps when optimizing (but never to // remove them, at least not by default). - if (index >= flatTable.names.size()) { - return replaceWithUnreachable(operands); - } - auto name = flatTable.names[index]; - if (!name.is()) { - return replaceWithUnreachable(operands); - } - auto* func = getModule()->getFunction(name); - if (original->heapType != func->type) { + auto info = getTargetInfo(c, flatTable, original); + if (std::get_if<CallUtils::Trap>(&info)) { return replaceWithUnreachable(operands); } + assert(std::get_if<CallUtils::Known>(&info)); + auto name = std::get_if<CallUtils::Known>(&info)->target; // Everything looks good! return Builder(*getModule()) diff --git a/src/passes/OptimizeInstructions.cpp b/src/passes/OptimizeInstructions.cpp index 4b499213f..c4a5b2ed4 100644 --- a/src/passes/OptimizeInstructions.cpp +++ b/src/passes/OptimizeInstructions.cpp @@ -42,6 +42,8 @@ #include <support/threads.h> #include <wasm.h> +#include "call-utils.h" + // TODO: Use the new sign-extension opcodes where appropriate. This needs to be // conditionalized on the availability of atomics. @@ -1334,6 +1336,22 @@ struct OptimizeInstructions curr->operands.back() = builder.makeBlock({set, drop, get}); replaceCurrent(builder.makeCall( ref->func, curr->operands, curr->type, curr->isReturn)); + return; + } + + // If the target is a select of two different constants, we can emit an if + // over two direct calls. + if (auto* calls = CallUtils::convertToDirectCalls( + curr, + [](Expression* target) -> CallUtils::IndirectCallInfo { + if (auto* refFunc = target->dynCast<RefFunc>()) { + return CallUtils::Known{refFunc->func}; + } + return CallUtils::Unknown{}; + }, + *getFunction(), + *getModule())) { + replaceCurrent(calls); } } diff --git a/src/passes/call-utils.h b/src/passes/call-utils.h new file mode 100644 index 000000000..175946c0e --- /dev/null +++ b/src/passes/call-utils.h @@ -0,0 +1,153 @@ +/* + * Copyright 2022 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_function_h +#define wasm_ir_function_h + +#include <variant> + +#include "ir/type-updating.h" +#include "wasm.h" + +namespace wasm::CallUtils { + +// Define a variant to describe the information we know about an indirect call, +// which is one of three things: +// * Unknown: Nothing is known this call. +// * Trap: This call target is invalid and will trap at runtime. +// * Known: This call goes to a known static call target, which is provided. +struct Unknown : public std::monostate {}; +struct Trap : public std::monostate {}; +struct Known { + Name target; +}; +using IndirectCallInfo = std::variant<Unknown, Trap, Known>; + +// Converts indirect calls that target selects between values into ifs over +// direct calls. For example, consider this input: +// +// (call_ref +// (select +// (ref.func A) +// (ref.func B) +// (..condition..) +// ) +// ) +// +// We'll check if the input falls into such a pattern, and if so, return the new +// form: +// +// (if +// (..condition..) +// (call $A) +// (call $B) +// ) +// +// If we fail to find the expected pattern, or we decide it is not worth +// optimizing it for some reason, we return nullptr. +// +// If this returns the new form, it will modify the function as necessary, +// adding new locals etc., which later passes should optimize. +// +// |getCallInfo| is given one of the arms of the select and should return an +// IndirectCallInfo that says what we know about it. We may know nothing, or +// that it will trap, or that it will go to a known static target. +template<typename T> +inline Expression* +convertToDirectCalls(T* curr, + std::function<IndirectCallInfo(Expression*)> getCallInfo, + Function& func, + Module& wasm) { + auto* select = curr->target->template dynCast<Select>(); + if (!select) { + return nullptr; + } + + if (select->condition->type == Type::unreachable) { + // Leave this for DCE. + return nullptr; + } + + // Check if we can find useful info for both arms: either known call targets, + // or traps. + // TODO: support more than 2 targets (with nested selects) + auto ifTrueCallInfo = getCallInfo(select->ifTrue); + auto ifFalseCallInfo = getCallInfo(select->ifFalse); + if (std::get_if<Unknown>(&ifTrueCallInfo) || + std::get_if<Unknown>(&ifFalseCallInfo)) { + // We know nothing about at least one arm, so give up. + // TODO: Perhaps emitting a direct call for one arm is enough even if the + // other remains indirect? + return nullptr; + } + + auto& operands = curr->operands; + + // We must use the operands twice, and also must move the condition to + // execute first, so we'll use locals for them all. First, see if any are + // unreachable, and if so stop trying to optimize and leave this for DCE. + for (auto* operand : operands) { + if (operand->type == Type::unreachable || + !TypeUpdating::canHandleAsLocal(operand->type)) { + return nullptr; + } + } + + Builder builder(wasm); + std::vector<Expression*> blockContents; + + // None of the types are a problem, so we can proceed to add new vars as + // needed and perform this optimization. + std::vector<Index> operandLocals; + for (auto* operand : operands) { + auto currLocal = builder.addVar(&func, operand->type); + operandLocals.push_back(currLocal); + blockContents.push_back(builder.makeLocalSet(currLocal, operand)); + } + + // Build the calls. + auto numOperands = operands.size(); + auto getOperands = [&]() { + std::vector<Expression*> newOperands(numOperands); + for (Index i = 0; i < numOperands; i++) { + newOperands[i] = + builder.makeLocalGet(operandLocals[i], operands[i]->type); + } + return newOperands; + }; + + auto makeCall = [&](IndirectCallInfo info) -> Expression* { + if (std::get_if<Trap>(&info)) { + return builder.makeUnreachable(); + } else { + return builder.makeCall(std::get<Known>(info).target, + getOperands(), + curr->type, + curr->isReturn); + } + }; + auto* ifTrueCall = makeCall(ifTrueCallInfo); + auto* ifFalseCall = makeCall(ifFalseCallInfo); + + // Create the if to pick the calls, and emit the final block. + auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); + blockContents.push_back(iff); + return builder.makeBlock(blockContents); +} + +} // namespace wasm::CallUtils + +#endif // wasm_ir_function_h |