diff options
Diffstat (limited to 'src/passes/call-utils.h')
-rw-r--r-- | src/passes/call-utils.h | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/src/passes/call-utils.h b/src/passes/call-utils.h new file mode 100644 index 000000000..175946c0e --- /dev/null +++ b/src/passes/call-utils.h @@ -0,0 +1,153 @@ +/* + * Copyright 2022 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_function_h +#define wasm_ir_function_h + +#include <variant> + +#include "ir/type-updating.h" +#include "wasm.h" + +namespace wasm::CallUtils { + +// Define a variant to describe the information we know about an indirect call, +// which is one of three things: +// * Unknown: Nothing is known this call. +// * Trap: This call target is invalid and will trap at runtime. +// * Known: This call goes to a known static call target, which is provided. +struct Unknown : public std::monostate {}; +struct Trap : public std::monostate {}; +struct Known { + Name target; +}; +using IndirectCallInfo = std::variant<Unknown, Trap, Known>; + +// Converts indirect calls that target selects between values into ifs over +// direct calls. For example, consider this input: +// +// (call_ref +// (select +// (ref.func A) +// (ref.func B) +// (..condition..) +// ) +// ) +// +// We'll check if the input falls into such a pattern, and if so, return the new +// form: +// +// (if +// (..condition..) +// (call $A) +// (call $B) +// ) +// +// If we fail to find the expected pattern, or we decide it is not worth +// optimizing it for some reason, we return nullptr. +// +// If this returns the new form, it will modify the function as necessary, +// adding new locals etc., which later passes should optimize. +// +// |getCallInfo| is given one of the arms of the select and should return an +// IndirectCallInfo that says what we know about it. We may know nothing, or +// that it will trap, or that it will go to a known static target. +template<typename T> +inline Expression* +convertToDirectCalls(T* curr, + std::function<IndirectCallInfo(Expression*)> getCallInfo, + Function& func, + Module& wasm) { + auto* select = curr->target->template dynCast<Select>(); + if (!select) { + return nullptr; + } + + if (select->condition->type == Type::unreachable) { + // Leave this for DCE. + return nullptr; + } + + // Check if we can find useful info for both arms: either known call targets, + // or traps. + // TODO: support more than 2 targets (with nested selects) + auto ifTrueCallInfo = getCallInfo(select->ifTrue); + auto ifFalseCallInfo = getCallInfo(select->ifFalse); + if (std::get_if<Unknown>(&ifTrueCallInfo) || + std::get_if<Unknown>(&ifFalseCallInfo)) { + // We know nothing about at least one arm, so give up. + // TODO: Perhaps emitting a direct call for one arm is enough even if the + // other remains indirect? + return nullptr; + } + + auto& operands = curr->operands; + + // We must use the operands twice, and also must move the condition to + // execute first, so we'll use locals for them all. First, see if any are + // unreachable, and if so stop trying to optimize and leave this for DCE. + for (auto* operand : operands) { + if (operand->type == Type::unreachable || + !TypeUpdating::canHandleAsLocal(operand->type)) { + return nullptr; + } + } + + Builder builder(wasm); + std::vector<Expression*> blockContents; + + // None of the types are a problem, so we can proceed to add new vars as + // needed and perform this optimization. + std::vector<Index> operandLocals; + for (auto* operand : operands) { + auto currLocal = builder.addVar(&func, operand->type); + operandLocals.push_back(currLocal); + blockContents.push_back(builder.makeLocalSet(currLocal, operand)); + } + + // Build the calls. + auto numOperands = operands.size(); + auto getOperands = [&]() { + std::vector<Expression*> newOperands(numOperands); + for (Index i = 0; i < numOperands; i++) { + newOperands[i] = + builder.makeLocalGet(operandLocals[i], operands[i]->type); + } + return newOperands; + }; + + auto makeCall = [&](IndirectCallInfo info) -> Expression* { + if (std::get_if<Trap>(&info)) { + return builder.makeUnreachable(); + } else { + return builder.makeCall(std::get<Known>(info).target, + getOperands(), + curr->type, + curr->isReturn); + } + }; + auto* ifTrueCall = makeCall(ifTrueCallInfo); + auto* ifFalseCall = makeCall(ifFalseCallInfo); + + // Create the if to pick the calls, and emit the final block. + auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); + blockContents.push_back(iff); + return builder.makeBlock(blockContents); +} + +} // namespace wasm::CallUtils + +#endif // wasm_ir_function_h |