diff options
Diffstat (limited to 'src/passes/Monomorphize.cpp')
-rw-r--r-- | src/passes/Monomorphize.cpp | 105 |
1 files changed, 86 insertions, 19 deletions
diff --git a/src/passes/Monomorphize.cpp b/src/passes/Monomorphize.cpp index c27f5d6eb..8c08db55b 100644 --- a/src/passes/Monomorphize.cpp +++ b/src/passes/Monomorphize.cpp @@ -92,6 +92,7 @@ #include "ir/manipulation.h" #include "ir/module-utils.h" #include "ir/names.h" +#include "ir/return-utils.h" #include "ir/type-updating.h" #include "ir/utils.h" #include "pass.h" @@ -103,6 +104,36 @@ namespace wasm { namespace { +// Core information about a call: the call itself, and if it is dropped, the +// drop. +struct CallInfo { + Call* call; + // Store a reference to the drop's pointer so that we can replace it, as when + // we optimize a dropped call we need to replace (drop (call)) with (call). + // Or, if the call is not dropped, this is nullptr. + Expression** drop; +}; + +// Finds the calls and whether each one of them is dropped. +struct CallFinder : public PostWalker<CallFinder> { + std::vector<CallInfo> infos; + + void visitCall(Call* curr) { + // Add the call as not having a drop, and update the drop later if we are. + infos.push_back(CallInfo{curr, nullptr}); + } + + void visitDrop(Drop* curr) { + if (curr->value->is<Call>()) { + // The call we just added to |infos| is dropped. + assert(!infos.empty()); + auto& back = infos.back(); + assert(back.call == curr->value); + back.drop = getCurrentPointer(); + } + } +}; + // Relevant information about a callsite for purposes of monomorphization. struct CallContext { // The operands of the call, processed to leave the parts that make sense to @@ -181,12 +212,12 @@ struct CallContext { // remaining values by updating |newOperands| (for example, if all the values // sent are constants, then |newOperands| will end up empty, as we have // nothing left to send). - void buildFromCall(Call* call, + void buildFromCall(CallInfo& info, std::vector<Expression*>& newOperands, Module& wasm) { Builder builder(wasm); - for (auto* operand : call->operands) { + for (auto* operand : info.call->operands) { // Process the operand. This is a copy operation, as we are trying to move // (copy) code from the callsite into the called function. When we find we // can copy then we do so, and when we cannot that value remains as a @@ -212,8 +243,7 @@ struct CallContext { })); } - // TODO: handle drop - dropped = false; + dropped = !!info.drop; } // Checks whether an expression can be moved into the context. @@ -299,6 +329,11 @@ struct Monomorphize : public Pass { void run(Module* module) override { // TODO: parallelize, see comments below + // Find all the return-calling functions. We cannot remove their returns + // (because turning a return call into a normal call may break the program + // by using more stack). + auto returnCallersMap = ReturnUtils::findReturnCallers(*module); + // Note the list of all functions. We'll be adding more, and do not want to // operate on those. std::vector<Name> funcNames; @@ -309,26 +344,38 @@ struct Monomorphize : public Pass { // to call the monomorphized targets. for (auto name : funcNames) { auto* func = module->getFunction(name); - for (auto* call : FindAll<Call>(func->body).list) { - if (call->type == Type::unreachable) { + + CallFinder callFinder; + callFinder.walk(func->body); + for (auto& info : callFinder.infos) { + if (info.call->type == Type::unreachable) { // Ignore unreachable code. // TODO: return_call? continue; } - if (call->target == name) { + if (info.call->target == name) { // Avoid recursion, which adds some complexity (as we'd be modifying // ourselves if we apply optimizations). continue; } - processCall(call, *module); + // If the target function does a return call, then as noted earlier we + // cannot remove its returns, so do not consider the drop as part of the + // context in such cases (as if we reverse-inlined the drop into the + // target then we'd be removing the returns). + if (returnCallersMap[module->getFunction(info.call->target)]) { + info.drop = nullptr; + } + + processCall(info, *module); } } } // Try to optimize a call. - void processCall(Call* call, Module& wasm) { + void processCall(CallInfo& info, Module& wasm) { + auto* call = info.call; auto target = call->target; auto* func = wasm.getFunction(target); if (func->imported()) { @@ -342,7 +389,7 @@ struct Monomorphize : public Pass { // if we use that context. CallContext context; std::vector<Expression*> newOperands; - context.buildFromCall(call, newOperands, wasm); + context.buildFromCall(info, newOperands, wasm); // See if we've already evaluated this function + call context. If so, then // we've memoized the result. @@ -350,11 +397,8 @@ struct Monomorphize : public Pass { if (iter != funcContextMap.end()) { auto newTarget = iter->second; if (newTarget != target) { - // When we computed this before we found a benefit to optimizing, and - // created a new monomorphized function to call. Use it by simply - // applying the new operands we computed, and adjusting the call target. - call->operands.set(newOperands); - call->target = newTarget; + // We saw benefit to optimizing this case. Apply that. + updateCall(info, newTarget, newOperands, wasm); } return; } @@ -419,8 +463,7 @@ struct Monomorphize : public Pass { if (worthwhile) { // We are using the monomorphized function, so update the call and add it // to the module. - call->operands.set(newOperands); - call->target = monoFunc->name; + updateCall(info, monoFunc->name, newOperands, wasm); wasm.addFunction(std::move(monoFunc)); } @@ -453,8 +496,9 @@ struct Monomorphize : public Pass { newParams.push_back(operand->type); } } - // TODO: support changes to results. - auto newResults = func->getResults(); + // If we were dropped then we are pulling the drop into the monomorphized + // function, which means we return nothing. + auto newResults = context.dropped ? Type::none : func->getResults(); newFunc->type = Signature(Type(newParams), newResults); // We must update local indexes: the new function has a potentially @@ -549,9 +593,32 @@ struct Monomorphize : public Pass { newFunc->body = builder.makeBlock(pre); } + if (context.dropped) { + ReturnUtils::removeReturns(newFunc.get(), wasm); + } + return newFunc; } + // Given a call and a new target it should be calling, apply that new target, + // including updating the operands and handling dropping. + void updateCall(const CallInfo& info, + Name newTarget, + const std::vector<Expression*>& newOperands, + Module& wasm) { + info.call->target = newTarget; + info.call->operands.set(newOperands); + + if (info.drop) { + // Replace (drop (call)) with (call), that is, replace the drop with the + // (updated) call which now has type none. Note we should have handled + // unreachability before getting here. + assert(info.call->type != Type::unreachable); + info.call->type = Type::none; + *info.drop = info.call; + } + } + // Run some function-level optimizations on a function. Ideally we would run a // minimal amount of optimizations here, but we do want to give the optimizer // as much of a chance to work as possible, so for now do all of -O3 (in |