/* * Copyright 2021 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // // Merge similar functions that only differs constant values (like immediate // operand of const and call insts) by parameterization. // Performing this pass at post-link time can merge more functions across // objects. Inspired by Swift compiler's optimization which is derived from // LLVM's one: // https://github.com/apple/swift/blob/main/lib/LLVMPasses/LLVMMergeFunctions.cpp // https://github.com/llvm/llvm-project/blob/main/llvm/docs/MergeFunctions.rst // // The basic idea is: // // 1. Group possible mergeable functions by hashing instruction kind // 2. Create a group of mergeable functions (EquivalentClass) that can be merged // by parameterization. The classes are collected by comparing functions on // a pairwise basis. // 3. Derive the parameters to be parameterized (ParamInfo) from each // EquivalentClass. A ParamInfo contains positions of parameter use and a // set of constant values (ConstDiff) for each functions in an // EquivalentClass. (A parameter can be used in multiple times in a function, // so ParamInfo contains an array of use position) // 4. Create a shared function from a function picked from EquivalentClass and // an array of ParamInfo. // 5. Create thunks for each functions in an EquivalentClass. // // e.g. // // Before: // (func $big-const-42 (result i32) // [[many instr 1]] // (i32.const 42) // [[many instr 2]] // ) // (func $big-const-43 (result i32) // [[many instr 1]] // (i32.const 43) // [[many instr 2]] // ) // After: // (func $byn$mgfn-shared$big-const-42 (result i32) // [[many instr 1]] // (local.get $0) // [[many instr 2]] // ) // (func $big-const-42 (result i32) // (call $byn$mgfn-shared$big-const-42 // (i32.const 42) // ) // ) // (func $big-const-43 (result i32) // (call $byn$mgfn-shared$big-const-42 // (i32.const 43) // ) // ) // // In the above example, there is an EquivalentClass `[$big-const-42, // $big-const-43]`, and a ParamInfo `{ values: [i32(42), i32(43)], uses: // [location of (i32.const 42)] }` is derived. Then, clone `$big-const-42` // replacing uses of params with local.get, and create thunks for $big-const-42 // and $big-const-43. #include "ir/hashed.h" #include "ir/manipulation.h" #include "ir/module-utils.h" #include "ir/names.h" #include "ir/utils.h" #include "opt-utils.h" #include "pass.h" #include "support/hash.h" #include "support/utilities.h" #include "wasm.h" #include #include #include #include #include #include #include #include #include namespace wasm { // A set of constant values of an instruction different between each functions // in an EquivalentClass using ConstDiff = std::variant>; // Describes a parameter which we create to parameterize the merged function. struct ParamInfo { // Actual values of the parameter ordered by the EquivalentClass's // `functions`. ConstDiff values; // All uses of the parameter in the primary function. std::vector uses; ParamInfo(ConstDiff values, std::vector uses) : values(std::move(values)), uses(uses) {} // Returns the type of the parameter value. Type getValueType(Module* module) const { if (const auto literals = std::get_if(&values)) { return (*literals)[0].type; } else if (auto callees = std::get_if>(&values)) { auto* callee = module->getFunction((*callees)[0]); return Type(callee->type, NonNullable); } else { WASM_UNREACHABLE("unexpected const value type"); } } // Lower the constant value at a given index to an expression Expression* lowerToExpression(Builder& builder, Module* module, size_t index) const { if (const auto literals = std::get_if(&values)) { return builder.makeConst((*literals)[index]); } else if (auto callees = std::get_if>(&values)) { auto fnName = (*callees)[index]; auto heapType = module->getFunction(fnName)->type; return builder.makeRefFunc(fnName, heapType); } else { WASM_UNREACHABLE("unexpected const value type"); } } }; // Describes the set of functions which are considered as "equivalent" (i.e. // only differing by some constants). struct EquivalentClass { // Primary function in the `functions`, which will be the base for the merged // function. Function* primaryFunction; // List of functions belonging to this equivalence class. std::vector functions; EquivalentClass(Function* primaryFunction, std::vector functions) : primaryFunction(primaryFunction), functions(functions) {} bool isEligibleToMerge() { return this->functions.size() >= 2; } // Merge the functions in this class. void merge(Module* module, const std::vector& params); bool hasMergeBenefit(Module* module, const std::vector& params); Function* createShared(Module* module, const std::vector& params); Function* replaceWithThunk(Builder& builder, Function* target, Function* shared, const std::vector& params, const std::vector& extraArgs); bool deriveParams(Module* module, std::vector& params, bool isIndirectionEnabled); }; struct MergeSimilarFunctions : public Pass { bool invalidatesDWARF() override { return true; } void run(Module* module) override { std::vector classes; collectEquivalentClasses(classes, module); std::sort( classes.begin(), classes.end(), [](const auto& left, const auto& right) { return left.primaryFunction->name < right.primaryFunction->name; }); for (auto& clazz : classes) { if (!clazz.isEligibleToMerge()) { continue; } std::vector params; if (!clazz.deriveParams( module, params, isCallIndirectionEnabled(module))) { continue; } if (!clazz.hasMergeBenefit(module, params)) { continue; } clazz.merge(module, params); } } // Parameterize direct calls if the module supports func ref values. bool isCallIndirectionEnabled(Module* module) const { return module->features.hasReferenceTypes() && module->features.hasGC(); } bool areInEquvalentClass(Function* lhs, Function* rhs, Module* module); void collectEquivalentClasses(std::vector& classes, Module* module); }; // Determine if two functions are equivalent ignoring constants. bool MergeSimilarFunctions::areInEquvalentClass(Function* lhs, Function* rhs, Module* module) { if (lhs->imported() || rhs->imported()) { return false; } if (lhs->type != rhs->type) { return false; } if (lhs->getNumVars() != rhs->getNumVars()) { return false; } ExpressionAnalyzer::ExprComparer comparer = [&](Expression* lhsExpr, Expression* rhsExpr) { if (lhsExpr->_id != rhsExpr->_id) { return false; } if (lhsExpr->type != rhsExpr->type) { return false; } if (lhsExpr->is()) { if (!this->isCallIndirectionEnabled(module)) { return false; } auto lhsCast = lhsExpr->dynCast(); auto rhsCast = rhsExpr->dynCast(); if (lhsCast->operands.size() != rhsCast->operands.size()) { return false; } if (lhsCast->type != rhsCast->type) { return false; } auto* lhsCallee = module->getFunction(lhsCast->target); auto* rhsCallee = module->getFunction(rhsCast->target); if (lhsCallee->type != rhsCallee->type) { return false; } // Arguments operands should be also equivalent ignoring constants. for (Index i = 0; i < lhsCast->operands.size(); i++) { if (!ExpressionAnalyzer::flexibleEqual( lhsCast->operands[i], rhsCast->operands[i], comparer)) { return false; } } return true; } if (lhsExpr->is()) { auto lhsCast = lhsExpr->dynCast(); auto rhsCast = rhsExpr->dynCast(); // Types should be the same at least. if (lhsCast->value.type != rhsCast->value.type) { return false; } return true; } return false; }; if (!ExpressionAnalyzer::flexibleEqual(lhs->body, rhs->body, comparer)) { return false; } return true; } // Collect all equivalent classes to be merged. void MergeSimilarFunctions::collectEquivalentClasses( std::vector& classes, Module* module) { auto hashes = FunctionHasher::createMap(module); PassRunner runner(module); std::function ignoringConsts = [&](Expression* expr, size_t& digest) { // Ignore const's immediate operands. if (expr->is()) { return true; } // Ignore callee operands. if (auto* call = expr->dynCast()) { for (auto operand : call->operands) { rehash(digest, ExpressionAnalyzer::flexibleHash(operand, ignoringConsts)); } rehash(digest, call->isReturn); return true; } return false; }; FunctionHasher(&hashes, ignoringConsts).run(&runner, module); // Find hash-equal groups. std::map> hashGroups; ModuleUtils::iterDefinedFunctions( *module, [&](Function* func) { hashGroups[hashes[func]].push_back(func); }); for (auto& [_, hashGroup] : hashGroups) { if (hashGroup.size() < 2) { continue; } // Collect exactly equivalent functions ignoring constants. std::vector classesInGroup = { EquivalentClass(hashGroup[0], {hashGroup[0]})}; for (Index i = 1; i < hashGroup.size(); i++) { auto* func = hashGroup[i]; bool found = false; for (auto& newClass : classesInGroup) { if (areInEquvalentClass(newClass.primaryFunction, func, module)) { newClass.functions.push_back(func); found = true; break; } } if (!found) { // Same hash but different instruction pattern. classesInGroup.push_back(EquivalentClass(func, {func})); } } std::copy(classesInGroup.begin(), classesInGroup.end(), std::back_inserter(classes)); } } // Find the set of parameters which are required to merge the functions in the // class Returns false if unable to derive parameters. bool EquivalentClass::deriveParams(Module* module, std::vector& params, bool isCallIndirectionEnabled) { // Allows iteration over children of the root expression recursively. struct DeepValueIterator { // The DFS work list. SmallVector tasks; DeepValueIterator(Expression** root) { tasks.push_back(root); } void operator++() { ChildIterator it(*tasks.back()); tasks.pop_back(); for (Expression*& child : it) { tasks.push_back(&child); } } Expression*& operator*() { assert(!empty()); return *tasks.back(); } bool empty() { return tasks.empty(); } }; if (primaryFunction->imported()) { return false; } DeepValueIterator primaryIt(&primaryFunction->body); std::vector siblingIterators; // Skip the first function, as it is the primary function to compare the // primary function with the other functions based on the primary instr type. assert(functions.size() >= 2); for (auto func = functions.begin() + 1; func != functions.end(); ++func) { siblingIterators.emplace_back(&(*func)->body); } for (; !primaryIt.empty(); ++primaryIt) { Expression*& primary = *primaryIt; ConstDiff diff; Literals values; std::vector names; bool isAllSame = true; if (auto* primaryConst = primary->dynCast()) { values.push_back(primaryConst->value); for (auto& it : siblingIterators) { Expression*& sibling = *it; ++it; if (auto* siblingConst = sibling->dynCast()) { isAllSame &= primaryConst->value == siblingConst->value; values.push_back(siblingConst->value); } else { WASM_UNREACHABLE( "all sibling functions should have the same instruction type"); } } diff = values; } else if (isCallIndirectionEnabled && primary->is()) { auto* primaryCall = primary->dynCast(); names.push_back(primaryCall->target); for (auto& it : siblingIterators) { Expression*& sibling = *it; ++it; if (auto* siblingCall = sibling->dynCast()) { isAllSame &= primaryCall->target == siblingCall->target; names.push_back(siblingCall->target); } else { WASM_UNREACHABLE( "all sibling functions should have the same instruction type"); } } diff = names; } else { // Skip non-constant expressions, which are ensured to be the exactly // same. for (auto& it : siblingIterators) { // Sibling functions in a class should have the same instruction type. assert((*it)->_id == primary->_id); ++it; } continue; } // If all values are the same, skip to parameterize it. if (isAllSame) { continue; } // If the derived param is already in the params, reuse it. // e.g. // // ``` // (func $use-42-twice (result i32) // (i32.add (i32.const 42) (i32.const 42)) // ) // (func $use-43-twice (result i32) // (i32.add (i32.const 43) (i32.const 43)) // ) // ``` // // will be merged reusing the parameter [42, 43] // // ``` // (func $use-42-twice (result i32) // (call $byn$mgfn-shared$use-42-twice (i32.const 42)) // ) // (func $use-43-twice (result i32) // (call $byn$mgfn-shared$use-42-twice (i32.const 43)) // ) // (func $byn$mgfn-shared$use-42-twice (param $0 i32) (result i32) // (i32.add (local.get $0) (local.get $0)) // ) // ``` // bool paramReused = false; for (auto& param : params) { if (param.values == diff) { param.uses.push_back(&primary); paramReused = true; break; } } if (!paramReused) { params.push_back(ParamInfo(diff, {&primary})); } } return true; } void EquivalentClass::merge(Module* module, const std::vector& params) { Function* sharedFn = createShared(module, params); for (size_t i = 0; i < functions.size(); ++i) { Builder builder(*module); auto* func = functions[i]; std::vector extraArgs; for (auto& param : params) { extraArgs.push_back(param.lowerToExpression(builder, module, i)); } replaceWithThunk(builder, func, sharedFn, params, extraArgs); } return; } // Determine if it's beneficial to merge the functions in the class // Merging functions by creating a shared function and thunks is not always // beneficial. If the functions are very small, added glue code may be larger // than the reduced size. bool EquivalentClass::hasMergeBenefit(Module* module, const std::vector& params) { size_t funcCount = functions.size(); Index exprSize = Measurer::measure(primaryFunction->body); size_t thunkCount = funcCount; // -1 for cloned primary func size_t removedInstrs = (funcCount - 1) * exprSize; // Each thunks will add local.get and call instructions to forward the params // and pass extra parameterized values. size_t addedInstrsPerThunk = thunkCount * ( // call 1 + // local.get primaryFunction->getParams().size() + params.size()); constexpr size_t INSTR_WEIGHT = 1; constexpr size_t CODE_SEC_LOCALS_WEIGHT = 1; constexpr size_t CODE_SEC_ENTRY_WEIGHT = 2; constexpr size_t FUNC_SEC_ENTRY_WEIGHT = 2; // Glue instrs for thunks and a merged function entry will be added by the // merge. size_t negativeScore = addedInstrsPerThunk * INSTR_WEIGHT + thunkCount * ( // Locals entries in merged function in code section. (params.size() * CODE_SEC_LOCALS_WEIGHT) + // Code size field in merged function entry. CODE_SEC_ENTRY_WEIGHT) + // Thunk function entries in function section. (thunkCount * FUNC_SEC_ENTRY_WEIGHT); size_t positiveScore = INSTR_WEIGHT * removedInstrs; return negativeScore < positiveScore; } Function* EquivalentClass::createShared(Module* module, const std::vector& params) { Name fnName = Names::getValidFunctionName(*module, std::string("byn$mgfn-shared$") + primaryFunction->name.toString()); Builder builder(*module); std::vector sigParams; Index extraParamBase = primaryFunction->getNumParams(); Index newVarBase = primaryFunction->getNumParams() + params.size(); for (const auto& param : primaryFunction->getParams()) { sigParams.push_back(param); } for (const auto& param : params) { sigParams.push_back(param.getValueType(module)); } Signature sig(Type(sigParams), primaryFunction->getResults()); // Cloning the primary function while replacing the parameterized values ExpressionManipulator::CustomCopier copier = [&](Expression* expr) -> Expression* { if (!expr) { return nullptr; } // Replace the use of the parameter with extra locals for (Index paramIdx = 0; paramIdx < params.size(); paramIdx++) { for (auto& use : params[paramIdx].uses) { if (*use != expr) { continue; } auto* paramExpr = builder.makeLocalGet( extraParamBase + paramIdx, params[paramIdx].getValueType(module)); if (expr->is()) { return paramExpr; } else if (auto* call = expr->cast()) { ExpressionList operands(module->allocator); // Clone the children of the call for (auto* operand : call->operands) { operands.push_back( ExpressionManipulator::flexibleCopy(operand, *module, copier)); } auto returnType = module->getFunction(call->target)->getResults(); return builder.makeCallRef( paramExpr, operands, returnType, call->isReturn); } } } // Re-number local indices of variables (not params) to offset for the extra // params if (auto* localGet = expr->dynCast()) { if (primaryFunction->isVar(localGet->index)) { localGet->index = newVarBase + (localGet->index - primaryFunction->getNumParams()); localGet->finalize(); return localGet; } } if (auto* localSet = expr->dynCast()) { if (primaryFunction->isVar(localSet->index)) { auto operand = ExpressionManipulator::flexibleCopy(localSet->value, *module, copier); localSet->index = newVarBase + (localSet->index - primaryFunction->getNumParams()); localSet->value = operand; localSet->finalize(); return localSet; } } return nullptr; }; Expression* body = ExpressionManipulator::flexibleCopy(primaryFunction->body, *module, copier); auto vars = primaryFunction->vars; std::unique_ptr f = builder.makeFunction(fnName, sig, std::move(vars), body); return module->addFunction(std::move(f)); } Function* EquivalentClass::replaceWithThunk(Builder& builder, Function* target, Function* shared, const std::vector& params, const std::vector& extraArgs) { std::vector callOperands; Type targetParams = target->getParams(); for (Index i = 0; i < targetParams.size(); i++) { callOperands.push_back(builder.makeLocalGet(i, targetParams[i])); } for (const auto& value : extraArgs) { callOperands.push_back(value); } // TODO: make a return_call when possible? auto ret = builder.makeCall(shared->name, callOperands, target->getResults()); target->vars.clear(); target->body = ret; return target; } Pass* createMergeSimilarFunctionsPass() { return new MergeSimilarFunctions(); } } // namespace wasm