diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/ir/return-utils.cpp | 99 | ||||
-rw-r--r-- | src/ir/return-utils.h | 39 | ||||
-rw-r--r-- | src/passes/DeadArgumentElimination.cpp | 19 | ||||
-rw-r--r-- | src/passes/Monomorphize.cpp | 105 |
5 files changed, 227 insertions, 36 deletions
diff --git a/src/ir/CMakeLists.txt b/src/ir/CMakeLists.txt index 996daa564..45b08702d 100644 --- a/src/ir/CMakeLists.txt +++ b/src/ir/CMakeLists.txt @@ -17,6 +17,7 @@ set(ir_SOURCES LocalGraph.cpp LocalStructuralDominance.cpp ReFinalize.cpp + return-utils.cpp stack-utils.cpp table-utils.cpp type-updating.cpp diff --git a/src/ir/return-utils.cpp b/src/ir/return-utils.cpp new file mode 100644 index 000000000..20b3a194b --- /dev/null +++ b/src/ir/return-utils.cpp @@ -0,0 +1,99 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/return-utils.h" +#include "ir/module-utils.h" +#include "wasm-builder.h" +#include "wasm-traversal.h" +#include "wasm.h" + +namespace wasm::ReturnUtils { + +namespace { + +struct ReturnValueRemover : public PostWalker<ReturnValueRemover> { + void visitReturn(Return* curr) { + auto* value = curr->value; + assert(value); + curr->value = nullptr; + Builder builder(*getModule()); + replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr)); + } + + void visitCall(Call* curr) { handleReturnCall(curr); } + void visitCallIndirect(CallIndirect* curr) { handleReturnCall(curr); } + void visitCallRef(CallRef* curr) { handleReturnCall(curr); } + + template<typename T> void handleReturnCall(T* curr) { + if (curr->isReturn) { + Fatal() << "Cannot remove return_calls in ReturnValueRemover"; + } + } + + void visitFunction(Function* curr) { + if (curr->body->type.isConcrete()) { + curr->body = Builder(*getModule()).makeDrop(curr->body); + } + } +}; + +} // anonymous namespace + +void removeReturns(Function* func, Module& wasm) { + ReturnValueRemover().walkFunctionInModule(func, &wasm); +} + +std::unordered_map<Function*, bool> findReturnCallers(Module& wasm) { + ModuleUtils::ParallelFunctionAnalysis<bool> analysis( + wasm, [&](Function* func, bool& hasReturnCall) { + if (func->imported()) { + return; + } + + struct Finder : PostWalker<Finder> { + bool hasReturnCall = false; + + void visitCall(Call* curr) { + if (curr->isReturn) { + hasReturnCall = true; + } + } + void visitCallIndirect(CallIndirect* curr) { + if (curr->isReturn) { + hasReturnCall = true; + } + } + void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + hasReturnCall = true; + } + } + } finder; + + finder.walk(func->body); + hasReturnCall = finder.hasReturnCall; + }); + + // Convert to an unordered map for fast lookups. TODO: Avoid a copy here. + std::unordered_map<Function*, bool> ret; + ret.reserve(analysis.map.size()); + for (auto& [k, v] : analysis.map) { + ret[k] = v; + } + return ret; +} + +} // namespace wasm::ReturnUtils diff --git a/src/ir/return-utils.h b/src/ir/return-utils.h new file mode 100644 index 000000000..a5214ba01 --- /dev/null +++ b/src/ir/return-utils.h @@ -0,0 +1,39 @@ +/* + * Copyright 2024 WebAssembly Community Group participants + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef wasm_ir_return_h +#define wasm_ir_return_h + +#include "wasm.h" + +namespace wasm::ReturnUtils { + +// Removes values from both explicit returns and implicit ones (values that flow +// from the body). This is useful after changing a function's type to no longer +// return anything. +// +// This does *not* handle return calls, and will error on them. Removing a +// return call may change the semantics of the program, so we do not do it +// automatically here. +void removeReturns(Function* func, Module& wasm); + +// Return a map of every function to whether it does a return call. +using ReturnCallersMap = std::unordered_map<Function*, bool>; +ReturnCallersMap findReturnCallers(Module& wasm); + +} // namespace wasm::ReturnUtils + +#endif // wasm_ir_return_h diff --git a/src/passes/DeadArgumentElimination.cpp b/src/passes/DeadArgumentElimination.cpp index 4a341571e..83cd7e86d 100644 --- a/src/passes/DeadArgumentElimination.cpp +++ b/src/passes/DeadArgumentElimination.cpp @@ -42,6 +42,7 @@ #include "ir/find_all.h" #include "ir/lubs.h" #include "ir/module-utils.h" +#include "ir/return-utils.h" #include "ir/type-updating.h" #include "ir/utils.h" #include "param-utils.h" @@ -358,23 +359,7 @@ private: } } // Remove any return values. - struct ReturnUpdater : public PostWalker<ReturnUpdater> { - Module* module; - ReturnUpdater(Function* func, Module* module) : module(module) { - walk(func->body); - } - void visitReturn(Return* curr) { - auto* value = curr->value; - assert(value); - curr->value = nullptr; - Builder builder(*module); - replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr)); - } - } returnUpdater(func, module); - // Remove any value flowing out. - if (func->body->type.isConcrete()) { - func->body = Builder(*module).makeDrop(func->body); - } + ReturnUtils::removeReturns(func, *module); } // Given a function and all the calls to it, see if we can refine the type of diff --git a/src/passes/Monomorphize.cpp b/src/passes/Monomorphize.cpp index c27f5d6eb..8c08db55b 100644 --- a/src/passes/Monomorphize.cpp +++ b/src/passes/Monomorphize.cpp @@ -92,6 +92,7 @@ #include "ir/manipulation.h" #include "ir/module-utils.h" #include "ir/names.h" +#include "ir/return-utils.h" #include "ir/type-updating.h" #include "ir/utils.h" #include "pass.h" @@ -103,6 +104,36 @@ namespace wasm { namespace { +// Core information about a call: the call itself, and if it is dropped, the +// drop. +struct CallInfo { + Call* call; + // Store a reference to the drop's pointer so that we can replace it, as when + // we optimize a dropped call we need to replace (drop (call)) with (call). + // Or, if the call is not dropped, this is nullptr. + Expression** drop; +}; + +// Finds the calls and whether each one of them is dropped. +struct CallFinder : public PostWalker<CallFinder> { + std::vector<CallInfo> infos; + + void visitCall(Call* curr) { + // Add the call as not having a drop, and update the drop later if we are. + infos.push_back(CallInfo{curr, nullptr}); + } + + void visitDrop(Drop* curr) { + if (curr->value->is<Call>()) { + // The call we just added to |infos| is dropped. + assert(!infos.empty()); + auto& back = infos.back(); + assert(back.call == curr->value); + back.drop = getCurrentPointer(); + } + } +}; + // Relevant information about a callsite for purposes of monomorphization. struct CallContext { // The operands of the call, processed to leave the parts that make sense to @@ -181,12 +212,12 @@ struct CallContext { // remaining values by updating |newOperands| (for example, if all the values // sent are constants, then |newOperands| will end up empty, as we have // nothing left to send). - void buildFromCall(Call* call, + void buildFromCall(CallInfo& info, std::vector<Expression*>& newOperands, Module& wasm) { Builder builder(wasm); - for (auto* operand : call->operands) { + for (auto* operand : info.call->operands) { // Process the operand. This is a copy operation, as we are trying to move // (copy) code from the callsite into the called function. When we find we // can copy then we do so, and when we cannot that value remains as a @@ -212,8 +243,7 @@ struct CallContext { })); } - // TODO: handle drop - dropped = false; + dropped = !!info.drop; } // Checks whether an expression can be moved into the context. @@ -299,6 +329,11 @@ struct Monomorphize : public Pass { void run(Module* module) override { // TODO: parallelize, see comments below + // Find all the return-calling functions. We cannot remove their returns + // (because turning a return call into a normal call may break the program + // by using more stack). + auto returnCallersMap = ReturnUtils::findReturnCallers(*module); + // Note the list of all functions. We'll be adding more, and do not want to // operate on those. std::vector<Name> funcNames; @@ -309,26 +344,38 @@ struct Monomorphize : public Pass { // to call the monomorphized targets. for (auto name : funcNames) { auto* func = module->getFunction(name); - for (auto* call : FindAll<Call>(func->body).list) { - if (call->type == Type::unreachable) { + + CallFinder callFinder; + callFinder.walk(func->body); + for (auto& info : callFinder.infos) { + if (info.call->type == Type::unreachable) { // Ignore unreachable code. // TODO: return_call? continue; } - if (call->target == name) { + if (info.call->target == name) { // Avoid recursion, which adds some complexity (as we'd be modifying // ourselves if we apply optimizations). continue; } - processCall(call, *module); + // If the target function does a return call, then as noted earlier we + // cannot remove its returns, so do not consider the drop as part of the + // context in such cases (as if we reverse-inlined the drop into the + // target then we'd be removing the returns). + if (returnCallersMap[module->getFunction(info.call->target)]) { + info.drop = nullptr; + } + + processCall(info, *module); } } } // Try to optimize a call. - void processCall(Call* call, Module& wasm) { + void processCall(CallInfo& info, Module& wasm) { + auto* call = info.call; auto target = call->target; auto* func = wasm.getFunction(target); if (func->imported()) { @@ -342,7 +389,7 @@ struct Monomorphize : public Pass { // if we use that context. CallContext context; std::vector<Expression*> newOperands; - context.buildFromCall(call, newOperands, wasm); + context.buildFromCall(info, newOperands, wasm); // See if we've already evaluated this function + call context. If so, then // we've memoized the result. @@ -350,11 +397,8 @@ struct Monomorphize : public Pass { if (iter != funcContextMap.end()) { auto newTarget = iter->second; if (newTarget != target) { - // When we computed this before we found a benefit to optimizing, and - // created a new monomorphized function to call. Use it by simply - // applying the new operands we computed, and adjusting the call target. - call->operands.set(newOperands); - call->target = newTarget; + // We saw benefit to optimizing this case. Apply that. + updateCall(info, newTarget, newOperands, wasm); } return; } @@ -419,8 +463,7 @@ struct Monomorphize : public Pass { if (worthwhile) { // We are using the monomorphized function, so update the call and add it // to the module. - call->operands.set(newOperands); - call->target = monoFunc->name; + updateCall(info, monoFunc->name, newOperands, wasm); wasm.addFunction(std::move(monoFunc)); } @@ -453,8 +496,9 @@ struct Monomorphize : public Pass { newParams.push_back(operand->type); } } - // TODO: support changes to results. - auto newResults = func->getResults(); + // If we were dropped then we are pulling the drop into the monomorphized + // function, which means we return nothing. + auto newResults = context.dropped ? Type::none : func->getResults(); newFunc->type = Signature(Type(newParams), newResults); // We must update local indexes: the new function has a potentially @@ -549,9 +593,32 @@ struct Monomorphize : public Pass { newFunc->body = builder.makeBlock(pre); } + if (context.dropped) { + ReturnUtils::removeReturns(newFunc.get(), wasm); + } + return newFunc; } + // Given a call and a new target it should be calling, apply that new target, + // including updating the operands and handling dropping. + void updateCall(const CallInfo& info, + Name newTarget, + const std::vector<Expression*>& newOperands, + Module& wasm) { + info.call->target = newTarget; + info.call->operands.set(newOperands); + + if (info.drop) { + // Replace (drop (call)) with (call), that is, replace the drop with the + // (updated) call which now has type none. Note we should have handled + // unreachability before getting here. + assert(info.call->type != Type::unreachable); + info.call->type = Type::none; + *info.drop = info.call; + } + } + // Run some function-level optimizations on a function. Ideally we would run a // minimal amount of optimizations here, but we do want to give the optimizer // as much of a chance to work as possible, so for now do all of -O3 (in |