/* * Copyright 2022 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef wasm_ir_function_h #define wasm_ir_function_h #include <variant> #include "ir/debuginfo.h" #include "ir/type-updating.h" #include "wasm.h" namespace wasm::CallUtils { // Define a variant to describe the information we know about an indirect call, // which is one of three things: // * Unknown: Nothing is known this call. // * Trap: This call target is invalid and will trap at runtime. // * Known: This call goes to a known static call target, which is provided. struct Unknown : public std::monostate {}; struct Trap : public std::monostate {}; struct Known { Name target; }; using IndirectCallInfo = std::variant<Unknown, Trap, Known>; // Converts indirect calls that target selects between values into ifs over // direct calls. For example, consider this input: // // (call_ref // (select // (ref.func A) // (ref.func B) // (..condition..) // ) // ) // // We'll check if the input falls into such a pattern, and if so, return the new // form: // // (if // (..condition..) // (call $A) // (call $B) // ) // // If we fail to find the expected pattern, or we decide it is not worth // optimizing it for some reason, we return nullptr. // // If this returns the new form, it will modify the function as necessary, // adding new locals etc., which later passes should optimize. // // |getCallInfo| is given one of the arms of the select and should return an // IndirectCallInfo that says what we know about it. We may know nothing, or // that it will trap, or that it will go to a known static target. template<typename T> inline Expression* convertToDirectCalls(T* curr, std::function<IndirectCallInfo(Expression*)> getCallInfo, Function& func, Module& wasm) { auto* select = curr->target->template dynCast<Select>(); if (!select) { return nullptr; } if (select->type == Type::unreachable) { // Leave this for DCE. return nullptr; } // Check if we can find useful info for both arms: either known call targets, // or traps. // TODO: support more than 2 targets (with nested selects) auto ifTrueCallInfo = getCallInfo(select->ifTrue); auto ifFalseCallInfo = getCallInfo(select->ifFalse); if (std::get_if<Unknown>(&ifTrueCallInfo) || std::get_if<Unknown>(&ifFalseCallInfo)) { // We know nothing about at least one arm, so give up. // TODO: Perhaps emitting a direct call for one arm is enough even if the // other remains indirect? return nullptr; } auto& operands = curr->operands; // We must use the operands twice, and also must move the condition to // execute first, so we'll use locals for them all. First, see if any are // unreachable, and if so stop trying to optimize and leave this for DCE. for (auto* operand : operands) { if (operand->type == Type::unreachable || !TypeUpdating::canHandleAsLocal(operand->type)) { return nullptr; } } Builder builder(wasm); std::vector<Expression*> blockContents; // None of the types are a problem, so we can proceed to add new vars as // needed and perform this optimization. std::vector<Index> operandLocals; for (auto* operand : operands) { auto currLocal = builder.addVar(&func, operand->type); operandLocals.push_back(currLocal); blockContents.push_back(builder.makeLocalSet(currLocal, operand)); } // Build the calls. auto numOperands = operands.size(); auto getOperands = [&]() { std::vector<Expression*> newOperands(numOperands); for (Index i = 0; i < numOperands; i++) { newOperands[i] = builder.makeLocalGet(operandLocals[i], operands[i]->type); } return newOperands; }; auto makeCall = [&](IndirectCallInfo info) -> Expression* { Expression* ret; if (std::get_if<Trap>(&info)) { ret = builder.makeUnreachable(); } else { ret = builder.makeCall(std::get<Known>(info).target, getOperands(), curr->type, curr->isReturn); } debuginfo::copyOriginalToReplacement(curr, ret, &func); return ret; }; auto* ifTrueCall = makeCall(ifTrueCallInfo); auto* ifFalseCall = makeCall(ifFalseCallInfo); // Create the if to pick the calls, and emit the final block. auto* iff = builder.makeIf(select->condition, ifTrueCall, ifFalseCall); blockContents.push_back(iff); return builder.makeBlock(blockContents); } } // namespace wasm::CallUtils #endif // wasm_ir_function_h