diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/ir/effects.h | 113 | ||||
-rw-r--r-- | src/ir/localize.h | 15 | ||||
-rw-r--r-- | src/passes/Inlining.cpp | 210 | ||||
-rw-r--r-- | src/tools/wasm-ctor-eval.cpp | 194 | ||||
-rw-r--r-- | src/wasm-builder.h | 1 | ||||
-rw-r--r-- | src/wasm-interpreter.h | 168 | ||||
-rw-r--r-- | src/wasm/wasm.cpp | 1 |
7 files changed, 476 insertions, 226 deletions
diff --git a/src/ir/effects.h b/src/ir/effects.h index 6901f99de..3ecb54641 100644 --- a/src/ir/effects.h +++ b/src/ir/effects.h @@ -67,9 +67,23 @@ public: // noticeable from the perspective of the caller, that is, effects that are // only noticeable during the call, but "vanish" when the call stack is // unwound. + // + // Unlike walking just the body, walking the function will also + // include the effects of any return calls the function makes. For that + // reason, it is a bug if a user of this code calls walk(Expression*) and not + // walk(Function*) if their intention is to scan an entire function body. + // Putting it another way, a return_call is syntax sugar for a return and a + // call, where the call executes at the function scope, so there is a + // meaningful difference between scanning an expression and scanning + // the entire function body. void walk(Function* func) { walk(func->body); + // Effects of return-called functions will be visible to the caller. + if (hasReturnCallThrow) { + throws_ = true; + } + // We can ignore branching out of the function body - this can only be // a return, and that is only noticeable in the function, not outside. branchesOut = false; @@ -143,6 +157,22 @@ public: // or a continuation that is never continued, are examples of that. bool mayNotReturn = false; + // Since return calls return out of the body of the function before performing + // their call, they are indistinguishable from normal returns from the + // perspective of their surrounding code, and the return-callee's effects only + // become visible when considering the effects of the whole function + // containing the return call. To model this correctly, stash the callee's + // effects on the side and only merge them in after walking a full function + // body. + // + // We currently do this stashing only for the throw effect, but in principle + // we could do it for all effects if it made a difference. (Only throw is + // noticeable now because the only thing that can change between doing the + // call here and doing it outside at the function exit is the scoping of + // try-catch blocks. If future wasm scoping additions are added, we may need + // more here.) + bool hasReturnCallThrow = false; + // Helper functions to check for various effect types bool accessesLocal() const { @@ -466,43 +496,63 @@ private: return; } + const EffectAnalyzer* targetEffects = nullptr; + if (parent.funcEffectsMap) { + auto iter = parent.funcEffectsMap->find(curr->target); + if (iter != parent.funcEffectsMap->end()) { + targetEffects = &iter->second; + } + } + if (curr->isReturn) { parent.branchesOut = true; + // When EH is enabled, any call can throw. + if (parent.features.hasExceptionHandling() && + (!targetEffects || targetEffects->throws())) { + parent.hasReturnCallThrow = true; + } } - if (parent.funcEffectsMap) { - auto iter = parent.funcEffectsMap->find(curr->target); - if (iter != parent.funcEffectsMap->end()) { - // We have effect information for this call target, and can just use - // that. The one change we may want to make is to remove throws_, if - // the target function throws and we know that will be caught anyhow, - // the same as the code below for the general path. - const auto& targetEffects = iter->second; - if (targetEffects.throws_ && parent.tryDepth > 0) { - auto filteredEffects = targetEffects; - filteredEffects.throws_ = false; - parent.mergeIn(filteredEffects); - } else { - // Just merge in all the effects. - parent.mergeIn(targetEffects); - } - return; + if (targetEffects) { + // We have effect information for this call target, and can just use + // that. The one change we may want to make is to remove throws_, if the + // target function throws and we know that will be caught anyhow, the + // same as the code below for the general path. We can always filter out + // throws for return calls because they are already more precisely + // captured by `branchesOut`, which models the return, and + // `hasReturnCallThrow`, which models the throw that will happen after + // the return. + if (targetEffects->throws_ && (parent.tryDepth > 0 || curr->isReturn)) { + auto filteredEffects = *targetEffects; + filteredEffects.throws_ = false; + parent.mergeIn(filteredEffects); + } else { + // Just merge in all the effects. + parent.mergeIn(*targetEffects); } + return; } parent.calls = true; - // When EH is enabled, any call can throw. - if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { + // When EH is enabled, any call can throw. Skip this for return calls + // because the throw is already more precisely captured by the combination + // of `hasReturnCallThrow` and `branchesOut`. + if (parent.features.hasExceptionHandling() && parent.tryDepth == 0 && + !curr->isReturn) { parent.throws_ = true; } } void visitCallIndirect(CallIndirect* curr) { parent.calls = true; - if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { - parent.throws_ = true; - } if (curr->isReturn) { parent.branchesOut = true; + if (parent.features.hasExceptionHandling()) { + parent.hasReturnCallThrow = true; + } + } + if (parent.features.hasExceptionHandling() && + (parent.tryDepth == 0 && !curr->isReturn)) { + parent.throws_ = true; } } void visitLocalGet(LocalGet* curr) { @@ -745,21 +795,26 @@ private: } } void visitCallRef(CallRef* curr) { + if (curr->isReturn) { + parent.branchesOut = true; + if (parent.features.hasExceptionHandling()) { + parent.hasReturnCallThrow = true; + } + } if (curr->target->type.isNull()) { parent.trap = true; return; } - parent.calls = true; - if (parent.features.hasExceptionHandling() && parent.tryDepth == 0) { - parent.throws_ = true; - } - if (curr->isReturn) { - parent.branchesOut = true; - } // traps when the call target is null if (curr->target->type.isNullable()) { parent.implicitTrap = true; } + + parent.calls = true; + if (parent.features.hasExceptionHandling() && + (parent.tryDepth == 0 && !curr->isReturn)) { + parent.throws_ = true; + } } void visitRefTest(RefTest* curr) {} void visitRefCast(RefCast* curr) { diff --git a/src/ir/localize.h b/src/ir/localize.h index d44fb9be5..85e4415f5 100644 --- a/src/ir/localize.h +++ b/src/ir/localize.h @@ -153,17 +153,20 @@ struct ChildLocalizer { // Nothing to add. return parent; } + auto* block = getChildrenReplacement(); + if (!hasUnreachableChild) { + block->list.push_back(parent); + block->finalize(); + } + return block; + } + // Like `getReplacement`, but the result never contains the parent. + Block* getChildrenReplacement() { auto* block = Builder(wasm).makeBlock(); block->list.set(sets); if (hasUnreachableChild) { - // If there is an unreachable child then we do not need the parent at all, - // and we know the type is unreachable. block->type = Type::unreachable; - } else { - // Otherwise, add the parent and finalize. - block->list.push_back(parent); - block->finalize(); } return block; } diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 51dddaaa9..0643389c2 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -37,6 +37,7 @@ #include "ir/element-utils.h" #include "ir/find_all.h" #include "ir/literal-utils.h" +#include "ir/localize.h" #include "ir/module-utils.h" #include "ir/names.h" #include "ir/type-updating.h" @@ -298,22 +299,34 @@ struct Updater : public PostWalker<Updater> { Module* module; std::map<Index, Index> localMapping; Name returnName; + Type resultType; bool isReturn; Builder* builder; PassOptions& options; + struct ReturnCallInfo { + // The original `return_call` or `return_call_indirect` or `return_call_ref` + // with its operands replaced with `local.get`s. + Expression* call; + // The branch that is serving as the "return" part of the original + // `return_call`. + Break* branch; + }; + + // Collect information on return_calls in the inlined body. Each will be + // turned into branches out of the original inlined body followed by + // non-return version of the original `return_call`, followed by a branch out + // to the caller. The branch labels will be filled in at the end of the walk. + std::vector<ReturnCallInfo> returnCallInfos; + Updater(PassOptions& options) : options(options) {} void visitReturn(Return* curr) { replaceCurrent(builder->makeBreak(returnName, curr->value)); } - // Return calls in inlined functions should only break out of the scope of - // the inlined code, not the entire function they are being inlined into. To - // achieve this, make the call a non-return call and add a break. This does - // not cause unbounded stack growth because inlining and return calling both - // avoid creating a new stack frame. - template<typename T> void handleReturnCall(T* curr, Type results) { - if (isReturn) { + + template<typename T> void handleReturnCall(T* curr, Signature sig) { + if (isReturn || !curr->isReturn) { // If the inlined callsite was already a return_call, then we can keep // return_calls in the inlined function rather than downgrading them. // That is, if A->B and B->C and both those calls are return_calls @@ -321,45 +334,85 @@ struct Updater : public PostWalker<Updater> { // return_call. return; } + + // Set the children to locals as necessary, then add a branch out of the + // inlined body. The branch label will be set later when we create branch + // targets for the calls. + Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options) + .getChildrenReplacement(); + Break* branch = builder->makeBreak(Name()); + childBlock->list.push_back(branch); + childBlock->type = Type::unreachable; + replaceCurrent(childBlock); + curr->isReturn = false; - curr->type = results; - // There might still be unreachable children causing this to be unreachable. - curr->finalize(); - if (results.isConcrete()) { - replaceCurrent(builder->makeBreak(returnName, curr)); - } else { - replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName))); - } + curr->type = sig.results; + returnCallInfos.push_back({curr, branch}); } + void visitCall(Call* curr) { - if (curr->isReturn) { - handleReturnCall(curr, module->getFunction(curr->target)->getResults()); - } + handleReturnCall(curr, module->getFunction(curr->target)->getSig()); } + void visitCallIndirect(CallIndirect* curr) { - if (curr->isReturn) { - handleReturnCall(curr, curr->heapType.getSignature().results); - } + handleReturnCall(curr, curr->heapType.getSignature()); } + void visitCallRef(CallRef* curr) { Type targetType = curr->target->type; - if (targetType.isNull()) { - // We don't know what type the call should return, but we can't leave it - // as a potentially-invalid return_call_ref, either. - replaceCurrent(getDroppedChildrenAndAppend( - curr, *module, options, Builder(*module).makeUnreachable())); + if (!targetType.isSignature()) { + // We don't know what type the call should return, but it will also never + // be reached, so we don't need to do anything here. return; } - if (curr->isReturn) { - handleReturnCall(curr, targetType.getHeapType().getSignature().results); - } + handleReturnCall(curr, targetType.getHeapType().getSignature()); } + void visitLocalGet(LocalGet* curr) { curr->index = localMapping[curr->index]; } + void visitLocalSet(LocalSet* curr) { curr->index = localMapping[curr->index]; } + + void walk(Expression*& curr) { + PostWalker<Updater>::walk(curr); + if (returnCallInfos.empty()) { + return; + } + + Block* body = builder->blockify(curr); + curr = body; + auto blockNames = BranchUtils::BranchAccumulator::get(body); + + for (Index i = 0; i < returnCallInfos.size(); ++i) { + auto& info = returnCallInfos[i]; + + // Add a block containing the previous body and a branch up to the caller. + // Give the block a name that will allow this return_call's original + // callsite to branch out of it then execute the call before returning to + // the caller. + auto name = Names::getValidName( + "__return_call", [&](Name test) { return !blockNames.count(test); }, i); + blockNames.insert(name); + info.branch->name = name; + Block* oldBody = builder->makeBlock(body->list, body->type); + body->list.clear(); + + if (resultType.isConcrete()) { + body->list.push_back(builder->makeBlock( + name, {builder->makeBreak(returnName, oldBody)}, Type::none)); + } else { + oldBody->list.push_back(builder->makeBreak(returnName)); + oldBody->name = name; + oldBody->type = Type::none; + body->list.push_back(oldBody); + } + body->list.push_back(info.call); + body->finalize(resultType); + } + } }; // Core inlining logic. Modifies the outside function (adding locals as @@ -376,8 +429,11 @@ static Expression* doInlining(Module* module, Index nameHint = 0) { Function* from = action.contents; auto* call = (*action.callSite)->cast<Call>(); + // Works for return_call, too Type retType = module->getFunction(call->target)->getResults(); + + // Build the block that will contain the inlined contents. Builder builder(*module); auto* block = builder.makeBlock(); auto name = std::string("__inlined_func$") + from->name.toString(); @@ -385,6 +441,7 @@ static Expression* doInlining(Module* module, name += '$' + std::to_string(nameHint); } block->name = Name(name); + // In the unlikely event that the function already has a branch target with // this name, fix that up, as otherwise we can get unexpected capture of our // branches, that is, we could end up with this: @@ -407,27 +464,24 @@ static Expression* doInlining(Module* module, // // (In this case we could use a second block and define the named block $X // after the call's parameters, but that adds work for an extremely rare - // situation.) + // situation.) The latter case does not apply if the call is a return_call, + // because in that case the call's children do not appear inside the same + // block as the inlined body. if (BranchUtils::hasBranchTarget(from->body, block->name) || - BranchUtils::BranchSeeker::has(call, block->name)) { + (!call->isReturn && BranchUtils::BranchSeeker::has(call, block->name))) { auto fromNames = BranchUtils::getBranchTargets(from->body); - auto callNames = BranchUtils::BranchAccumulator::get(call); + auto callNames = call->isReturn ? BranchUtils::NameSet{} + : BranchUtils::BranchAccumulator::get(call); block->name = Names::getValidName(block->name, [&](Name test) { return !fromNames.count(test) && !callNames.count(test); }); } - if (call->isReturn) { - if (retType.isConcrete()) { - *action.callSite = builder.makeReturn(block); - } else { - *action.callSite = builder.makeSequence(block, builder.makeReturn()); - } - } else { - *action.callSite = block; - } + // Prepare to update the inlined code's locals and other things. Updater updater(options); + updater.setFunction(into); updater.module = module; + updater.resultType = from->getResults(); updater.returnName = block->name; updater.isReturn = call->isReturn; updater.builder = &builder; @@ -435,31 +489,71 @@ static Expression* doInlining(Module* module, for (Index i = 0; i < from->getNumLocals(); i++) { updater.localMapping[i] = builder.addVar(into, from->getLocalType(i)); } - // Assign the operands into the params - for (Index i = 0; i < from->getParams().size(); i++) { - block->list.push_back( - builder.makeLocalSet(updater.localMapping[i], call->operands[i])); - } - // Zero out the vars (as we may be in a loop, and may depend on their - // zero-init value - for (Index i = 0; i < from->vars.size(); i++) { - auto type = from->vars[i]; - if (!LiteralUtils::canMakeZero(type)) { - // Non-zeroable locals do not need to be zeroed out. As they have no zero - // value they by definition should not be used before being written to, so - // any value we set here would not be observed anyhow. - continue; + + if (call->isReturn) { + // Wrap the existing function body in a block we can branch out of before + // entering the inlined function body. This block must have a name that is + // different from any other block name above the branch. + auto intoNames = BranchUtils::BranchAccumulator::get(into->body); + auto bodyName = + Names::getValidName(Name("__original_body"), + [&](Name test) { return !intoNames.count(test); }); + if (retType.isConcrete()) { + into->body = builder.makeBlock( + bodyName, {builder.makeReturn(into->body)}, Type::none); + } else { + into->body = builder.makeBlock( + bodyName, {into->body, builder.makeReturn()}, Type::none); + } + + // Sequence the inlined function body after the original caller body. + into->body = builder.makeSequence(into->body, block, retType); + + // Replace the original callsite with an expression that assigns the + // operands into the params and branches out of the original body. + auto numParams = from->getParams().size(); + if (numParams) { + auto* branchBlock = builder.makeBlock(); + for (Index i = 0; i < numParams; i++) { + branchBlock->list.push_back( + builder.makeLocalSet(updater.localMapping[i], call->operands[i])); + } + branchBlock->list.push_back(builder.makeBreak(bodyName)); + branchBlock->finalize(Type::unreachable); + *action.callSite = branchBlock; + } else { + *action.callSite = builder.makeBreak(bodyName); + } + } else { + // Assign the operands into the params + for (Index i = 0; i < from->getParams().size(); i++) { + block->list.push_back( + builder.makeLocalSet(updater.localMapping[i], call->operands[i])); } - block->list.push_back( - builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i], - LiteralUtils::makeZero(type, *module))); + // Zero out the vars (as we may be in a loop, and may depend on their + // zero-init value + for (Index i = 0; i < from->vars.size(); i++) { + auto type = from->vars[i]; + if (!LiteralUtils::canMakeZero(type)) { + // Non-zeroable locals do not need to be zeroed out. As they have no + // zero value they by definition should not be used before being written + // to, so any value we set here would not be observed anyhow. + continue; + } + block->list.push_back( + builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i], + LiteralUtils::makeZero(type, *module))); + } + *action.callSite = block; } + // Generate and update the inlined contents auto* contents = ExpressionManipulator::copy(from->body, *module); debug::copyDebugInfo(from->body, contents, from, into); updater.walk(contents); block->list.push_back(contents); block->type = retType; + // The ReFinalize below will handle propagating unreachability if we need to // do so, that is, if the call was reachable but now the inlined content we // replaced it with was unreachable. The opposite case requires special diff --git a/src/tools/wasm-ctor-eval.cpp b/src/tools/wasm-ctor-eval.cpp index d476bb2cc..fe3d42d09 100644 --- a/src/tools/wasm-ctor-eval.cpp +++ b/src/tools/wasm-ctor-eval.cpp @@ -25,6 +25,7 @@ #include <memory> #include "asmjs/shared-constants.h" +#include "ir/find_all.h" #include "ir/gc-type-utils.h" #include "ir/global-utils.h" #include "ir/import-utils.h" @@ -1061,40 +1062,45 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, params.push_back(Literal::makeZero(type)); } - // We want to handle the form of the global constructor function in LLVM. That - // looks like this: - // - // (func $__wasm_call_ctors - // (call $ctor.1) - // (call $ctor.2) - // (call $ctor.3) - // ) - // - // Some of those ctors may be inlined, however, which would mean that the - // function could have locals, control flow, etc. However, we assume for now - // that it does not have parameters at least (whose values we can't tell). - // And for now we look for a toplevel block and process its children one at a - // time. This allows us to eval some of the $ctor.* functions (or their - // inlined contents) even if not all. - // - // TODO: Support complete partial evalling, that is, evaluate parts of an - // arbitrary function, and not just a sequence in a single toplevel - // block. + // After we successfully eval a line we will store the operations to set up + // the locals here. That is, we need to save the local state in the function, + // which we do by setting up at the entry. We update this list of expressions + // at the same time as applyToModule() - we must only do it after an entire + // atomic "chunk" has been processed succesfully, we do not want partial + // updates from an item in the block that we only partially evalled. When we + // construct the (partially) evalled function, we will create local.sets of + // these expressions at the beginning. + std::vector<Expression*> localExprs; + + // We might have to evaluate multiple functions due to return calls. +start_eval: + while (true) { + // We want to handle the form of the global constructor function in LLVM. + // That looks like this: + // + // (func $__wasm_call_ctors + // (call $ctor.1) + // (call $ctor.2) + // (call $ctor.3) + // ) + // + // Some of those ctors may be inlined, however, which would mean that the + // function could have locals, control flow, etc. However, we assume for now + // that it does not have parameters at least (whose values we can't tell). + // And for now we look for a toplevel block and process its children one at + // a time. This allows us to eval some of the $ctor.* functions (or their + // inlined contents) even if not all. + // + // TODO: Support complete partial evalling, that is, evaluate parts of an + // arbitrary function, and not just a sequence in a single toplevel + // block. + Builder builder(wasm); + auto* block = builder.blockify(func->body); - if (auto* block = func->body->dynCast<Block>()) { // Go through the items in the block and try to execute them. We do all this // in a single function scope for all the executions. EvallingModuleRunner::FunctionScope scope(func, params, instance); - // After we successfully eval a line we will store the operations to set up - // the locals here. That is, we need to save the local state in the - // function, which we do by setting up at the entry. We update this list of - // local.sets at the same time as applyToModule() - we must only do it after - // an entire atomic "chunk" has been processed succesfully, we do not want - // partial updates from an item in the block that we only partially evalled. - std::vector<Expression*> localSets; - - Builder builder(wasm); Literals results; Index successes = 0; @@ -1116,6 +1122,22 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, break; } + if (flow.breakTo == RETURN_CALL_FLOW) { + // The return-called function is stored in the last value. + func = wasm.getFunction(flow.values.back().getFunc()); + flow.values.pop_back(); + params = std::move(flow.values); + + // Serialize the arguments for the new function and save the module + // state in case we fail to eval the new function. + localExprs.clear(); + for (auto& param : params) { + localExprs.push_back(interface.getSerialization(param)); + } + interface.applyToModule(); + goto start_eval; + } + // So far so good! Serialize the values of locals, and apply to the // module. Note that we must serialize the locals now as doing so may // cause changes that must be applied to the module (e.g. GC data may @@ -1128,11 +1150,9 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, // of them, and leave it to the optimizer to remove redundant or // unnecessary operations. We just recompute the entire local // serialization sets from scratch each time here, for all locals. - localSets.clear(); + localExprs.clear(); for (Index i = 0; i < func->getNumLocals(); i++) { - auto value = scope.locals[i]; - localSets.push_back( - builder.makeLocalSet(i, interface.getSerialization(value))); + localExprs.push_back(interface.getSerialization(scope.locals[i])); } interface.applyToModule(); successes++; @@ -1144,41 +1164,97 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, if (flow.breaking()) { // We are returning out of the function (either via a return, or via a // break to |block|, which has the same outcome. That means we don't - // need to execute any more lines, and can consider them to be executed. + // need to execute any more lines, and can consider them to be + // executed. if (!quiet) { std::cout << " ...stopping in block due to break\n"; } // Mark us as having succeeded on the entire block, since we have: we - // are skipping the rest, which means there is no problem there. We must - // set this here so that lower down we realize that we've evalled + // are skipping the rest, which means there is no problem there. We + // must set this here so that lower down we realize that we've evalled // everything. successes = block->list.size(); break; } } - if (successes > 0 && successes < block->list.size()) { - // We managed to eval some but not all. That means we can't just remove - // the entire function, but need to keep parts of it - the parts we have - // not evalled - around. To do so, we create a copy of the function with - // the partially-evalled contents and make the export use that (as the - // function may be used in other places than the export, which we do not - // want to affect). + // If we have not fully evaluated the current function, but we have + // evaluated part of it, have return-called to a different function, or have + // precomputed values for the current return-called function, then we can + // replace the export with a new function that does less work than the + // original. + if ((func->imported() || successes < block->list.size()) && + (successes > 0 || func->name != funcName || + (localExprs.size() && func->getParams() != Type::none))) { + auto originalFuncType = wasm.getFunction(funcName)->type; auto copyName = Names::getValidFunctionName(wasm, funcName); - auto* copyFunc = ModuleUtils::copyFunction(func, wasm, copyName); wasm.getExport(exportName)->value = copyName; + if (func->imported()) { + // We must have return-called this imported function. Generate a new + // function that return-calls the import with the arguments we have + // evalled. + auto copyFunc = builder.makeFunction( + copyName, + originalFuncType, + {}, + builder.makeCall(func->name, localExprs, func->getResults(), true)); + wasm.addFunction(std::move(copyFunc)); + return EvalCtorOutcome(); + } + + // We may have managed to eval some but not all. That means we can't just + // remove the entire function, but need to keep parts of it - the parts we + // have not evalled - around. To do so, we create a copy of the function + // with the partially-evalled contents and make the export use that (as + // the function may be used in other places than the export, which we do + // not want to affect). + auto* copyBody = + builder.blockify(ExpressionManipulator::copy(func->body, wasm)); + // Remove the items we've evalled. - auto* copyBlock = copyFunc->body->cast<Block>(); for (Index i = 0; i < successes; i++) { - copyBlock->list[i] = builder.makeNop(); + copyBody->list[i] = builder.makeNop(); } - // Put the local sets at the front of the block. We know there must be a - // nop in that position (since we've evalled at least one item in the - // block, and replaced it with a nop), so we can overwrite it. - copyBlock->list[0] = builder.makeBlock(localSets); + // Put the local sets at the front of the function body. + auto* setsBlock = builder.makeBlock(); + for (Index i = 0; i < localExprs.size(); ++i) { + setsBlock->list.push_back(builder.makeLocalSet(i, localExprs[i])); + } + copyBody = builder.makeSequence(setsBlock, copyBody, copyBody->type); + + // We may have return-called into a function with different parameter + // types, but we ultimately need to export a function with the original + // signature. If there is a mismatch, shift the local indices to make room + // for the unused parameters. + std::vector<Type> localTypes; + auto originalParams = originalFuncType.getSignature().params; + if (originalParams != func->getParams()) { + // Add locals for the body to use instead of using the params. + for (auto type : func->getParams()) { + localTypes.push_back(type); + } + + // Shift indices in the body so they will refer to the new locals. + auto localShift = originalParams.size(); + if (localShift != 0) { + for (auto* get : FindAll<LocalGet>(copyBody).list) { + get->index += localShift; + } + for (auto* set : FindAll<LocalSet>(copyBody).list) { + set->index += localShift; + } + } + } + + // Add vars from current function. + localTypes.insert(localTypes.end(), func->vars.begin(), func->vars.end()); + + // Create and add the new function. + auto* copyFunc = wasm.addFunction(builder.makeFunction( + copyName, originalFuncType, std::move(localTypes), copyBody)); // Interesting optimizations may be possible both due to removing some but // not all of the code, and due to the locals we just added. @@ -1196,24 +1272,6 @@ EvalCtorOutcome evalCtor(EvallingModuleRunner& instance, return EvalCtorOutcome(); } } - - // Otherwise, we don't recognize a pattern that allows us to do partial - // evalling. So simply call the entire function at once and see if we can - // optimize that. - - Literals results; - try { - results = instance.callFunction(funcName, params); - } catch (FailToEvalException& fail) { - if (!quiet) { - std::cout << " ...stopping since could not eval: " << fail.why << "\n"; - } - return EvalCtorOutcome(); - } - - // Success! Apply the results. - interface.applyToModule(); - return EvalCtorOutcome(results); } // Eval all ctors in a module. diff --git a/src/wasm-builder.h b/src/wasm-builder.h index cc90a8abe..463faa04f 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -269,6 +269,7 @@ public: call->target = target; call->operands.set(args); call->isReturn = isReturn; + call->finalize(); return call; } template<typename T> diff --git a/src/wasm-interpreter.h b/src/wasm-interpreter.h index c95f694ef..64c0bfb2d 100644 --- a/src/wasm-interpreter.h +++ b/src/wasm-interpreter.h @@ -51,7 +51,7 @@ std::ostream& operator<<(std::ostream& o, const WasmException& exn); // Utilities -extern Name WASM, RETURN_FLOW, NONCONSTANT_FLOW; +extern Name WASM, RETURN_FLOW, RETURN_CALL_FLOW, NONCONSTANT_FLOW; // Stuff that flows around during executing expressions: a literal, or a change // in control flow. @@ -63,6 +63,8 @@ public: Flow(Literals&& values) : values(std::move(values)) {} Flow(Name breakTo) : values(), breakTo(breakTo) {} Flow(Name breakTo, Literal value) : values{value}, breakTo(breakTo) {} + Flow(Name breakTo, Literals&& values) + : values(std::move(values)), breakTo(breakTo) {} Literals values; Name breakTo; // if non-null, a break is going on @@ -2965,32 +2967,33 @@ public: Flow visitCall(Call* curr) { NOTE_ENTER("Call"); NOTE_NAME(curr->target); + Name target = curr->target; Literals arguments; Flow flow = self()->generateArguments(curr->operands, arguments); if (flow.breaking()) { return flow; } auto* func = wasm.getFunction(curr->target); - Flow ret; + auto funcType = func->type; if (Intrinsics(*self()->getModule()).isCallWithoutEffects(func)) { // The call.without.effects intrinsic is a call to an import that actually // calls the given function reference that is the final argument. - auto newArguments = arguments; - auto target = newArguments.back(); - newArguments.pop_back(); - ret.values = callFunctionInternal(target.getFunc(), newArguments); - } else if (func->imported()) { - ret.values = externalInterface->callImport(func, arguments); - } else { - ret.values = callFunctionInternal(curr->target, arguments); + target = arguments.back().getFunc(); + funcType = arguments.back().type.getHeapType(); + arguments.pop_back(); + } + + if (curr->isReturn) { + // Return calls are represented by their arguments followed by a reference + // to the function to be called. + arguments.push_back(Literal::makeFunc(target, funcType)); + return Flow(RETURN_CALL_FLOW, std::move(arguments)); } + + Flow ret = callFunctionInternal(target, arguments); #ifdef WASM_INTERPRETER_DEBUG std::cout << "(returned to " << scope->function->name << ")\n"; #endif - // TODO: make this a proper tail call (return first) - if (curr->isReturn) { - ret.breakTo = RETURN_FLOW; - } return ret; } @@ -3007,18 +3010,28 @@ public: } Index index = target.getSingleValue().geti32(); - Type type = curr->isReturn ? scope->function->getResults() : curr->type; auto info = getTableInterfaceInfo(curr->table); - Flow ret = info.interface->callTable( - info.name, index, curr->heapType, arguments, type, *self()); - // TODO: make this a proper tail call (return first) if (curr->isReturn) { - ret.breakTo = RETURN_FLOW; + // Return calls are represented by their arguments followed by a reference + // to the function to be called. + auto funcref = info.interface->tableLoad(info.name, index); + if (!Type::isSubType(funcref.type, Type(curr->heapType, NonNullable))) { + trap("cast failure in call_indirect"); + } + arguments.push_back(funcref); + return Flow(RETURN_CALL_FLOW, std::move(arguments)); } + + Flow ret = info.interface->callTable( + info.name, index, curr->heapType, arguments, curr->type, *self()); +#ifdef WASM_INTERPRETER_DEBUG + std::cout << "(returned to " << scope->function->name << ")\n"; +#endif return ret; } + Flow visitCallRef(CallRef* curr) { NOTE_ENTER("CallRef"); Literals arguments; @@ -3030,24 +3043,22 @@ public: if (target.breaking()) { return target; } - if (target.getSingleValue().isNull()) { + auto targetRef = target.getSingleValue(); + if (targetRef.isNull()) { trap("null target in call_ref"); } - Name funcName = target.getSingleValue().getFunc(); - auto* func = wasm.getFunction(funcName); - Flow ret; - if (func->imported()) { - ret.values = externalInterface->callImport(func, arguments); - } else { - ret.values = callFunctionInternal(funcName, arguments); + + if (curr->isReturn) { + // Return calls are represented by their arguments followed by a reference + // to the function to be called. + arguments.push_back(targetRef); + return Flow(RETURN_CALL_FLOW, std::move(arguments)); } + + Flow ret = callFunctionInternal(targetRef.getFunc(), arguments); #ifdef WASM_INTERPRETER_DEBUG std::cout << "(returned to " << scope->function->name << ")\n"; #endif - // TODO: make this a proper tail call (return first) - if (curr->isReturn) { - ret.breakTo = RETURN_FLOW; - } return ret; } @@ -4098,12 +4109,7 @@ public: // Call a function, starting an invocation. Literals callFunction(Name name, const Literals& arguments) { - auto* func = wasm.getFunction(name); - if (func->imported()) { - return externalInterface->callImport(func, arguments); - } - - // if the last call ended in a jump up the stack, it might have left stuff + // If the last call ended in a jump up the stack, it might have left stuff // for us to clean up here callDepth = 0; functionStack.clear(); @@ -4112,47 +4118,79 @@ public: // Internal function call. Must be public so that callTable implementations // can use it (refactor?) - Literals callFunctionInternal(Name name, const Literals& arguments) { + Literals callFunctionInternal(Name name, Literals arguments) { if (callDepth > maxDepth) { externalInterface->trap("stack limit"); } - auto previousCallDepth = callDepth; - callDepth++; - auto previousFunctionStackSize = functionStack.size(); - functionStack.push_back(name); - Function* function = wasm.getFunction(name); - assert(function); - FunctionScope scope(function, arguments, *self()); + Flow flow; + std::optional<Type> resultType; + + // We may have to call multiple functions in the event of return calls. + while (true) { + Function* function = wasm.getFunction(name); + assert(function); + + // Return calls can only make the result type more precise. + if (resultType) { + assert(Type::isSubType(function->getResults(), *resultType)); + } + resultType = function->getResults(); + + if (function->imported()) { + // TODO: Allow imported functions to tail call as well. + return externalInterface->callImport(function, arguments); + } + + auto previousCallDepth = callDepth; + callDepth++; + auto previousFunctionStackSize = functionStack.size(); + functionStack.push_back(name); + + FunctionScope scope(function, arguments, *self()); #ifdef WASM_INTERPRETER_DEBUG - std::cout << "entering " << function->name << "\n with arguments:\n"; - for (unsigned i = 0; i < arguments.size(); ++i) { - std::cout << " $" << i << ": " << arguments[i] << '\n'; - } + std::cout << "entering " << function->name << "\n with arguments:\n"; + for (unsigned i = 0; i < arguments.size(); ++i) { + std::cout << " $" << i << ": " << arguments[i] << '\n'; + } +#endif + + flow = self()->visit(function->body); + + // may decrease more than one, if we jumped up the stack + callDepth = previousCallDepth; + // if we jumped up the stack, we also need to pop higher frames + // TODO can FunctionScope handle this automatically? + while (functionStack.size() > previousFunctionStackSize) { + functionStack.pop_back(); + } +#ifdef WASM_INTERPRETER_DEBUG + std::cout << "exiting " << function->name << " with " << flow.values + << '\n'; #endif - Flow flow = self()->visit(function->body); + if (flow.breakTo != RETURN_CALL_FLOW) { + break; + } + + // There was a return call, so we need to call the next function before + // returning to the caller. The flow carries the function arguments and a + // function reference. + name = flow.values.back().getFunc(); + flow.values.pop_back(); + arguments = flow.values; + } + // cannot still be breaking, it means we missed our stop assert(!flow.breaking() || flow.breakTo == RETURN_FLOW); auto type = flow.getType(); - if (!Type::isSubType(type, function->getResults())) { - std::cerr << "calling " << function->name << " resulted in " << type - << " but the function type is " << function->getResults() - << '\n'; + if (!Type::isSubType(type, *resultType)) { + std::cerr << "calling " << name << " resulted in " << type + << " but the function type is " << *resultType << '\n'; WASM_UNREACHABLE("unexpected result type"); } - // may decrease more than one, if we jumped up the stack - callDepth = previousCallDepth; - // if we jumped up the stack, we also need to pop higher frames - // TODO can FunctionScope handle this automatically? - while (functionStack.size() > previousFunctionStackSize) { - functionStack.pop_back(); - } -#ifdef WASM_INTERPRETER_DEBUG - std::cout << "exiting " << function->name << " with " << flow.values - << '\n'; -#endif + return flow.values; } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index bc707890c..c2df2c68c 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -25,6 +25,7 @@ namespace wasm { Name WASM("wasm"); Name RETURN_FLOW("*return:)*"); +Name RETURN_CALL_FLOW("*return-call:)*"); Name NONCONSTANT_FLOW("*nonconstant:)*"); namespace BinaryConsts { |